import h5py
import numpy as np
import os
import time
import datetime
from xml.etree import ElementTree as ET
import re
from IO.Kafka import getDetectorData
from  Utils.Histogram import Hist2D
from signal import signal, SIGINT,SIGTERM,SIGTSTP
from sys import exit



def generatePixel(name):
    if name[:6]=="module":
        num=int(name[-3:-2])
        length=500
        if num==7 or num==8:
            length=300
        nx=int(length/5)
        ny=8
        idstart=int(name[6:]+"0001")
    else:
        num=int(name[-2:])
        idstart=int("201"+name[-2:]+"0001")
        nx=1
        ny=1
    pidNum=int(nx*ny)
    pixel_id=np.arange(idstart,idstart+pidNum)
    pixel_id=pixel_id.reshape((nx,ny))
    return pixel_id,idstart,pidNum


pubName = ["beamline","description","measurement_type","proposal_id","proton_charge","run_no","version","wavelength","start_time_tai","start_time_utc","end_time_tai","end_time_utc","instrument_bank_definition","instrument_bank_file","instrument_name"]

def writeGroup(parent,name,ntype):
    grp = parent.create_group(name)
    grp.attrs["NX_class"] = ntype
    return grp

def mergeNumpyAll(dataset,chunknum):
    name=[]
    for k in range(len(dataset)):
        name.append(dataset["var"+str(k)])
    param = tuple(name)
    all_arr = np.concatenate(param)
    num = len(all_arr)
    if num > chunknum:
        res_arr = all_arr[-num+chunknum:]
        ans_arr = all_arr[:chunknum]
    else:
        res_arr = None
        ans_arr = all_arr
    return ans_arr, res_arr
    


class nxsWrite():
    def __init__(self,rawPath,nxsPath,tofbins):
        #path include runno
        self.rawPath=rawPath
        self.nxsPath=nxsPath
        self.tofbins=tofbins
        step=int(40000/tofbins)
        self.time_of_flight = np.arange(0,40000+step,step)
    
    def handler(self, signal_received, frame):
        # Handle any cleanup here
        self.nxs.close()
        print('SIGINT or CTRL-C detected. Exiting gracefully')
        exit(0)
        

    def createBasicFramework(self):
        root = writeGroup(self.nxs,"csns","NXentry")
        grp1 = writeGroup(root,"instrument","NXinstrument")
        grp2 = writeGroup(root,"histogram_data","NXcollection")
        grp3 = writeGroup(root,"event_data","NXcollection")
        grp4 = writeGroup(root,"logs","NXcollection")
        grp5 = writeGroup(root,"process","NXprocess")
        grp6 = writeGroup(root,"user","NXuser")

    def createModule(self,module):
        son1 = writeGroup(self.nxs["/csns/instrument"],module,"NXdetector")
        son2 = writeGroup(self.nxs["/csns/histogram_data"],module,"NXdata")
        son3 = writeGroup(self.nxs["/csns/event_data"],module,"NXdata")


    def createEventDataPointer(self,module,chunkSize,compression):
        self.ds_tof=self.nxs["/csns/instrument/"+module].create_dataset("event_time_of_flight", (chunkSize, ), maxshape=(None, ), dtype='float32', compression=compression, chunks=(chunkSize, ))
        self.ds_pid=self.nxs["/csns/instrument/"+module].create_dataset("event_pixel_id", (chunkSize, ), maxshape=(None, ), dtype='int64', compression=compression, chunks=(chunkSize, ))
        self.ds_pulse=self.nxs["/csns/instrument/"+module].create_dataset("event_pulse_time", (chunkSize, ), maxshape=(None, ), dtype='int64', compression=compression, chunks=(chunkSize, ))
    
    def createHistDataPointer(self,module,pidNums):
        self.ds_hist=self.nxs["/csns/instrument/"+module].create_dataset("histogram_data", (pidNums,self.tofbins), dtype='int32')
    
    def appendEvtData(self,num,pixelId,tof,pulseTime,chunkSize):
        if num == 0:
            self.ds_tof[:] = tof
            self.ds_pid[:] = pixelId
            self.ds_pulse[:] = pulseTime
        else:
            npsize=chunkSize
            if len(tof)<chunkSize:
                npsize=len(tof)
            self.ds_tof.resize(self.ds_tof.shape[0] + npsize, axis=0)
            self.ds_pid.resize(self.ds_pid.shape[0] + npsize, axis=0)
            self.ds_pulse.resize(self.ds_pulse.shape[0] + npsize, axis=0)
            self.ds_tof[-npsize:]=tof
            self.ds_pid[-npsize:]=pixelId
            self.ds_pulse[-npsize:]=pulseTime    

    def writeFileAttrs(self,fName):
        self.nxs.attrs["default"]="entry"
        self.nxs.attrs["file_name"]=fName
        timestamp = str(datetime.datetime.now())
        self.nxs.attrs["file_time"]=timestamp
        self.nxs.attrs["NeXus_Version"]="4.3.0"
        self.nxs.attrs["HDF5_Version"]="1.8.12"

    def writePublicData(self,dic):
        for key in dic: 
            self.nxs["/csns"].create_dataset(key, data = dic[key])
    
    def writeEventLink(self,module):
        self.nxs["/csns/event_data/"+module]["event_time_of_flight"]=h5py.SoftLink("/csns/instrument/"+module+"/event_time_of_flight")
        self.nxs["/csns/event_data/"+module]["event_pixel_id"]=h5py.SoftLink("/csns/instrument/"+module+"/event_pixel_id")
        self.nxs["/csns/event_data/"+module]["event_pulse_time"]=h5py.SoftLink("/csns/instrument/"+module+"/event_pulse_time")  
        
        
    def writeHistogramLink(self,module):
        self.nxs["/csns/histogram_data/"+module]["histogram_data"]=h5py.SoftLink("/csns/instrument/"+module+"/histogram_data")
        self.nxs["/csns/histogram_data/"+module]["time_of_flight"]=h5py.SoftLink("/csns/instrument/"+module+"/time_of_flight")
        
    
    def writePidAndTof(self,module,pixel):
        grp=self.nxs["/csns/instrument/"+module]
        size=pixel.shape
        grp.create_dataset("pixel_id",size,dtype='int64', data=pixel)
        size=self.time_of_flight.shape
        grp.create_dataset("time_of_flight",size,dtype='int32', data=self.time_of_flight)


    def fillHist(self,pidNums,idstart,consumer,startT0,endT0):
        hist = Hist2D(pidNums,self.tofbins, [[-0.5 + idstart, idstart+pidNums+0.5],[0-0.5, 40000-0.5]])
        for msg in consumer:
            pulseId, tofs, pids = getDetectorData(msg)
            if pulseId>=startT0:
                hist.fill(pids,tofs)
                histVal=hist.hist
                self.ds_hist[:,:]=histVal
                if pulseId>=endT0:
                    break
        consumer.close()
        self.nxs.close()
    
    def recHistModule(self,name,consumer,t1,t2):
        # for both module and monitor
        filename=self.rawPath+"/"+name+"_hist.nxs"
        self.nxs=h5py.File(filename,"w")
        self.createBasicFramework()
        self.writeFileAttrs(filename)
        signal(SIGINT,self.handler)
        signal(SIGTERM,self.handler)
        signal(SIGTSTP,self.handler)
        self.createModule(name)
        pixel_id,idstart,pidNums = generatePixel(name)
        self.writePidAndTof(name,pixel_id)
        self.createHistDataPointer(name,pidNums)
        self.fillHist(pidNums,idstart,consumer,t1,t2)


    def runOnlyEvt(self,prepath, module,chunkSize,consumer,startT0,endT0): 
        self.recSingleModule(prepath, module,chunkSize)
        self.fillEvt(chunkSize,consumer,startT0,endT0)

    def fillEvt(self,chunkSize,consumer,startT0,endT0):
        tmp_tof = {}
        tmp_pid = {}
        tmp_pulse = {}
        evtNum = 0
        startNum = 0
        loopNum = 0
        for msg in consumer:
            pulseId, tofs, pids = getDetectorData(msg)
            if pulseId>=startT0:
                tmp_tof["var"+str(startNum)] = tofs
                tmp_pid["var"+str(startNum)] = pids
                evtNum += pids.size
                pulseID = []
                for i in range(pids.size):
                    pulseID.append(pulseId)
                tmp_pulse["var"+str(startNum)]=pulseID
                startNum += 1
                if evtNum >=chunkSize:
                    _tof, res_tof = mergeNumpyAll(tmp_tof,chunkSize)
                    _pid, res_pid = mergeNumpyAll(tmp_pid,chunkSize)
                    _pulse,res_pulse = mergeNumpyAll(tmp_pulse,chunkSize)
                    self.appendEvtData(loopNum,_pid,_tof,_pulse,chunkSize)
                    loopNum += 1
                    tmp_tof={}
                    tmp_pid={}
                    tmp_pulse={}
                    startNum=0
                    evtNum=0
                    if pulseId>=endT0:
                        print("larger: ",pulseId,module)
                        if res_tof is None:
                            pass
                        else:
                            self.appendEvtData(loopNum,res_pid,res_tof,res_pulse,chunkSize)
                        consumer.close()
                        break
                
                    if res_tof is None:
                        pass
                    else:
                        tmp_tof["var"+str(startNum)]=res_tof
                        tmp_pid["var"+str(startNum)]=res_pid
                        tmp_pulse["var"+str(startNum)]=res_pulse
                        evtNum += tmp_tof["var"+str(startNum)].shape[0]
                        startNum +=1
                #print("finish one pulse: ",time.time()-be)
                if pulseId>=endT0:
                    consumer.close()
                    break
        self.nxs.close()
        print("finish")


    def mergeModules(self,pubInfo,moduleList,dataMode):
        #for both modules and monitor
        filename=self.nxsPath+"/detector.nxs"
        self.nxs=h5py.File(filename,"w")
        self.createBasicFramework()
        self.writeFileAttrs(filename)
        self.writePublicData(pubInfo)
        be=time.time()
        for name in moduleList:
            be1=time.time()
            print("start with ",name)
            self.createModule(name)
            if dataMode == "hist":
                f0=h5py.File(self.rawPath+"/"+name+"_hist.nxs","r")
                hist = f0["/csns/instrument/"+name+"/histogram_data"]
                pixel_id=f0["/csns/instrument/"+name+"/pixel_id"]
                time_of_flight=f0["/csns/instrument/"+name+"/time_of_flight"]
                self.nxs["/csns/instrument/"+name].create_dataset("histogram_data",hist.shape,dtype=hist.dtype,data=hist)
                self.nxs["/csns/instrument/"+name].create_dataset("time_of_flight",time_of_flight.shape,dtype=time_of_flight.dtype,data=time_of_flight)
                self.nxs["/csns/instrument/"+name].create_dataset("pixel_id",pixel_id.shape,dtype=pixel_id.dtype,data=pixel_id)
                self.writeHistogramLink(name)
                f0.close()
            else:
                if name[:6]=="module":
                    f0=h5py.File(self.rawPath+"/"+name+".nxs","r")
                    tof = f0["/csns/instrument/"+name+"/event_time_of_flight"]
                    pid = f0["/csns/instrument/"+name+"/event_pixel_id"]
                    pulse = f0["/csns/instrument/"+name+"/event_pulse_time"]
                    self.nxs["/csns/instrument/"+name].create_dataset("event_time_of_flight",tof.shape,dtype=tof.dtype,compression='gzip',data=tof)
                    self.nxs["/csns/instrument/"+name].create_dataset("event_pulse_time",pulse.shape,dtype=pulse.dtype,compression='gzip',data=pulse)
                    self.nxs["/csns/instrument/"+name].create_dataset("event_pixel_id",pid.shape,dtype=pid.dtype,compression='gzip',data=pid)
                    hist = f0["/csns/instrument/"+name+"/histogram_data"]
                    pixel_id=f0["/csns/instrument/"+name+"/pixel_id"]
                    time_of_flight=f0["/csns/instrument/"+name+"/time_of_flight"]
                    self.nxs["/csns/instrument/"+name].create_dataset("histogram_data",hist.shape,dtype=hist.dtype,data=hist)
                    self.nxs["/csns/instrument/"+name].create_dataset("time_of_flight",time_of_flight.shape,dtype=time_of_flight.dtype,data=time_of_flight)
                    self.nxs["/csns/instrument/"+name].create_dataset("pixel_id",pixel_id.shape,dtype=pixel_id.dtype,data=pixel_id)
                    self.writeHistogramLink(name)
                    self.writeEventLink(name)            
                    f0.close()

        self.nxs.close()
        print("finish, ",time.time()-be)
