from .Task import Task
from kafka import KafkaConsumer, TopicPartition
import re,time,logging
from iMACS.IO.Kafka import getDetectorData, getGlobalData
from  iMACS.Utils.Histogram import Hist2D
from iMACS.IO.RedisHelper import RedisHelper
from iMACS.Config import ConfigHelper
from time import sleep
import numpy as np


class KafkaTask(Task):
    def __init__(self, configKafka, jumpBack=0, confHist=None, sharedControl=None, sharedNumHit=None, sharedAccumulated=None, lastExitCode=0):
        super().__init__(sharedControl, sharedNumHit, sharedAccumulated, None, None, lastExitCode)
        if type(configKafka) is not dict:
            raise Exception('KafkaTask.__init__ configKafka must be a dict')
        if type(confHist) is not dict and confHist is not None:
            raise Exception('KafkaTask.__init__ confHist must be a dict or a None')

        try:
            configKafka['topic']
        except KeyError:
            raise Exception('Error: configKafka does not contain a topic key')
        self.jumpBack =  jumpBack
        self.configKafka = ConfigHelper(configKafka)

    def initKafka(self):
        self.dataType = self.configKafka.getValue('data_type')
        self.consumer = KafkaConsumer(group_id=self.configKafka.getValue('group_id'),
                bootstrap_servers = self.configKafka.getValue('bootstrap_servers'),
                auto_offset_reset = self.configKafka.getValue('auto_offset_reset'),
                enable_auto_commit = self.configKafka.getValue('enable_auto_commit'),
                consumer_timeout_ms = float(self.configKafka.getValue('consumer_timeout_ms')))
        topicPartition =  TopicPartition(self.configKafka.getValue('topic'), 0)
        self.consumer.assign([topicPartition])
        #self.consumer.seek_to_end() fixme; use seek_type
        self.consumer.seek_to_beginning()
        _current_offset = self.consumer.position(topicPartition)
        if _current_offset - self.jumpBack < 0:
            raise Exception('KafkaTask.initKafka jump back is greater than current offset ')
        self.consumer.seek(topicPartition, int(_current_offset - self.jumpBack))

        #set the function that decodes a kafka message
        if self.dataType=='event':
            self.funcRead=getDetectorData
        elif self.dataType=='meta':
            self.funcRead=getGlobalData
        else:
            raise Exception('KafkaTask.initKafka ' + str(self.dataType) + ' is an unknown kafka data type ')

        #init a histogram if histConf is given
        if self.configKafka.getValue('histogram') is None:
            self.hist = None
        else:
            xBinNum = self.configKafka.getValue('histogram', 'xBinNum' )
            yBinNum = self.configKafka.getValue('histogram', 'yBinNum' )
            xmin = self.configKafka.getValue('histogram', 'xmin' )
            xmax = self.configKafka.getValue('histogram', 'xmax' )
            ymin = self.configKafka.getValue('histogram', 'ymin' )
            ymax = self.configKafka.getValue('histogram', 'ymax' )
            self.hist = Hist2D( xBinNum , yBinNum \
                [[xmin, ymax], [xmin, ymax]])

        #fixme add peroidicRW


    def realInit(self):
        if not self.init:
            self.initKafka()
            self.init = True

    def realRun(self, msg):
        if self.control is None:
            return
        #idle
        if self.control[0].value == 0:
            pass

        #action
        elif self.control[0].value == 1:
            data = self.funcRead(msg)
            #fill histogram
            if self.hist is not None:
                self.hist(data[1],data[2])

            #accumulate total number
            if self.accumulated is not None:
                with self.accumulated[1]:
                    self.accumulated[0].value += data[1].sum()

            if self.numhit is not None:
                with self.numhit[1]:
                    self.numhit[0].value += data[1].size

        #reset
        elif self.control[0].value == 2:
            if self.hist is not None:
                self.hist.hist=np.zeros(self.hist.hist.shape)

    def run(self):
        self.realInit()
        for msg in self.consumer:
            self.realRun(msg)
