from .Task import Task
from kafka import KafkaConsumer, TopicPartition
import re,time,logging
from iMACS.IO.Kafka import getDetectorData, getGlobalData
from iMACS.IO import ManagedHistRedis
from iMACS.IO.RedisHelper import RedisHelper
from iMACS.Config import ConfigHelper
from time import sleep
import numpy as np


class ManagedKafkaTask(Task):
    def __init__(self, kafkaCfg, sharedControl=None, sharedNumHit=None, sharedAccumulated=None, lastExitCode=0):
        super().__init__(sharedControl, sharedNumHit, sharedAccumulated, None, None, lastExitCode)
        if type(kafkaCfg) is not dict:
            raise Exception('KafkaTask.__init__ configKafka must be a dict')
        self.managedCfg = ConfigHelper(kafkaCfg)

    def initKafka(self):
        self.dataType = self.managedCfg.getValue('data_type')
        self.consumer = KafkaConsumer(group_id=self.managedCfg.getValue('group_id'),
                bootstrap_servers = self.managedCfg.getValue('bootstrap_servers'),
                auto_offset_reset = self.managedCfg.getValue('auto_offset_reset'),
                enable_auto_commit = self.managedCfg.getValue('enable_auto_commit'),
                consumer_timeout_ms = float(self.managedCfg.getValue('consumer_timeout_ms')))
        topicPartition =  TopicPartition(self.managedCfg.getValue('topic'), 0)
        self.consumer.assign([topicPartition])
        #self.consumer.seek_to_end() fixme; use seek_type
        self.consumer.seek_to_beginning()
        _current_offset = self.consumer.position(topicPartition)
        self.jumpBack=0 #fixme
        if _current_offset - self.jumpBack < 0:
            raise Exception('KafkaTask.initKafka jump back is greater than current offset ')
        self.consumer.seek(topicPartition, int(_current_offset - self.jumpBack))

        #set the function that decodes a kafka message
        if self.dataType=='event':
            self.funcRead=getDetectorData
        elif self.dataType=='meta':
            self.funcRead=getGlobalData
        else:
            raise Exception('KafkaTask.initKafka ' + str(self.dataType) + ' is an unknown kafka data type ')

        #init a histogram if histConf is given
        histCfg = self.managedCfg.getValue('histogram')
        if histCfg is None:
            self.hist = None
        else:
            self.hist = ManagedHistRedis( histCfg)
            self.hist.startPeriodicRW()

        #fixme add peroidicRW


    def realInit(self):
        if not self.init:
            self.initKafka()
            self.init = True

    def realRun(self, msg):
        if self.control is None:
            raise Exception('No control')

        #idle
        if self.control.value == 0:
            pass

        #action
        elif self.control.value == 1:
            data = self.funcRead(msg)
            #fill histogram
            if data[1].size*data[2].size !=0 and data[1].size == data[2].size:
                with self.hist.lock:
                    self.hist.fill(data[1],data[2])

                #accumulate total number
                if self.accumulated is not None:
                    with self.accumulated.get_lock():
                        self.accumulated.value += data[1].sum()

                if self.numhit is not None:
                    with self.numhit.get_lock():
                        self.numhit.value += data[1].size

        #reset
        elif self.control.value == 2:
            if self.hist is not None:
                if self.hist.hist.sum() != 0:
                    self.hist.resetHist()

    def run(self):
        self.realInit()
        for msg in self.consumer:
            self.realRun(msg)
