
from abc import ABC, abstractmethod
import flatbuffers
from kafka import KafkaConsumer, TopicPartition
import Data.Detector.EventData
import Data.Detector.MetaData
import numpy as np

class Task(ABC):
    def __init__(self, queue):
        self.queue = queue

    @abstractmethod
    def get(self, pars):
        pass

    @abstractmethod
    def set(self):
        pass


class KafkaGrabber(Task):
    def __init__(self, queue, conf):
        super().__init__(queue)
        if type(conf) is not dict:
            raise Exception('Error: conf must be a dict')
        self.conf = conf
        self.previousPulseId = None
        _dataType = self.conf['data_type']
        if _dataType == 'meta':
            self._decodeData = self._decodeMetaData
        elif _dataType == 'event':
            self._decodeData = self._decodeEventData
        else:
            raise Exception(f'Not supported data type: {_dataType}')

    def _getKafkaConsumer(self):
        consumer = KafkaConsumer(group_id=self.conf['group_id'],
                bootstrap_servers=self.conf['bootstrap_servers'],
                auto_offset_reset=self.conf['auto_offset_reset'],
                enable_auto_commit=self.conf['enable_auto_commit'],
                consumer_timeout_ms=self.conf['consumer_timeout_ms'])
        consumer.subscribe(topics=[self.conf['topic']])
        return consumer
    
    def _consumerMessage(self, consumer):
        for msg in consumer:
            self._decodeData(msg)
        pass

    def _checkPulseId(self, currentPulseId):
        if self.previousPulseId is not None:
            _inc = currentPulseId - self.previousPulseId
            if _inc > 1:
                print(f'[WARN] pulseId not continueous: pre {self.previousPulseId}, cur {currentPulseId}')
                return false
            return True
        return True


    def _decodeEventData(self, msg):
        _data = Data.Detector.EventData.EventData.GetRootAsEventData(msg.value, 0)
        _pulseId = _data.PulseId()
        _tofs = _data.TofAsNumpy()
        _pids = _data.PosAsNumpy()
        _events = np.c_[_pids, _tofs]
        self._checkPulseId(_pulseId)
        self.previousPulseId = _pulseId
        print(_pulseId)

    def _decodeMetaData(self, msg):
        _data = Data.Detector.MetaData.MetaData.GetRootAsMetaData(msg.value, 0)
        _pulseId = _data.PulseId()
        _deviceId = _data.DeviceId()
        _deviceName = _data.DeviceName()
        _value = _data.ValueAsNumpy()
        _timeNano = _data.TimeNanoAsNumpy()
        _timeSecond = _data.TimeSecondAsNumpy()
        self._checkPulseId(_pulseId)
        self.previousPulseId = _pulseId
        print(_value)

    def get(self, pars):
        pass

    def set(self):
        pass

    def run(self):
        consumer = self._getKafkaConsumer()
        self._consumerMessage(consumer)

class EpicsListener(Task):
    def __init__(self, queue):
        super().__init__(queue)
        pass

class EpicsWriter(Task):
    def __init__(self, queue):
        super().__init__(queue)
        pass

class RedisTask(Task):
    def __init__(self, queue):
        super().__init__(queue)
        pass

class DataBaseTask(Task):
    def __init__(self, queue):
        super().__init__(queue)
        pass
