#!/usr/bin/env python3

from IO.Kafka import getDetectorData, getGlobalData, KafkaGrabber
from  Utils.Histogram import Hist2D 

from Tasks.TaskManager import TaskManager
from IO.RedisHelper import RedisHelper
from IO.conf_helper import ConfHelper

from time import sleep
import multiprocessing as mp
import time
import numpy as np
import logging
import re
import sys

logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

def getRedisPath(conf):
    _topic = conf['topic']
    _pattern = '^BL16-Detector-Bank(\d{2})-Module(\d{2})-EventData$'
    _r = re.match(_pattern, _topic)
    _p = None
    if _r is not None:
        _p = '1' + _r.group(1) + _r.group(2)

    return '/MPI/workspace/detector/module' + _p + '/value'

def getModuleName(conf):
    _topic = conf['topic']
    _pattern = '^BL16-Detector-(Bank\d{2}-Module\d{2})-EventData$'
    _r = re.match(_pattern, _topic)
    _p = None
    if _r is not None:
        _p = _r.group(1)

    return _p

def getPixelIdBase(conf):
    _topic = conf['topic']
    _pattern = '^BL16-Detector-Bank(\d{2})-Module(\d{2})-EventData$'
    _r = re.match(_pattern, _topic)
    if _r is not None:
        _p = '1' + _r.group(1) + _r.group(2) + '0000'
        return int(_p)
    raise Exception(f'Cannot find pixel ID base for topic: {_topic}')
from math import remainder

def onlineDetector(conf, redis_conf):
    print('onlineDetector profiling is running')
    appName = getModuleName(conf)

    _pixels = conf['module_length'] // 5 * 8
    _pixelIdBase = getPixelIdBase(conf)
    _lowId = _pixelIdBase + 1
    _highId = _pixelIdBase + _pixels

    task = KafkaGrabber(conf, 100000)
    consumer = task.getKafkaConsumer()
    #consumer.seek_to_beginning()
    #consumer.seek_to_end()
    _pts = list(consumer.assignment())
    consumer.seek_to_end(_pts[0])
    consumer.seek(_pts[0], consumer.position(_pts[0]) - 25*60*60)
    dataType = task.getDataType()

    interval = 25*30
    previous_pulseId = None
    startTime = time.time()
    j = 0
    count = 0
    total = 0
    error_pids = np.array([], dtype='uint32')
    error_tofs = np.array([], dtype='uint32')
    disc_pulses = []
    error_pulses = []
    error_events = 0
    total_events = 0
    for msg in consumer:
        skip = False
        pulseId, tof, pid = getDetectorData(msg)

        total += 1

        if previous_pulseId is not None:
            _d = pulseId - previous_pulseId
            if _d != 1:
                if _d > 0:
                    disc_pulses.append([previous_pulseId, pulseId])
                else:
                    error_pulses.append([previous_pulseId, pulseId])
        previous_pulseId = pulseId

        if not tof.size * pid.size:
            count += 1
            skip = True
        if tof.size != pid.size:
            error_events += 1
            skip = True

        if not skip:
            total_events += tof.shape[0]
            error_pids = np.concatenate((error_pids, pid[pid > _highId]), axis=0)
            error_pids = np.concatenate((error_pids, pid[pid < _lowId]), axis=0)
            error_tofs = np.concatenate((error_tofs, tof[tof > 40000]), axis=0)
            error_tofs = np.concatenate((error_tofs, tof[tof < 0]), axis=0)

        if total > interval:
            print(f'[{appName}] Current pulse: {pulseId}')
            _et = time.time() - startTime
            t = interval/25/60
            logging.info(f'[{appName}] Proccessing speed: {total/_et} pulses/second, {total_events/_et} events/second')
            logging.info(f'[{appName}] In {t} minutes: Empty pulses: {count}/{total}, {count*100/total}%. Error pulses: {error_events}/{total}, {error_events*100/total}%.')
            if len(disc_pulses) > 0:
                logging.warning(f'[{appName}] Intermittent pulses: {len(disc_pulses)}, {len(disc_pulses)*100/total}%, {disc_pulses}')
            if len(error_pulses) > 0:
                logging.warning(f'[{appName}] Error pulses: {len(error_pulses)}, {len(error_pulses)*100/total}%, {error_pulses}')
            if error_pids.shape[0] > 0:
                logging.warning(f'[{appName}] Error pixels: {error_pids.shape[0]}, {set(error_pids)}')
            if error_tofs.shape[0] > 0:
                logging.warning(f'[{appName}] Error tofs: {error_tofs.shape[0]}, {set(error_tofs)}')

            total = 0
            count = 0
            error_pids = np.array([], dtype='uint32')
            error_tofs = np.array([], dtype='uint32')
            disc_pulses = []
            error_pulses = []
            error_events = 0
            total_events = 0
            startTime = time.time()

if __name__ == '__main__':
    conf_helper = ConfHelper()
    conf = conf_helper.get_conf()
    redis_conf = conf['redis']
    ip_port = (redis_conf['host'], redis_conf['port'])
    passwd = redis_conf['password']
    rds = RedisHelper(ip_port, passwd, 10)

    conf = conf_helper.get_conf_from_redis(rds, '/MPI/conf')

    tasks = {}
    for _item in conf['modules']['detector_modules']:
        _taskName = 'DET-' + getModuleName(_item)
        tasks[_taskName] = [onlineDetector, (_item, redis_conf)]

    if len(tasks) == 0:
        sys.exit(0)

    mp.set_start_method('fork')
    print('Starting task manager')
    man = TaskManager(tasks)
    print("main thread sleeping")
    sleep(100000000)
    man.stopProc()
    print ('Main exit')
