from Tasks import Task
from kafka import KafkaConsumer, TopicPartition
import IO.Kafka.Data.Detector.EventData as EventData
import IO.Kafka.Data.Global.MetaData as MetaData
import re,time,logging
from IO.Kafka import getDetectorData, getGlobalData
from  Utils.Histogram import Hist2D
from IO.RedisHelper import RedisHelper, getRedisHelper
from time import sleep
import json

def getModuleName(conf):
    _topic = conf['topic']
    _pattern = '^BL16-Detector-(Bank\d{2}-Module\d{2})-EventData$'
    _r = re.match(_pattern, _topic)
    _p = None
    if _r is not None:
        _p = _r.group(1)
    return _p

class KafkaMetaData(Task):
    def __init__(self, status, detRat, conf, redis_conf, backward=0):
        super().__init__(status, detRat)
        if type(conf) is not dict:
            raise Exception('Error: conf must be a dict')
        self.conf = conf
        self.previousPulseId = None
        self.dataType = self.conf['data_type']
        self.backward=backward
        self.BeamTime = 0
        self.runStartTime=0

        self.consumer = KafkaConsumer(group_id=self.conf['group_id'],
                bootstrap_servers = self.conf['bootstrap_servers'],
                auto_offset_reset = self.conf['auto_offset_reset'],
                enable_auto_commit = self.conf['enable_auto_commit'],
                consumer_timeout_ms = float(self.conf['consumer_timeout_ms']))
        topicPartition =  TopicPartition(self.conf['topic'], 0)
        self.consumer.assign([topicPartition])
        self.consumer.seek_to_end()
        _current_offset = self.consumer.position(topicPartition)
        if _current_offset - self.backward > 0:
            self.consumer.seek(topicPartition, int(_current_offset - self.backward))

        #redis
        self.rds = getRedisHelper(redis_conf)
        #self.redisPath = getRedisPath(conf)

    def getDataType(self):
        return self.dataType

    def run(self):
        proton_charge_path = "/mpi/imacs/proton_charge"
        current_pulse_path = "/mpi/imacs/pulse/current"
        progress_path = "/mpi/imacs/progress"
        start_pulse_path = "/mpi/imacs/pulse/start"
        stop_pulse_path = "/mpi/imacs/pulse/stop"
        detector_rate_path = "/mpi/imacs/detector_rates"
        detector_count_path  = "/mpi/imacs/detector_counts"
        end_path = "/mpi/imacs/last_run"



        run_info_path = "/mpi/control/runInfo"
        #config_path = "/mpi/control/configure" #all path

        running = False

        _pts = list(self.consumer.assignment())
        #consumer.seek_to_end(_pts[0])
        #consumer.seek(_pts[0], consumer.position(_pts[0]) - 25*60*60)
        dataType = self.getDataType()
        hitRdsTime = time.time()
        startPulseId = 0
        endPulseId = 0
        appName = 'proton_charge'
        totCharge = 0

        print(self.rds.read(run_info_path).decode())

        runinfo =   json.loads(self.rds.read(run_info_path).decode())
        self.status.value==2

        for msg in self.consumer:
            pulseId, pCharge, timeNano, timeSecond = getGlobalData(msg)
            if self.status.value==1:
                totCharge += pCharge.sum()

            if (time.time()-hitRdsTime)>1.: #every 1 sec
                logging.info('status is {}'.format(self.status.value))
                #start
                newinfo = json.loads(self.rds.read(run_info_path).decode())
                print(newinfo)
                if newinfo is None:
                    self.status.value==2
                    hitRdsTime = time.time()
                    continue

                if runinfo is not None:
                    if newinfo['runNo']!=runinfo['runNo']: #if new info arrive
                        if runinfo is not None:
                            toDur = {}
                            toDur['runNo']=runinfo['runNo']
                            toDur['startPulseId']=pulseId
                            toDur['endPulseId']=endPulseId
                            self.rds.write('/mpi/expRecInfo', json.dumps(toDur))
                            logging.info('overright durong info {}'.format(toDur))
                        with self.lock:
                            self.status.value=2
                        sleep(1)
                        with self.lock:
                            self.status.value=1

                        runinfo=newinfo
                        #next run
                        logging.info('new run {} started.\n {} '.format(newinfo['runNo'],self.rds.read(run_info_path).decode() ) )


                        self.BeamTime=int(newinfo['endValue'])
                        startPulseID=pulseId
                        endPulseId=startPulseID+self.BeamTime*25

                        self.rds.write(start_pulse_path, pulseId)
                        self.rds.write(stop_pulse_path, pulseId+self.BeamTime*25)


                        totCharge=0
                        runinfo=newinfo
                        self.runStartTime=time.time()
                        with self.lock:
                            self.detcnt.value = 0
                        hitRdsTime=time.time()
                        continue

                #if time up or new run override
                if  runinfo is not None and self.BeamTime!=0:
                    if (newinfo['runNo']==runinfo['runNo']) :
                        if time.time()-self.runStartTime > self.BeamTime :
                            toDur = {}
                            toDur['runNo']=runinfo['runNo']
                            toDur['startPulseId']=pulseId
                            toDur['endPulseId']=endPulseId
                            self.rds.write('/mpi/expRecInfo', json.dumps(toDur))
                            with self.lock:
                                self.status.value=0
                            logging.info('overtime durong info {}'.format(toDur))
                            self.rds.write(end_path, runinfo['runNo'])
                            hitRdsTime=time.time()
                            continue
                        else:
                            self.rds.write( progress_path, 100*(time.time()-self.runStartTime )/ self.BeamTime )

                if self.status.value==1:
                    self.rds.write(current_pulse_path, pulseId)
                    dc = float(self.detcnt.value)
                    self.rds.write(detector_count_path, dc )
                    self.rds.write(detector_rate_path, (dc/(time.time()-self.runStartTime)))
                    self.rds.write(proton_charge_path, totCharge)
                    logging.info('{}:  accumulated {}, current pulse {}'.format(appName, totCharge, pulseId ) )
                hitRdsTime=time.time()
