import h5py
import numpy as np
import datetime
import baseComp
import Data.Detector.EventData
import time
import sys
#from IO.RedisHelper import RedisHelper
import json

from signal import signal, SIGINT,SIGTERM,SIGTSTP
from sys import exit

def mergeNumpyAll(dataset,chunknum):
    name=[]
    for k in range(len(dataset)):
        name.append(dataset["var"+str(k)])
    param = tuple(name)
    all_arr = np.concatenate(param)
    num = len(all_arr)
    if num > chunknum:
        res_arr = all_arr[-num+chunknum:]
        ans_arr = all_arr[:chunknum]
    else:
        res_arr = None
        ans_arr = all_arr
    return ans_arr, res_arr

def getRedis(confFile):
    with open(confFile,"r") as json_file:
        conf=json.load(json_file)
    ip_port=conf["ip_port"]
    passwd=conf["passwd"]
    rds=RedisHelper(ip_port,passwd,10)
    return rds

cprType='gzip'

class dataSvc():
    def __init__(self, consumer,xmlpath):
        self.chunkSize = 10000
        self.consumer = consumer
        self.xmlpath=xmlpath
        #self.filename = filename
        #self.startT0=startT0
        #self.endT0=endT0
        #self.mnum=mNum
        #self.nxsFile = h5py.File(filename, "w")
        #print ("start")
    

    def handler(self, signal_received, frame):
        # Handle any cleanup here
        self.nxs.close()
        print('SIGINT or CTRL-C detected. Exiting gracefully')
        exit(0)
    
    def createNXS(self,filename):
        self.nxs=h5py.File(filename,"w")
        nf=baseComp.nxsWrite(self.nxs)
        nf.createBasicFramework()
        nf.writeFileAttrs(filename)
        self.nxs.close()
    
    def writePublic(self,nfObj):
        #get exp conf from redis
        expInfo={"beamline":"BL15",
                "run_no":"RUN0000001",
                "measurement_type":"test",
                "start_time_utc":"2021-01-01T00:00:00",
                "end_time_utc":"2021-01-01T00:30:00",
                "version":"1.0"}
        nfObj.writePublicData(expInfo)

    def writeSingleModule(self,nfObj,module):
        mdObj=baseComp.getMetaData()
        self.ds_pid,self.ds_tof,self.ds_pulse = nfObj.createEvent(module,10000,cprType)
        #hist data
        pidArr = mdObj.getPixelId(module,self.xmlpath)
        nfObj.writePid(module,pidArr,cprType)
        tofArr = mdObj.getTof(0,40000,16)
        nfObj.writeTof(module,tofArr,cprType)
        #histArr = mdObj.getHistData()

    def appendEvtData(self,num,pixelId,tof,pulseTime):
        if num == 0:
            self.ds_tof[:] = tof
            self.ds_pid[:] = pixelId
            self.ds_pulse[:] = pulseTime
        else:
            npsize=self.chunkSize
            if len(tof)<self.chunkSize:
                npsize=len(tof)
            self.ds_tof.resize(self.ds_tof.shape[0] + npsize, axis=0)
            self.ds_pid.resize(self.ds_pid.shape[0] + npsize, axis=0)
            self.ds_pulse.resize(self.ds_pulse.shape[0] + npsize, axis=0)
            self.ds_tof[-npsize:]=tof
            self.ds_pid[-npsize:]=pixelId
            #print(npsize, len(pulseTime),len(tof),len(pixelId))
            self.ds_pulse[-npsize:]=pulseTime

    def getDataFromKafka(self,t1,t2):
        tmp_tof = {}
        tmp_pid = {}
        tmp_pulse = {}
        evtNum = 0
        startNum = 0
        loopNum = 0
        be0=time.time()
        for msg in self.consumer:
            be1=time.time()
            if True:
                _data = Data.Detector.EventData.EventData.GetRootAsEventData(msg.value, 0)
                be2=time.time()
                if _data.PulseId()>=t1:
                    tmp_tof["var"+str(startNum)] = _data.TofAsNumpy()
                    tmp_pid["var"+str(startNum)] = _data.PosAsNumpy()
                    #tmp_pulse["var"+str(startNum)] = _data.PulseIDAsNumpy()
                    evtNum += _data.TofLength()
                    pulseID = []
                    #print("pulseId:",_data.PulseId())
                    for i in range(_data.TofLength()):
                        pulseID.append(_data.PulseId())
                    tmp_pulse["var"+str(startNum)]=pulseID
                    print (pulseID, evtNum)
                    startNum += 1
                    if evtNum >=self.chunkSize:
                        _tof, res_tof = mergeNumpyAll(tmp_tof,self.chunkSize)
                        _pid, res_pid = mergeNumpyAll(tmp_pid,self.chunkSize)
                        _pulse,res_pulse = mergeNumpyAll(tmp_pulse,self.chunkSize)
                        #print (loopNum,_tof)
                        self.appendEvtData(loopNum,_pid,_tof,_pulse)
                        loopNum += 1
                        tmp_tof={}
                        tmp_pid={}
                        tmp_pulse={}
                        startNum=0
                        evtNum=0
                        if _data.PulseId()>=t2:
                            if res_tof is None:
                                pass
                            else:
                                print("larger: ",_data.PulseId())
                                self.appendEvtData(loopNum,res_pid,res_tof,res_pulse)         
                            self.consumer.close()
                            break
                
                        if res_tof is None:
                            pass
                        else:
                            tmp_tof["var"+str(startNum)]=res_tof
                            tmp_pid["var"+str(startNum)]=res_pid
                            tmp_pulse["var"+str(startNum)]=res_pulse
                            evtNum += tmp_tof["var"+str(startNum)].shape[0]
                            startNum +=1
                        #print (str(time.time()-be) , "seconds finish one pulse",)
            #except:
            #    print("error happened!")
            #    h5file.close()
            #    self.consumer.close()
            #    sys.exit()
            #if loopNum>100:
            #    self.consumer.close()
            #    break
        print (str(time.time()-be0) , "seconds for all evt")
                

    def recEvtNxs(self,filename,module,t1,t2):
        self.createNXS(filename)
        self.nxs=h5py.File(filename,"r+")
        nf=baseComp.nxsWrite(self.nxs)
        signal(SIGINT,self.handler)
        signal(SIGTERM,self.handler)
        signal(SIGTSTP,self.handler)
        #signal(SIGKILL,self.handler)

        nf.createModule(module) 
        self.writeSingleModule(nf,module)
        self.getDataFromKafka(t1,t2)
        self.nxs.close()
        print ("finish!!!")
