from Tasks import Task
from kafka import KafkaConsumer, TopicPartition
import IO.Kafka.Data.Detector.EventData as EventData
import IO.Kafka.Data.Global.MetaData as MetaData
import re,time,logging
from IO.Kafka import getDetectorData, getGlobalData
from  Utils.Histogram import Hist2D
from IO.RedisHelper import RedisHelper, getRedisHelper
from Tasks.TaskHelper import *
from time import sleep
import numpy as np

class KafkaMonitor(Task):
    def __init__(self, status, detRat,  conf, redis_conf, backward=0):
        super().__init__(status, detRat)
        if type(conf) is not dict:
            raise Exception('Error: conf must be a dict')
        self.conf = conf
        self.previousPulseId = None
        self.dataType = self.conf['data_type']
        self.backward=backward

        self.consumer = KafkaConsumer(group_id=self.conf['group_id'],
                bootstrap_servers = self.conf['bootstrap_servers'],
                auto_offset_reset = self.conf['auto_offset_reset'],
                enable_auto_commit = self.conf['enable_auto_commit'],
                consumer_timeout_ms = float(self.conf['consumer_timeout_ms']))
        topicPartition =  TopicPartition(self.conf['topic'], 0)
        self.consumer.assign([topicPartition])
        self.consumer.seek_to_end()
        _current_offset = self.consumer.position(topicPartition)
        if _current_offset - self.backward > 0:
            self.consumer.seek(topicPartition, int(_current_offset - self.backward))

        #histogram
        tofBinSize = 40000
        pixBinSize = 1

        self.hist = Hist2D(tofBinSize, pixBinSize, [[0-0.5, 40000+0.5], [0, 1e10]])

        #redis
        self.rds = getRedisHelper(redis_conf)
        self.countPath = '/mpi/imacs/monitor_counts'
        self.ratePath = '/mpi/imacs/monitor_rates'

    def getDataType(self):
        return self.dataType

    def run(self):
        _pts = list(self.consumer.assignment())
        #consumer.seek_to_end(_pts[0])
        #consumer.seek(_pts[0], consumer.position(_pts[0]) - 25*60*60)
        dataType = self.getDataType()
        startTime = time.time()
        hitRdsTime = time.time()
        startPulseID = 0
        count = 0
        total = 0
        appName = 'Monitor01'

        for msg in self.consumer:
            if self.status.value==0:
                sleep(1)
                continue
            elif self.status.value==2:
                self.hist.hist=np.zeros(self.hist.hist.shape)
                total = 0
                count = 0
                sleep(0.1)
                logging.warning('{} is performing clearing'.format(appName) )
                continue

            pulseId, tof, pid = getDetectorData(msg)
            total += 1

            pulseStart = time.time()

            if startPulseID == 0:
                startPulseID=pulseId
            #print(pulseId, tof.shape, pid.min(), pid.max())

            if not tof.size * pid.size:
                count += 1
                logging.warning(appName+ ' data with zero size')
                continue
            if tof.size != pid.size:
                logging.warning(appName+' tof and pid arrays differ in size')
                continue

            if not self.hist.fill(tof, pid):
                print(tof)
                print(pid)
                print('tof type', type(tof), tof.size)
                print('pid type', type(pid), pid.size)
                print('tof range', tof.min(),tof.max())
                print('pid range', pid.min(),pid.max() )
                logging.warning('{} contains invalid data, range of tof(), range of pid{}'.format(\
                                appName, (tof.min(),tof.max()), (pid.min(),pid.max()) ))
                return

            with self.lock:
                self.detcnt.value += tof.size

            pulseEnd=time.time()

            if (time.time()-hitRdsTime)>5.: #every 5 sec
                histVal=self.hist.hist
                self.rds.writeNumpyArray(self.redisPath, histVal[:,1])
                #data_back=rds.readNumpyArray(redisPath)
                #np.testing.assert_array_equal(histVal,data_back)
                #print('passed redis numpy read write test')
                logging.info(f'Empty pulses: {count}/{total}, {count*100/total}%')
                redisTime=time.time()
                totEvent = histVal.sum()
                self.rds.write(self.countPath, totEvent)
                if (pulseId-startPulseID)!=0.:
                    self.rds.write(self.ratePath, totEvent/((pulseId-startPulseID)*40.))

                hitRdsTime=time.time()
