import IOHelper
import time
import math
import numpy as np
from  Utils.Histogram import Hist2D


def getTOFrange(t1,t2,tstep,tofbins):
    startNum=0
    endNum=tofbins
    for i in range(tofbins):
        v=i*tstep
        if v>t1:
            startNum=i
            break
    for i in range(tofbins):
        v=i*tstep
        if v>t2:
            endNum=i
            break
    return startNum, endNum



class task_online():
    def __init__(self,conf):
        self.rds=IOHelper.getRedisHelper(conf["redis_data"])
        self.conf=conf
        self.onConf=conf["online_configure"]
        self.stime=self.onConf["sleepTime"]
        self.moduleGroupMatch=IOHelper.getModuleGroupMatch(conf)
        self.bankDict=IOHelper.getGroupDict(conf)
        print(self.moduleGroupMatch)
        print(self.bankDict)

    def initPid(self,module):
        bank=self.moduleGroupMatch[module]
        if bank[:7]=="monitor":
            pidNums=1
            idstart=1
        else:
            ny=self.onConf["detectorInfo"]["tubeNums"]
            xsize=self.onConf["detectorInfo"]["pixelSize"]
            nx=int(self.onConf["detectorInfo"]["tubeLen"][bank]/xsize)
            pidNums=int(nx*ny)
            idstart=int(module[-5:]+'0001')
        return pidNums, idstart

    def initTof(self,module):
        bank=self.moduleGroupMatch[module]
        if bank[:7]=="monitor":
            tofbins=self.onConf["tofInfo"]["tofbins_monitor"]
        else:
            tofbins=self.onConf["tofInfo"]["tofbins"]
        L1=self.onConf["sampleDist"]
        tmax=self.onConf["tofInfo"]["tofmax"]
        tmin=self.onConf["tofInfo"]["tofmin"]
        step=int((tmax-tmin)/tofbins)
        wmin=self.onConf["tofInfo"]["waveMin"]
        wmax=self.onConf["tofInfo"]["waveMax"]
        tmin=252.7*L1*wmin
        tmax=252.7*L1*wmax
        idx1,idx2=getTOFrange(tmin,tmax,step,tofbins)
        tofs=np.arange(idx1*step,idx2*step,step)
        return tofs,idx1,idx2,L1

    def initDspacing(self,module):
        bank=self.moduleGroupMatch[module]
        dDict={}
        if bank[:7]=="monitor":
            pass
        else:
            dRebin=self.onConf["dRebin"][bank]
            dRebin=dRebin.split(",")
            dDict["dmin"]=float(dRebin[0])
            dDict["dstep"]=float(dRebin[1])
            dDict["dmax"]=float(dRebin[2])
            tmp=int((float(dRebin[2])-float(dRebin[0]))/float(dRebin[1]))
            dDict["dbins"]=tmp
        return dDict

    def initPath(self, module):
        if module[:7]=="monitor":
            cxxPath=self.onConf["xxPath_m"]+"/monitor1/value"
            yllPath=self.onConf["llPath"]+"/monitor1"
        else:
            cxxPath=self.onConf["xxPath"]+"/"+module+"/value"
            yllPath=self.onConf["llPath"]+"/"+module
        return cxxPath, yllPath

    def initDetPos(self,module):
        fname=self.onConf["pidInfoPath"]+"/"+module+".json"
        return IOHelper.getConf(fname)


    def initImageAxis(self,module,sendPath):
        bank=self.moduleGroupMatch[module]
        ny=self.onConf["detectorInfo"]["tubeNums"]
        xsize=self.onConf["detectorInfo"]["pixelSize"]
        nx=int(self.onConf["detectorInfo"]["tubeLen"][bank]/xsize)
        xsize=xsize*1000
        ysize=self.onConf["detectorInfo"]["tubeDiameter"]*1000
        xaxis=np.arange(0,nx*xsize,xsize)
        yaxis=np.arange(0,ny*ysize,ysize)
        self.rds.writeNumpyArray(sendPath+"/xy_image/x",xaxis)
        self.rds.writeNumpyArray(sendPath+"/xy_image/y",yaxis)
        return nx,ny

    def tof2d(self,val,dDict,pidNums,idstart,tofs,posInfo):
        hist = Hist2D(1,dDict["dbins"], [[0,1.5],[dDict["dmin"], dDict["dmax"]]])
        cst=np.ones(len(tofs))
        for i in range(pidNums):
            w=val[i,:]
            pidname=str(idstart+i)
            info=posInfo[pidname]
            constB=505.4*info[0]*math.sin(info[1]/2.0)
            td=tofs/constB
            hist.fill(cst,td,w)
        histVal=hist.hist
        d_counts=histVal[0]
        d=(hist.yedge[1:]+hist.yedge[:-1])*0.5
        return d, d_counts

    def d2q(self,x,y):
        q=2*math.pi/x
        q=q[::-1]
        q_counts=y[::-1]
        return q, q_counts

    def d2tof(self,x,y,module):
        bank=self.moduleGroupMatch[module]
        cvtTOF=self.onConf["convertTOF"][bank]
        xtof=x*cvtTOF
        tof_counts=y
        return xtof,tof_counts

    def tof2wave(self,x,y,l1):
        wave=x/(252.7*l1)
        wave_counts=y
        return wave,wave_counts

    def send1Ddata(self,sendPath,dataName,x,y):
        path=sendPath+"/"+dataName
        self.rds.writeNumpyArray(path,x)
        path=sendPath+"/"+dataName+"_counts"
        self.rds.writeNumpyArray(path,y)


    def sendImage(self,sendPath,nx,ny,val):
        tmp=np.sum(val,axis = 1)
        tmp=tmp.T
        tmp=tmp.reshape(ny,nx)
        self.rds.writeNumpyArray(sendPath+"/xy_image/value", tmp)


    def get1Ddata(self, sendPath, dataName):
        path=sendPath+"/"+dataName
        return (self.rds.readNumpyArray(path))

    def runBank(self):
        while True:
            for bank in self.bankDict:
                try:
                    mList=self.bankDict[bank]
                    path=self.onConf["llPath"]+"/"+mList[0]
                    x_d = self.get1Ddata(path,"d")
                    x_q=self.get1Ddata(path,"q")
                    constA=self.onConf["convertTOF"][bank]
                    x_tof=x_d*constA
                    v_d = np.zeros(len(x_d))
                    v_q = np.zeros(len(x_q))

                    for name in mList:
                        path=self.onConf["llPath"]+"/"+name
                        v_d+=self.get1Ddata(path,"d_counts")
                        v_q+=self.get1Ddata(path,"q_counts")

                    path=self.onConf["llPath"]+"/group"+bank[-2:]+"/raw"
                    self.send1Ddata(path,"d",x_d,v_d)
                    self.send1Ddata(path,"tof",x_tof,v_d)
                    self.send1Ddata(path,"q",x_q,v_q)
                except:
                    pass
            time.sleep(self.stime)


    def runModule(self,name):
        #'''
        cxxPath, yllPath = self.initPath(name)
        tofs,crop1,crop2,_=self.initTof(name)
        pidNums,idstart=self.initPid(name)
        dInfo=self.initDspacing(name)
        posInfo=self.initDetPos(name)
        nx,ny = self.initImageAxis(name,yllPath)
        print(posInfo)
        while True:
            tmp=self.rds.readNumpyArray(cxxPath)
            value=tmp.T
            value=value[:,crop1:crop2]
            d,d_counts=self.tof2d(value,dInfo,pidNums,idstart,tofs,posInfo)
            q,q_counts=self.d2q(d,d_counts)
            xtof,tof_counts=self.d2tof(d,d_counts,name)
            self.send1Ddata(yllPath,"d",d,d_counts)
            self.send1Ddata(yllPath,"q",q,q_counts)
            self.send1Ddata(yllPath,"tof",xtof,tof_counts)
            self.sendImage(yllPath,nx,ny,value)
            time.sleep(self.stime)
        #'''
    def runMonitor(self,name):
        cxxPath, yllPath = self.initPath(name)
        tofs,crop1,crop2,l1=self.initTof(name)
        #pidNums,idstart=self.initPid(name)
        while True:
            tmp=self.rds.readNumpyArray(cxxPath)
            value=tmp.T
            value=value[:,crop1:crop2]
            value=value.flatten()
            wave,wave_counts=self.tof2wave(tofs,value,l1)
            self.send1Ddata(yllPath,"wave",wave,wave_counts)
            self.send1Ddata(yllPath,"tof",tofs,value)
            time.sleep(self.stime)

    #'''
    def getTasks(self):
        allModules=IOHelper.getModulesConf(self.conf)
        tasks={}
        mList=allModules.keys()
        for taskName in mList:
            if taskName[:7]=="monitor":
                tasks[taskName]=[self.runMonitor,(taskName,)]
            else:
                tasks[taskName]=[self.runModule,(taskName,)]
        tasks["bank"]=[self.runBank,()]

        return tasks
