from Utils.NexusData import CSNSNexus
from  Utils.Histogram import Hist2D
from signal import signal, SIGINT,SIGTERM,SIGTSTP
import numpy as np
import h5py
import math

def mergeNumpyAll(dataset,chunknum):
    name=[]
    for k in range(len(dataset)):
        name.append(dataset["var"+str(k)])
    param = tuple(name)
    all_arr = np.concatenate(param)
    num = len(all_arr)
    if num > chunknum:
        res_arr = all_arr[-num+chunknum:]
        ans_arr = all_arr[:chunknum]
    else:
        res_arr = None
        ans_arr = all_arr
    return ans_arr, res_arr

def generatePixel(name):
    if name[:6]=="module":
        num=int(name[-3:-2])
        length=500
        if num==7 or num==8:
            length=300
        nx=int(length/5)
        ny=8
        idstart=int(name[6:]+"0001")
    else:
        num=int(name[-2:])
        idstart=int("201"+name[-2:]+"0001")
        nx=1
        ny=1
    pidNum=int(nx*ny)
    pixel_id=np.arange(idstart,idstart+pidNum)
    pixel_id=pixel_id.reshape((nx,ny))
    return pixel_id,idstart,pidNum

def copyNxsData(oldfile,newfile,path):
    tmp=path.rsplit("/",1)
    val = oldfile[path]
    newfile[tmp[0]].create_dataset(tmp[1],val.shape,dtype=val.dtype,data=val)


class recDetector():
    def __init__(self,moduleList,rawPath,nxsPath):
        self.moduleList=moduleList
        self.rawPath=rawPath
        self.nxsFile=nxsPath+"/detector.nxs"
        self.nxs=h5py.File(self.nxsFile,"w")

    def writeWholeNexus(self,pubInfos,moduleList):
        wnxs=CSNSNexus()
        wnxs.writeFileAttrs(self.nxs,self.nxsFile)
        wnxs.createBasicFramework(self.nxs)
        wnxs.writePublic(self.nxs,pubInfos)
        prepath="/csns/instrument/"
        for name in moduleList:
            pathList=[prepath+name+"/histogram_data",prepath+name+"/pixel_id",prepath+name+"/time_of_flight"]
            wnxs.createModule(self.nxs,name)
            fname=self.rawPath+"/"+name+"_hist.nxs"
            f0=h5py.File(fname,"r")
            for path in pathList:
                copyNxsData(f0,self.nxs,path)
            wnxs.writeHistLink(self.nxs,name)
            f0.close()
        #write pc
        fname=self.rawPath+"/pclog.nxs"
        f0=h5py.File(fname,"r")
        path="/csns/logs/proton_charge"
        copyNxsData(f0,self.nxs,path)
        path="/csns/logs/utc_tai"
        copyNxsData(f0,self.nxs,path)
        path="/csns/proton_charge"
        copyNxsData(f0,self.nxs,path)
        f0.close()
        self.nxs.close()


class recHist():
    def __init__(self,name,rawPath,tofbins,evtStep):
        self.rawPath=rawPath
        self.evtStep=evtStep
        self.module=name
        nxsFile=self.rawPath+"/"+name+"_hist.nxs"
        self.nxs=h5py.File(nxsFile,"w")
        wnxs=CSNSNexus()
        wnxs.writeFileAttrs(self.nxs,nxsFile)
        wnxs.createBasicFramework(self.nxs)
        wnxs.createModule(self.nxs,name)
        pixels,idstart,pidNums=generatePixel(name)
        step = int(40000/tofbins)
        tofs = np.arange(0,40000+step,step)

        wnxs.writeModuleInfo(self.nxs,name,pixels,tofs)
        self.ds_hist = wnxs.createHistDataPointer(self.nxs, name,pidNums,tofbins)
        self.histogram = Hist2D(pidNums,tofbins, [[-0.5 + idstart, idstart+pidNums+0.5],[0-0.5, 40000-0.5]])

    def handler(self, signal_received, frame):
        # Handle any cleanup here
        self.nxs.close()
        print('SIGINT or CTRL-C detected. Exiting gracefully')
        exit(0)
    
    def closeNXS(self):
        self.nxs.close()


    def fillHist(self,binList):
        signal(SIGINT,self.handler)
        signal(SIGTERM,self.handler)
        signal(SIGTSTP,self.handler)
        for i in binList:
            fname=self.rawPath+"/"+self.module+"_evt_"+str(i)+".nxs"
            f0=h5py.File(fname,"r")
            tofs = f0["/csns/instrument/"+self.module+"/event_time_of_flight"]
            pids = f0["/csns/instrument/"+self.module+"/event_pixel_id"]
            tot=len(pids)
            bins=math.ceil(tot/self.evtStep)
            for i in range(bins):
                start=int(i*self.evtStep)
                end=int((i+1)*self.evtStep)
                if end>tot:
                    end=tot
                x=pids[start:end]
                y=tofs[start:end]
                self.histogram.fill(x,y)
            #print("fill hist from ",n," ",time.time()-be2)
            f0.close()
        histVal=self.histogram.hist
        self.ds_hist[:,:]=histVal

class recLog():
    def __init__(self,nxsFile,tai, pc):
        self.nxs=h5py.File(nxsFile,"w")
        wnxs=CSNSNexus()
        wnxs.writeFileAttrs(self.nxs,nxsFile)
        wnxs.createBasicFramework(self.nxs)
        wnxs.writePCLog(self.nxs,tai,pc)
        tot=pc.sum()
        wnxs.writePCPublic(self.nxs,tot)

    def closeNXS(self):
        self.nxs.close()


class recEvt():
    def __init__(self,name,nxsFile,chunkSize,compression):
        self.nxs=h5py.File(nxsFile,"w")
        wnxs=CSNSNexus()
        wnxs.writeFileAttrs(self.nxs,nxsFile)
        wnxs.createBasicFramework(self.nxs)
        wnxs.createModule(self.nxs,name)
        self.ds_pid, self.ds_tof, self.ds_pulse = wnxs.createEvtDataPointer(self.nxs, name,chunkSize,compression)

    def handler(self, signal_received, frame):
        # Handle any cleanup here
        self.nxs.close()
        print('SIGINT or CTRL-C detected. Exiting gracefully')
        exit(0)

    def closeNXS(self):
        self.nxs.close()

    def initEvtParams(self):
        self.tmp_tof = {}
        self.tmp_pid = {}
        self.tmp_pulse = {}
        self.evtNum = 0
        self.startNum = 0
        self.loopNum = 0

    def appendEvtData(self,num,pixelId,tof,pulseTime,chunkSize):
        if num == 0:
            self.ds_tof[:] = tof
            self.ds_pid[:] = pixelId
            self.ds_pulse[:] = pulseTime
        else:
            npsize=chunkSize
            if len(tof)<chunkSize:
                npsize=len(tof)
            self.ds_tof.resize(self.ds_tof.shape[0] + npsize, axis=0)
            self.ds_pid.resize(self.ds_pid.shape[0] + npsize, axis=0)
            self.ds_pulse.resize(self.ds_pulse.shape[0] + npsize, axis=0)
            self.ds_tof[-npsize:]=tof
            self.ds_pid[-npsize:]=pixelId
            self.ds_pulse[-npsize:]=pulseTime


    def fillEvt(self,pulseId,tofs,pids,startT0,endT0,chunksize):
        judge=False
        signal(SIGINT,self.handler)
        signal(SIGTERM,self.handler)
        signal(SIGTSTP,self.handler)
        #be=time.time()
        if pulseId>=startT0:
            self.tmp_tof["var"+str(self.startNum)] = tofs
            self.tmp_pid["var"+str(self.startNum)] = pids
            self.evtNum += pids.size
            pulseID = []
            for i in range(pids.size):
                pulseID.append(pulseId)
            self.tmp_pulse["var"+str(self.startNum)]=pulseID
            self.startNum += 1
            if self.evtNum >=chunksize:
                _tof, res_tof = mergeNumpyAll(self.tmp_tof,chunksize)
                _pid, res_pid = mergeNumpyAll(self.tmp_pid,chunksize)
                _pulse,res_pulse = mergeNumpyAll(self.tmp_pulse,chunksize)
                self.appendEvtData(self.loopNum,_pid,_tof,_pulse,chunksize)
                self.loopNum += 1
                self.tmp_tof={}
                self.tmp_pid={}
                self.tmp_pulse={}
                self.startNum=0
                self.evtNum=0
                if pulseId>=endT0:
                    if res_tof is None:
                        pass
                    else:
                        self.appendEvtData(self.loopNum,res_pid,res_tof,res_pulse,chunksize)
                    judge=True
                if res_tof is None:
                    pass
                else:
                    self.tmp_tof["var"+str(self.startNum)]=res_tof
                    self.tmp_pid["var"+str(self.startNum)]=res_pid
                    self.tmp_pulse["var"+str(self.startNum)]=res_pulse
                    self.evtNum += self.tmp_tof["var"+str(self.startNum)].shape[0]
                    self.startNum +=1
            #print(time.time()-be,"second for every pulse!")
            if pulseId>=endT0:
                judge=True
        return judge


