import time, datetime
from IO.RedisHelper import RedisHelper
#from  Utils.Histogram import Hist2D
import baseComp
from IO.Kafka import getDetectorData, getGlobalData, KafkaGrabber
import re
import os
import numpy as np
import subprocess
import json


def getNameFromTopic(confInfo):
    _topic = confInfo['topic']
    num=_topic.find("Monitor")
    if num==-1:
        num=_topic.find("Bank")
        _p="module1"+_topic[num+4:num+6]
        num=_topic.find("Module")
        _p+=_topic[num+6:num+8]
    else:
        _p="monitor"+_topic[num+7:num+9]
    return _p

def getModuleDictFromConf(conf):
    moduleDict={}
    nameList=["detector_modules","monitor_modules"]
    for name in nameList:
        try:
            for _item in conf['modules'][name]:
                if _item.get('enabled') is not None:
                    if _item['enabled'] == False:
                        print(f'{_item["topic"]} is not enabled, skip.')
                        continue
                    if _item["data_type"] == "event":
                        mname = getNameFromTopic(_item)
                        moduleDict[mname]=_item
        except:
            pass
    return moduleDict


def getItemFromConf(conf,nodename,datatype,topic):
    for _item in conf["modules"][nodename]:
        if _item["data_type"]== datatype:
            if _item["topic"]==topic:
                return _item
    return None



def createDir(path):
    if os.path.exists(path):
        pass
    else:
        os.mkdir(path)


class kafkaRec():
    def __init__(self,confFile):
        with open(confFile,"r") as json_file:
            self.conf=json.load(json_file)
        self.tofs=self.conf["rec_configure"]["tofbins"]
        self.chunksize=self.conf["rec_configure"]["chunksize"]
        self.myDict={} 
        
    def getRedisHelper(redis_conf):
        _mode = redis_conf['mode'].lower()
        _password = redis_conf['password']
        _servers = []
        for _item in redis_conf['servers']:
            _servers.append((_item['host'], _item['port']))
        if _mode == 'standalone':
            return RedisHelper(_servers[0], _password, 10)
        elif _mode == 'sentinel':
            return RedisHelper(_servers, _password, 10, master_name=redis_conf['master_name'])
        else:
            raise Exception(f'Redis mode not supported: {_mode}')
    
    def connectRedis(self,name):
        redis_conf=self.conf[name]
        return self.getRedisHelper(redis_conf)


    def getRun2Rec(self,startNo):
        self.expPath=self.conf["rec_configure"]["local_expinfo_path"]
        self.localdbPath=self.conf["rec_configure"]["local_completeinfo_path"]
        print(self.localdbPath)
        recList=[]
        complete=os.listdir(self.localdbPath)
        expdb=os.listdir(self.expPath)
        for item in expdb:
            if item[:3]=="RUN" and item not in complete:
                if int(item[-7:])>startNo:
                    recList.append(item)
        return recList

    def getPulseInfoFromFile(self,runno):
        fname=self.expPath+"/"+runno
        with open(fname, "r") as conf_file:
            conf=json.load(conf_file)
        self.startT0=conf["startPulseId"]
        self.endT0=conf["endPulseId"]


    def getPusleInfoFromRedis(self):
        pass

    def getPublicInfo(self,runno):
        tc=timeConvert()
        startLocal=tc.pulseId2Local(self.startT0)
        startStamp=tc.local_to_utc_stamp(startLocal)
        endLocal=tc.pulseId2Local(self.endT0)
        endStamp=tc.local_to_utc_stamp(endLocal)
        pubDict={"start_time_utc":[bytes(startLocal,encoding="utf8")],
             "end_time_utc":[bytes(endLocal,encoding="utf8")],
             "start_time_tai":[bytes(str(startStamp),encoding="utf8")],
             "end_time_tai":[bytes(str(endStamp),encoding="utf8")],
             "beamline":[b"BL16"],
             "run_no":[str.encode(runno)]}
        return pubDict

    def mergeNxs(self,runno,dataMode):
        createDir(self.nxsPath)
        pubInfo=self.getPublicInfo(runno)
        moduleList=self.myDict.keys()
        if len(moduleList)==0:
            mDict=getModuleDictFromConf(self.conf)
            moduleList=mDict.keys()
        print(moduleList)
        run = baseComp.nxsWrite(self.rawPath,self.nxsPath,self.tofs)
        run.mergeModules(pubInfo,moduleList,dataMode)


    def offlineDetector(self,conf,mname,t1,t2,dataMode):
        tc=timeConvert()
        offset=tc.getOffset(self.startT0)
        task = KafkaGrabber(conf,offset)
        consumer=task.getKafkaConsumer()
        run = baseComp.nxsWrite(self.rawPath, self.nxsPath,self.tofs)
        run.startSingleModule(mname,dataMode,self.chunksize)
        for msg in consumer:
            pulseId, tofs, pids = getDetectorData(msg)
            if dataMode=="hist":
                finish=run.fillHist(pulseId,tofs,pids,t1,t2)
            elif dataMode=="evt":
                finish=run.fillEvt(pulseId,tofs,pids,t1,t2,self.chunksize)
            if finish:
                break
        consumer.close()
        run.closeNXS()
        print("finish online ",mname)

    def getPath(self,runno):
        self.dbPath=self.conf["rec_configure"]["db_path"]
        self.dbIDPath=self.conf["rec_configure"]["db_id_path"]
        self.cloudPath=self.conf["rec_configure"]["cloud_path"]
        self.expPath=self.conf["rec_configure"]["local_expinfo_path"]
        self.localdbPath=self.conf["rec_configure"]["local_completeinfo_path"]
        self.rawPath=self.conf["rec_configure"]["rawdata_path"]+"/"+runno
        self.nxsPath=self.conf["rec_configure"]["complete_path"]+"/"+runno


    def getTask(self,dataMode,lineMode,runno):
        #lineMode is online or offline
        createDir(self.rawPath)
        mDict=getModuleDictFromConf(self.conf)
        for taskName in mDict:
            if lineMode=="offline":
                self.myDict[taskName]=[self.offlineDetector,(mDict[taskName],taskName,self.startT0,self.endT0,dataMode)]


    def onlineDetector(conf,mname,prepath,tofs,chunksize,t1,t2):
        print("start online: ",mname)
        print("start offset: ",mname,offset)
        task = KafkaGrabber(conf, offset)
        consumer=task.getKafkaConsumer()
        nxsfile=prepath+"/"+mname+".nxs"
        run = baseComp.nxsWrite()
        if mname[:6]=="module":
            run.recSingleModule(prepath,mname,tofs, chunksize, consumer,t1,t2)
        else:
            print("start rec monitor")
            run.recSingleMonitor(prepath,mname,tofs,consumer,t1,t2)
        #_cmd = 'scp ./complete/complete_'+mname+" "+prepath
        #subprocess.check_call(_cmd, shell=True)
        print("finish online ",mname)


    def sendData(self,runno):
        expFile="./runlog/"+runno
        _cmd="cp "+expFile+" "+self.nxsPath
        subprocess.check_call(_cmd, shell=True)

        _cmd="scp -r "+self.nxsPath+" "+self.cloudPath
        subprocess.check_call(_cmd, shell=True)
        #_cmd = 'scp complete  '+transPath1+runNo
        #subprocess.check_call(_cmd, shell=True)
        _cmd="scp -r "+self.nxsPath+" "+self.dbPath
        subprocess.check_call(_cmd, shell=True)
        
        f=open(runno, "w")
        f.close()
        _cmd="scp "+runno+" "+self.dbIDPath
        subprocess.check_call(_cmd, shell=True)
        _cmd="rm -rf "+runno
        subprocess.check_call(_cmd, shell=True)




class timeConvert():
    def __init__(self):
        self.local_format="%Y-%m-%dT%H:%M:%S"

    def getCurrPulseIdSecond(self):
        confFile = "conf.json"
        with open(confFile,"r") as json_file:
            conf=json.load(json_file)
        item=conf["modules"]["global_modules"]
        #print(item)
        task = KafkaGrabber(item[0])
        consumer=task.getKafkaConsumer()
        for msg in consumer:
            pulseId,_,_,timeSecond = getGlobalData(msg)
            break
        consumer.close()
        return pulseId,timeSecond[0]

    def utc_stamp_to_local(self,seconds):
        utc_time=datetime.datetime.utcfromtimestamp(seconds)
        local=utc_time+datetime.timedelta(hours=8)
        local=local.strftime(self.local_format)
        return local

    def local_to_utc_stamp(self,local):
        dt=datetime.datetime.strptime(local, self.local_format)
        dt=dt+datetime.timedelta(hours=-8)
        stamp=time.mktime(dt.timetuple())
        return stamp
    
    def local2PulseId(self,local):
        ts = self.local_to_utc_stamp(local)
        curPulseId,curTime=self.getCurrPulseIdSecond()
        offsetTS=curTime-ts
        if offsetTS<0:
            print("error in kafka data!")
        else:
            offsetPulseId=offsetTS*25
            pulseId=curPulseId-offsetPulseId
        return pulseId

    def pulseId2Local(self, pulseId):
        curPulseId,curTime=self.getCurrPulseIdSecond()
        offsetPulseId=curPulseId-pulseId
        if offsetPulseId<0:
            print("error in kafka data!")
        else:
            offsetTS=offsetPulseId/25.0
            ts=curTime-offsetTS
            local = self.utc_stamp_to_local(ts)
        return local

    def getOffset(self,pulseId):
        curPulseId,curTime=self.getCurrPulseIdSecond()
        offset=curPulseId-pulseId
        return offset
