from IO.RedisHelper import RedisHelper
import json
import time
import numpy as np
from  Utils.Histogram import Hist2D
import math
import matplotlib.pyplot as plt
import sys

def getNameFromTopic(confInfo):
    _topic = confInfo['topic']
    num=_topic.find("Monitor")
    if num==-1:
        num=_topic.find("Bank")
        _p="module1"+_topic[num+4:num+6]
        num=_topic.find("Module")
        _p+=_topic[num+6:num+8]
    else:
        _p="monitor"+_topic[num+7:num+9]
    return _p

def getModuleDictFromConf(conf):
    moduleDict={}
    nameList=["detector_modules","monitor_modules"]
    for name in nameList:
        try:
            for _item in conf['modules'][name]:
                if _item.get('enabled') is not None:
                    if _item['enabled'] == False:
                        print(f'{_item["topic"]} is not enabled, skip.')
                        continue
                    if _item["data_type"] == "event":
                        mname = getNameFromTopic(_item)
                        moduleDict[mname]=_item
        except:
            pass
    return moduleDict

def getRedisHelper(redis_conf):
    _mode = redis_conf['mode'].lower()
    _password = redis_conf['password']
    _servers = []
    for _item in redis_conf['servers']:
        _servers.append((_item['host'], _item['port']))
    if _mode == 'standalone':
        return RedisHelper(_servers[0], _password, 10)
    elif _mode == 'sentinel':
        return RedisHelper(_servers, _password, 10, master_name=redis_conf['master_name'])
    else:
        raise Exception(f'Redis mode not supported: {_mode}')


def getBankGroup(module):
    if module[:4]=="modu":
        num=int(module[-5:][1:3])
        if num==7 or num==8:
            return "bank7"
        if num==6 or num==9:
            return "bank6"
        if num==5 or num==10:
            return "bank5"
        if num==4 or num==11:
            return "bank4"
        if num==3 or num==12:
            return "bank3"
        if num==2 or num==13:
            return "bank2"
        if num==1:
            return "bank1"
    else:
        name=module[:7]+str(int(module[-2:]))
        return name

class moduleIO():
    def __init__(self,rds,module,conf,refreshtime):
        self.rds=rds
        self.rtime=refreshtime
        self.module=module
        self.bank=getBankGroup(module)
        if self.bank[:7]=="monitor":
            self.pidNums=1
            self.histPath=conf["xxPath_m"]+"/"+self.bank+"/value"
            self.sendPath=conf["llPath"]+"/"+self.bank
        else:
            self.getDValue(conf)
            self.getPidValue(conf)
            self.histPath=conf["xxPath"]+"/"+module+"/value"
            self.sendPath=conf["llPath"]+"/"+module
         
        self.L1=conf["sampleDist"]
        self.getTOFValue(conf)

    def getTOFrange(self,t1,t2,tstep):
        startNum=0
        endNum=self.tofbins
        for i in range(self.tofbins):
            v=i*tstep
            if v>t1:
                startNum=i
                break
        for i in range(self.tofbins):
            v=i*tstep
            if v>t2:
                endNum=i
                break
        return startNum, endNum

    def getDValue(self,conf):
        dRebin=conf["dRebin"][self.bank]
        dRebin=dRebin.split(",")
        self.dmin=float(dRebin[0])
        self.dstep=float(dRebin[1])
        self.dmax=float(dRebin[2])
        self.dbins=int((self.dmax-self.dmin)/self.dstep)

    def getTOFValue(self,conf):
        self.cvtTOF=conf["convertTOF"][self.bank]
        if self.bank[:7]=="monitor":
            self.tofbins=conf["tofInfo"]["tofbins_monitor"]
        else:
            self.tofbins=conf["tofInfo"]["tofbins"]
        tmax=conf["tofInfo"]["tofmax"]
        tmin=conf["tofInfo"]["tofmin"]
        step=int((tmax-tmin)/self.tofbins)
        wmin=conf["tofInfo"]["waveMin"]
        wmax=conf["tofInfo"]["waveMax"]
        tmin=252.7*self.L1*wmin
        tmax=252.7*self.L1*wmax
        self.idx1,self.idx2=self.getTOFrange(tmin,tmax,step)
        self.tofs=np.arange(self.idx1*step,self.idx2*step,step)
        self.cst=np.ones(self.idx2-self.idx1)
    def getPidValue(self,conf):
        self.ny=conf["detectorInfo"]["tubeNums"]
        xsize=conf["detectorInfo"]["pixelSize"]
        self.nx=int(conf["detectorInfo"]["tubeLen"][self.bank]/xsize)
        xsize=xsize*1000
        self.pidNums=int(self.nx*self.ny)
        fname=conf["pidInfoPath"]+"/"+self.module+".json"
        with open(fname,'r') as jf:
            self.pidDict=json.load(jf)
        self.idstart=int(self.module[-5:]+'0001')
        ysize=conf["detectorInfo"]["tubeDiameter"]*1000
        self.xaxis=np.arange(0,self.nx*xsize,xsize)
        self.yaxis=np.arange(0,self.ny*ysize,ysize)


    def send1Ddata(self,dataname,x,y):
        path=self.sendPath+"/"+dataname
        self.rds.writeNumpyArray(path,x)
        path=self.sendPath+"/"+dataname+"_counts"
        self.rds.writeNumpyArray(path,y)


    def calDiffraction(self):
        #get d spacing
        hist = Hist2D(1,self.dbins, [[0,1.5],[self.dmin, self.dmax]])
        for i in range(self.pidNums):
            w=self.value[i,:]
            pidname=str(self.idstart+i)
            info=self.pidDict[pidname]
            constB=505.4*info[0]*math.sin(info[1]/2.0)
            td=self.tofs/constB
            hist.fill(self.cst,td,w)
        histVal=hist.hist
        d_counts=histVal[0]
        d=(hist.yedge[1:]+hist.yedge[:-1])*0.5
        self.send1Ddata("d",d,d_counts)
        #convert to q
        q=2*math.pi/d
        q=q[::-1]
        q_counts=d_counts[::-1]
        self.send1Ddata("q",q,q_counts)
        #convert to tof
        xtof=d*self.cvtTOF
        tof_counts=d_counts
        self.send1Ddata("tof",xtof,tof_counts)


    def calWavelength(self):
        wave = self.tofs/(252.7*self.L1)
        self.send1Ddata("wave",wave,self.value)
        self.send1Ddata("tof",self.tofs,self.value)


    def calDetectorImage(self):
        tmp=np.sum(self.value,axis = 1)
        tmp=tmp.T
        _tmp=tmp.reshape(self.ny,self.nx)
        self.rds.writeNumpyArray(self.sendPath+"/xy_image/x",self.xaxis)
        self.rds.writeNumpyArray(self.sendPath+"/xy_image/y",self.yaxis)
        self.rds.writeNumpyArray(self.sendPath+"/xy_image/value", _tmp)

    def process(self):
        tmp = self.rds.readNumpyArray(self.histPath)
        self.value=tmp.T
        self.value=self.value[:,self.idx1:self.idx2]
        
        if self.module[:7]=="monitor":
            self.value=self.value.flatten()
            self.calWavelength()
        else:
            self.calDiffraction()
            self.calDetectorImage()

    def runModule(self):
        while True:
            self.process()
            time.sleep(self.rtime)


class bankIO():
    def __init__(self,rds,moduleList,conf,refreshtime):
        self.moduleList=moduleList
        self.sendPath=conf["llPath"]
        self.bankInfo=self.getBankDict()
        self.rds=rds
        self.stime=refreshtime
        self.constDict=conf["convertTOF"]
        self.bankDict=self.getBankDict()
        
    def getBankDict(self):
        tmp={}
        tmp["bank2"]=[]
        tmp["bank3"]=[]
        tmp["bank4"]=[]
        tmp["bank5"]=[]
        tmp["bank6"]=[]
        tmp["bank7"]=[]
        for name in self.moduleList:
            bank=getBankGroup(name)
            if bank[:7]=="monitor":
                pass
            else:
                tmp[bank].append(name)
        return tmp

    def get1Ddata(self,dataname,name):
        path=self.sendPath+"/"+name+"/"+dataname
        return self.rds.readNumpyArray(path)
        

    def sendData(self,dataname,name,xv,yv):
        path=self.sendPath+"/group"+name[4:].zfill(2)+"/raw"
        self.rds.writeNumpyArray(path+"/"+dataname,xv)
        self.rds.writeNumpyArray(path+"/"+dataname+"_counts",yv)

    def mergeData(self):
        for bname in self.bankDict.keys():
            mList=self.bankDict[bname]
            if len(mList)==0:
                pass
            else:
                x_d = self.get1Ddata("d",mList[0])
                constA = self.constDict[bname]
                x_tof = x_d*constA
                v = np.zeros(len(x_d))
                for mname in mList:
                    v+=self.get1Ddata("d_counts",mname)
                self.sendData("d",bname,x_d,v)
                self.sendData("tof",bname,x_tof,v)
            

    def runBank(self):
        while True:
            #try:
            self.mergeData()
            #except:
            #    pass
            #time.sleep(self.stime)
