import h5py
import numpy as np
import datetime
import nxsComp
import Data.Detector.EventData
import time
class dataSvc():
    def __init__(self, consumer, filename):
        self.chunkSize = 100000
        self.consumer = consumer
        self.filename = filename
        self.nxsFile = h5py.File(filename, "w")
        print ("start")
    def mergeNumpyAll(self, dataset):
        name=[]
        for k in range(len(dataset)):
            name.append(dataset["var"+str(k)])
        param = tuple(name)
        all_arr = np.concatenate(param)
        num = len(all_arr)
        if num > self.chunkSize:
            res_arr = all_arr[-num+self.chunkSize:]
            ans_arr = all_arr[:self.chunkSize]
        else:
            res_arr = None
            ans_arr = all_arr
        #print (res_arr)
        return ans_arr, res_arr

    def writeFile(self):
        nxsComp.writeFileAttrs(self.nxsFile,self.filename)
        entry = nxsComp.writeEntry(self.nxsFile)
        self.ds_tof = entry.create_dataset("tof", (self.chunkSize, ), maxshape=(None, ), dtype='int64', compression='gzip', chunks=(self.chunkSize, ))
        self.ds_pid = entry.create_dataset("pid", (self.chunkSize, ), maxshape=(None, ), dtype='int64', compression='gzip', chunks=(self.chunkSize, ))
        self.ds_pulse = entry.create_dataset("pulse", (self.chunkSize, ), maxshape=(None, ), dtype='int64', compression='gzip', chunks=(self.chunkSize, ))
        
    def getData(self):
        be=time.time()
        tmp_tof = {}
        tmp_pid = {}
        tmp_pulse = {}
        evtNum = 0
        loopNum = 0
        startNum = 0
        for msg in self.consumer:
            _data = Data.Detector.EventData.EventData.GetRootAsEventData(msg.value, 0)
            tmp_tof["var"+str(startNum)] = _data.TofAsNumpy()
            tmp_pid["var"+str(startNum)] = _data.PosAsNumpy()
            #tmp_pulse["var"+str(startNum)] = _data.PulseIDAsNumpy()
            evtNum += _data.TofLength()
            pulseID = []
            #_data.PulseId()

            for i in range(_data.TofLength()):
                pulseID.append(_data.PulseId())
            tmp_pulse["var"+str(startNum)]=pulseID
            #print (pulseID, evtNum)
            startNum += 1
            if evtNum >=self.chunkSize:
                _tof, res_tof = self.mergeNumpyAll(tmp_tof)
                _pid, res_pid = self.mergeNumpyAll(tmp_pid)
                _pulse,res_pulse = self.mergeNumpyAll(tmp_pulse)
                print (loopNum)
                if loopNum == 0:
                    self.ds_tof[:] = _tof
                    self.ds_pid[:] = _pid
                    self.ds_pulse[:] = _pulse
                else:
                    self.ds_tof.resize(self.ds_tof.shape[0] + self.chunkSize, axis=0)
                    self.ds_pid.resize(self.ds_pid.shape[0] + self.chunkSize, axis=0)
                    self.ds_pulse.resize(self.ds_pulse.shape[0] + self.chunkSize, axis=0)
                    self.ds_tof[-self.chunkSize:]=_tof
                    self.ds_pid[-self.chunkSize:]=_pid
                    self.ds_pulse[-self.chunkSize:]=_pulse
                    print (self.ds_tof.shape)
                    print (str(time.time()-be) , "seconds")
                #print(_tof.shape)
                loopNum += 1
                tmp_tof={}
                tmp_pid={}
                startNum=0
                evtNum=0
                if res_tof is None:
                    pass
                else:
                    tmp_tof["var"+str(startNum)]=res_tof
                    tmp_pid["var"+str(startNum)]=res_pid
                    tmp_pulse["var"+str(startNum)]=res_pulse
                    evtNum += tmp_tof["var"+str(startNum)].shape[0]
                    startNum +=1
            #if loopNum > 1080000:
            #    break

    def recNxs(self):
        self.writeFile()
        self.getData()
        self.nxsFile.close()
        print ("finish!!!")
