import h5py
import numpy as np
import os
import time
import datetime
from xml.etree import ElementTree as ET
import re
from IO.Kafka import getDetectorData
from  Utils.Histogram import Hist2D
from signal import signal, SIGINT,SIGTERM,SIGTSTP
from sys import exit

pubName = ["beamline","description","measurement_type","proposal_id","proton_charge","run_no","version","wavelength","start_time_tai","start_time_utc","end_time_tai","end_time_utc","instrument_bank_definition","instrument_bank_file","instrument_name"]

def writeGroup(parent,name,ntype):
    grp = parent.create_group(name)
    grp.attrs["NX_class"] = ntype
    return grp

def mergeNumpyAll(dataset,chunknum):
    name=[]
    for k in range(len(dataset)):
        name.append(dataset["var"+str(k)])
    param = tuple(name)
    all_arr = np.concatenate(param)
    num = len(all_arr)
    if num > chunknum:
        res_arr = all_arr[-num+chunknum:]
        ans_arr = all_arr[:chunknum]
    else:
        res_arr = None
        ans_arr = all_arr
    return ans_arr, res_arr
    


class nxsWrite():
    def __init__(self):
        pass
    
    def handler(self, signal_received, frame):
        # Handle any cleanup here
        self.nxs.close()
        print('SIGINT or CTRL-C detected. Exiting gracefully')
        exit(0)
        

    def createBasicFramework(self):
        root = writeGroup(self.nxs,"csns","NXentry")
        grp1 = writeGroup(root,"instrument","NXinstrument")
        grp2 = writeGroup(root,"histogram_data","NXcollection")
        grp3 = writeGroup(root,"event_data","NXcollection")
        grp4 = writeGroup(root,"logs","NXcollection")
        grp5 = writeGroup(root,"process","NXprocess")
        grp6 = writeGroup(root,"user","NXuser")

    def createModule(self,module):
        son1 = writeGroup(self.nxs["/csns/instrument"],module,"NXdetector")
        son2 = writeGroup(self.nxs["/csns/histogram_data"],module,"NXdata")
        son3 = writeGroup(self.nxs["/csns/event_data"],module,"NXdata")


    def createDataPointer(self,module,chunkSize,compression,nx,ny):
        self.ds_tof=self.nxs["/csns/instrument/"+module].create_dataset("event_time_of_flight", (chunkSize, ), maxshape=(None, ), dtype='float32', compression=compression, chunks=(chunkSize, ))
        self.ds_pid=self.nxs["/csns/instrument/"+module].create_dataset("event_pixel_id", (chunkSize, ), maxshape=(None, ), dtype='int64', compression=compression, chunks=(chunkSize, ))
        self.ds_pulse=self.nxs["/csns/instrument/"+module].create_dataset("event_pulse_time", (chunkSize, ), maxshape=(None, ), dtype='int64', compression=compression, chunks=(chunkSize, ))
        self.ds_hist=self.nxs["/csns/instrument/"+module].create_dataset("histogram_data", (nx,ny), dtype='int32')
    
    def appendEvtData(self,num,pixelId,tof,pulseTime,chunkSize):
        if num == 0:
            self.ds_tof[:] = tof
            self.ds_pid[:] = pixelId
            self.ds_pulse[:] = pulseTime
        else:
            npsize=chunkSize
            if len(tof)<chunkSize:
                npsize=len(tof)
            self.ds_tof.resize(self.ds_tof.shape[0] + npsize, axis=0)
            self.ds_pid.resize(self.ds_pid.shape[0] + npsize, axis=0)
            self.ds_pulse.resize(self.ds_pulse.shape[0] + npsize, axis=0)
            self.ds_tof[-npsize:]=tof
            self.ds_pid[-npsize:]=pixelId
            #print(npsize, len(pulseTime),len(tof),len(pixelId))
            self.ds_pulse[-npsize:]=pulseTime    

    def writeFileAttrs(self,fName):
        self.nxs.attrs["default"]="entry"
        self.nxs.attrs["file_name"]=fName
        timestamp = str(datetime.datetime.now())
        self.nxs.attrs["file_time"]=timestamp
        self.nxs.attrs["NeXus_Version"]="4.3.0"
        self.nxs.attrs["HDF5_Version"]="1.8.12"

    def writePublicData(self,dic):
        for key in dic: 
            self.nxs["/csns"].create_dataset(key, data = dic[key])
    
    
    def writeEventLink(self,module):
        self.nxs["/csns/event_data/"+module]["event_time_of_flight"]=h5py.SoftLink("/csns/instrument/"+module+"/event_time_of_flight")
        self.nxs["/csns/event_data/"+module]["event_pixel_id"]=h5py.SoftLink("/csns/instrument/"+module+"/event_pixel_id")
        self.nxs["/csns/event_data/"+module]["event_pulse_time"]=h5py.SoftLink("/csns/instrument/"+module+"/event_pulse_time")  
        
        
    def writeHistogramLink(self,module):
        self.nxs["/csns/histogram_data/"+module]["histogram_data"]=h5py.SoftLink("/csns/instrument/"+module+"/histogram_data")
        self.nxs["/csns/histogram_data/"+module]["time_of_flight"]=h5py.SoftLink("/csns/instrument/"+module+"/time_of_flight")
        
    
    def writePidAndTof(self,module,pixel,tof):
        grp=self.nxs["/csns/instrument/"+module]
        size=pixel.shape
        grp.create_dataset("pixel_id",size,dtype='int64', data=pixel)
        size=tof.shape
        grp.create_dataset("time_of_flight",size,dtype='int32', data=tof)

    def recSingleMonitor(self,prepath,monitor,pixel_id,time_of_flight,consumer,startT0, endT0):
        filename=prepath+"/"+monitor+".nxs"
        self.nxs=h5py.File(filename,"w")
        self.createBasicFramework()
        self.writeFileAttrs(filename)
        signal(SIGINT,self.handler)
        signal(SIGTERM,self.handler)
        signal(SIGTSTP,self.handler)
        self.createModule(monitor)
        self.writePidAndTof(monitor,pixel_id,time_of_flight)
        idstart = pixel_id.flatten()[0]
        pidNums = pixel_id.size
        tofBins = time_of_flight.size-1
        self.createDataPointer(monitor,100,"gzip",pidNums,tofBins)
        hist = Hist2D(pidNums,tofBins, [[-0.5 + idstart, idstart+pidNums+0.5],[0-0.5, 40000-0.5]])
        for msg in consumer:
            pulseId, tofs, pids = getDetectorData(msg)
            #print(pids)
            if pulseId>=startT0:
                hist.fill(pids,tofs)
                histVal=hist.hist
                self.ds_hist[:,:]=histVal
            if pulseId>=endT0:
                break
        consumer.close()
        self.nxs.close()
        

    def recSingleModule(self,prepath, module,pixel_id,time_of_flight,chunkSize,consumer,startT0,endT0):
        filename=prepath+"/"+module+".nxs"
        self.nxs=h5py.File(filename,"w")
        self.createBasicFramework()
        self.writeFileAttrs(filename)
        signal(SIGINT,self.handler)
        signal(SIGTERM,self.handler)
        signal(SIGTSTP,self.handler)
        
        self.createModule(module)
        self.writePidAndTof(module,pixel_id,time_of_flight)
        idstart = pixel_id.flatten()[0]
        pidNums = pixel_id.size
        tofBins = time_of_flight.size-1
        self.createDataPointer(module,chunkSize,"gzip",pidNums,tofBins)

        hist = Hist2D(pidNums,tofBins, [[-0.5 + idstart, idstart+pidNums+0.5],[0-0.5, 40000-0.5]])
        tmp_tof = {}
        tmp_pid = {}
        tmp_pulse = {}
        evtNum = 0
        startNum = 0
        loopNum = 0
        print("start evt ")
        for msg in consumer:
            pulseId, tofs, pids = getDetectorData(msg)
            print("pulseId: ",pulseId)
            if pulseId>=startT0:
                hist.fill(pids,tofs)
                tmp_tof["var"+str(startNum)] = tofs
                tmp_pid["var"+str(startNum)] = pids
                evtNum += pids.size
                pulseID = []
                for i in range(pids.size):
                    pulseID.append(pulseId)
                tmp_pulse["var"+str(startNum)]=pulseID
                print(pulseId,evtNum)
                startNum += 1
                histVal=hist.hist
                self.ds_hist[:,:]=histVal
                if evtNum >=chunkSize:
                    _tof, res_tof = mergeNumpyAll(tmp_tof,chunkSize)
                    _pid, res_pid = mergeNumpyAll(tmp_pid,chunkSize)
                    _pulse,res_pulse = mergeNumpyAll(tmp_pulse,chunkSize)
                    #print (loopNum,_tof)
                    self.appendEvtData(loopNum,_pid,_tof,_pulse,chunkSize)
                    loopNum += 1
                    tmp_tof={}
                    tmp_pid={}
                    tmp_pulse={}
                    startNum=0
                    evtNum=0
                    if pulseId>=endT0:
                        if res_tof is None:
                            pass
                        else:
                            print("larger: ",pulseId,module)
                            self.appendEvtData(loopNum,res_pid,res_tof,res_pulse,chunkSize)
                        consumer.close()
                        break
                
                    if res_tof is None:
                        pass
                    else:
                        tmp_tof["var"+str(startNum)]=res_tof
                        tmp_pid["var"+str(startNum)]=res_pid
                        tmp_pulse["var"+str(startNum)]=res_pulse
                        evtNum += tmp_tof["var"+str(startNum)].shape[0]
                        startNum +=1
        
        self.nxs.close()
        print("finish")

    def mergeModules(self,prepath,pubInfo,moduleList,monitorList):
        filename=prepath+"/detector.nxs"
        self.nxs=h5py.File(filename,"w")
        self.createBasicFramework()
        self.writeFileAttrs(filename)
        self.writePublicData(pubInfo)
        be=time.time()
        for name in moduleList:
            be1=time.time()
            print("start with ",name)
            self.createModule(name)
            f0=h5py.File(prepath+"/"+name+".nxs","r")
            tof = f0["/csns/instrument/"+name+"/event_time_of_flight"]
            pid = f0["/csns/instrument/"+name+"/event_pixel_id"]
            pulse = f0["/csns/instrument/"+name+"/event_pulse_time"]
            hist = f0["/csns/instrument/"+name+"/histogram_data"]
            pixel_id=f0["/csns/instrument/"+name+"/pixel_id"]
            time_of_flight=f0["/csns/instrument/"+name+"/time_of_flight"]
            
            self.nxs["/csns/instrument/"+name].create_dataset("event_time_of_flight",tof.shape,dtype=tof.dtype,compression='gzip',data=tof)
            self.nxs["/csns/instrument/"+name].create_dataset("event_pulse_time",pulse.shape,dtype=pulse.dtype,compression='gzip',data=pulse)
            self.nxs["/csns/instrument/"+name].create_dataset("event_pixel_id",pid.shape,dtype=pid.dtype,compression='gzip',data=pid)
            self.nxs["/csns/instrument/"+name].create_dataset("histogram_data",hist.shape,dtype=hist.dtype,data=hist)
            self.nxs["/csns/instrument/"+name].create_dataset("time_of_flight",time_of_flight.shape,dtype=time_of_flight.dtype,data=time_of_flight)
            self.nxs["/csns/instrument/"+name].create_dataset("pixel_id",pixel_id.shape,dtype=pixel_id.dtype,data=pixel_id)
            self.writeEventLink(name)
            self.writeHistogramLink(name)
            f0.close()
            print("time: ",time.time()-be1)
        for name in monitorList:
            print("start with ",name)
            self.createModule(name)
            f0=h5py.File(prepath+"/"+name+".nxs","r")
            hist = f0["/csns/instrument/"+name+"/histogram_data"]
            pixel_id=f0["/csns/instrument/"+name+"/pixel_id"]
            time_of_flight=f0["/csns/instrument/"+name+"/time_of_flight"]

            self.nxs["/csns/instrument/"+name].create_dataset("histogram_data",hist.shape,dtype=hist.dtype,data=hist)
            self.nxs["/csns/instrument/"+name].create_dataset("time_of_flight",time_of_flight.shape,dtype=time_of_flight.dtype,data=time_of_flight)
            self.nxs["/csns/instrument/"+name].create_dataset("pixel_id",pixel_id.shape,dtype=pixel_id.dtype,data=pixel_id)
            self.writeHistogramLink(name)
            f0.close()
            print("time: ",time.time()-be1)

        self.nxs.close()
        print("finish, ",time.time()-be)
