from Tasks import Task
from kafka import KafkaConsumer, TopicPartition
import IO.Kafka.Data.Detector.EventData as EventData
import IO.Kafka.Data.Global.MetaData as MetaData
import re,time,logging
from IO.Kafka import getDetectorData, getGlobalData
from  Utils.Histogram import Hist2D
from IO.RedisHelper import RedisHelper, getRedisHelper
from Tasks.TaskHelper import *
from time import sleep
import numpy as np


class KafkaTask(Task):
    def __init__(self, confKafka, jumpBack=0, confHist=None, sharedControl=None, sharedNumHit=None, sharedAccumulated=None, slavePipe=None, masterPipe=None, lastExitCode=0):
        super().__init__(sharedControl, sharedNumHit, sharedAccumulated, slavePipe, masterPipe, lastExitCode)
        if type(confKafka) is not dict:
            raise Exception('KafkaTask.__init__ confKafka must be a dict')
        if type(confHist) is not dict and confHist is not None:
            raise Exception('KafkaTask.__init__ confHist must be a dict or a None')

        try:
            confKafka['topic']
        except KeyError:
            raise Exception('Error: confKafka does not contain a topic key')
        self.confKafka = confKafka
        self.confHist = confHist
        self.jumpBack=jumpBack

    def initKafka(self):
        self.dataType = self.confKafka['data_type']
        self.consumer = KafkaConsumer(group_id=self.confKafka['group_id'],
                bootstrap_servers = self.confKafka['bootstrap_servers'],
                auto_offset_reset = self.confKafka['auto_offset_reset'],
                enable_auto_commit = self.confKafka['enable_auto_commit'],
                consumer_timeout_ms = float(self.confKafka['consumer_timeout_ms']))
        topicPartition =  TopicPartition(self.confKafka['topic'], 0)
        self.consumer.assign([topicPartition])
        #self.consumer.seek_to_end() fixme;
        self.consumer.seek_to_beginning()
        _current_offset = self.consumer.position(topicPartition)
        if _current_offset - self.jumpBack < 0:
            raise Exception('KafkaTask.initKafka jump back is greater than current offset ')
        self.consumer.seek(topicPartition, int(_current_offset - self.jumpBack))

        #set the function that decodes a kafka message
        if self.dataType=='event':
            self.funcRead=getDetectorData
        elif self.dataType=='meta':
            self.funcRead=getGlobalData
        else:
            raise Exception('KafkaTask.initKafka ' + str(self.dataType) + ' is an unknown kafka data type ')

        #init a histogram if histConf is given
        if self.confHist is None:
            self.hist = None
        else:
            self.hist = Hist2D( confHist['x_bin'], confHist['y_bin'],
                [[confHist['x_lower'], confHist['x_upper']], [confHist['y_lower'], confHist['y_upper']]])


    def realInit(self):
        if not self.init:
            self.initKafka()
            self.init = True

    def realRun(self, msg):
        if self.control is None:
            return
        #idle
        if self.control[0].value == 0:
            pass

        #action
        elif self.control[0].value == 1:
            data = self.funcRead(msg)
            #fill histogram
            if self.hist is not None:
                self.hist(data[0],data[1])

            #accumulate total number
            if self.accumulated is not None:
                with self.accumulated[1]:
                    self.accumulated[0].value += data[1].sum()

            if self.numhit is not None:
                with self.numhit[1]:
                    self.numhit[0].value += data[1].size

        #reset
        elif self.control[0].value == 2:
            if self.hist is not None:
                self.hist.hist=np.zeros(self.hist.hist.shape)

    def run(self):
        self.realInit()
        for msg in self.consumer:
            self.realRun(msg)
            print(self.numhit[0].value)
