#!/usr/bin/env python3

from IO.Kafka import getDetectorData, getGlobalData, KafkaGrabber
from  Utils.Histogram import Hist2D 

from Tasks.TaskManager import TaskManager
from IO.RedisHelper import RedisHelper
from IO.conf_helper import ConfHelper

from time import sleep
import multiprocessing as mp
import time
import numpy as np
import logging
import re
import sys

logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

def getRedisPath(conf):
    _topic = conf['topic']
    _pattern = '^BL16-Detector-Bank(\d{2})-Module(\d{2})-EventData$'
    _r = re.match(_pattern, _topic)
    _p = None
    if _r is not None:
        _p = '1' + _r.group(1) + _r.group(2)

    return '/MPI/workspace/detector/module' + _p + '/value'

def getModuleName(conf):
    _topic = conf['topic']
    _pattern = '^BL16-Detector-(Bank\d{2}-Module\d{2})-EventData$'
    _r = re.match(_pattern, _topic)
    _p = None
    if _r is not None:
        _p = _r.group(1)

    return _p

def getPixelIdBase(conf):
    _topic = conf['topic']
    _pattern = '^BL16-Detector-Bank(\d{2})-Module(\d{2})-EventData$'
    _r = re.match(_pattern, _topic)
    if _r is not None:
        _p = '1' + _r.group(1) + _r.group(2) + '0000'
        return int(_p)
    raise Exception(f'Cannot find pixel ID base for topic: {_topic}')
from math import remainder

def onlineDetector(conf, redis_conf):
    print('onlineDetector is running')
    appName = getModuleName(conf)
    redisPath = getRedisPath(conf)

    ip_port = (redis_conf['host'], redis_conf['port'])
    passwd = redis_conf['password']
    rds= RedisHelper(ip_port, passwd, 10)

    _pixels = conf['module_length'] // 5 * 8
    _pixelIdBase = getPixelIdBase(conf)
    _lowId = _pixelIdBase + 1
    _highId = _pixelIdBase + _pixels

    tofBinSize = 2500
    pixBinSize = _pixels

    hist = Hist2D(tofBinSize, pixBinSize, [[0-0.5, 40000-0.5], [-0.5 + _lowId, _highId + 0.5]])
    task = KafkaGrabber(conf, 100000)
    consumer = task.getKafkaConsumer()
    #consumer.seek_to_beginning()
    #consumer.seek_to_end()
    _pts = list(consumer.assignment())
    #consumer.seek_to_end(_pts[0])
    #consumer.seek(_pts[0], consumer.position(_pts[0]) - 25*60*60)
    dataType = task.getDataType()
    startTime = time.time()
    i = time.time()
    startPulseID = 0
    count = 0
    total = 0
    for msg in consumer:
        pulseId, tof, pid = getDetectorData(msg)
        total += 1

        pulseStart = time.time()

        if startPulseID == 0:
            startPulseID=pulseId
        #print(pulseId, tof.shape, pid.min(), pid.max())

        if not tof.size * pid.size:
            count += 1
            logging.warning(appName+ ' data with zero size')
            continue
        if tof.size != pid.size:
            logging.warning(appName+' tof and pid arrays differ in size')
            continue

        if not hist.fill(tof, pid):
            print(tof)
            print(pid)
            print('tof type', type(tof), tof.size)
            print('pid type', type(pid), pid.size)
            print('tof range', tof.min(),tof.max())
            print('pid range', pid.min(),pid.max() )
            logging.warning('{} contains invalid data, range of tof(), range of pid{}'.format(\
                            appName, (tof.min(),tof.max()), (pid.min(),pid.max()) ))
            return

        pulseEnd=time.time()

        if (time.time()-i)>5.:
            histVal=hist.hist
            rds.writeNumpyArray(redisPath, histVal)
            #data_back=rds.readNumpyArray(redisPath)
            #np.testing.assert_array_equal(histVal,data_back)
            #print('passed redis numpy read write test')
            logging.info(f'Empty pulses: {count}/{total}, {count*100/total}%')
            redisTime=time.time()
            totEvent = histVal.sum()
            logging.info('{}: process pulse {} used {:5.2f}ms, histogram accumulated {}, rate {:5.2f}kHz, redis {:5.2f}ms'.format(appName, pulseId, (pulseEnd-pulseStart)*1e3, totEvent, \
                          totEvent/((pulseId-startPulseID)*40.), 1e3*(redisTime-pulseEnd)  ))
            i=time.time()

if __name__ == '__main__':
    conf_helper = ConfHelper()
    conf = conf_helper.get_conf()
    redis_conf = conf['redis']
    ip_port = (redis_conf['host'], redis_conf['port'])
    passwd = redis_conf['password']
    rds = RedisHelper(ip_port, passwd, 10)
    del conf['redis']
    _r = conf_helper.write_conf_to_redis(rds, '/MPI/conf', conf)
    if not _r:
        print('Write conf to Redis failed.')

    conf = conf_helper.get_conf_from_redis(rds, '/MPI/conf')

    tasks = {}
    for _item in conf['modules']['detector_modules']:
        if _item.get('enabled') is not None:
            if _item['enabled'] == False:
                print(f'{_item["topic"]} is not enabled, skip.')
                continue
        _taskName = 'DET-' + getModuleName(_item)
        tasks[_taskName] = [onlineDetector, (_item, redis_conf, )]

    if len(tasks) == 0:
        sys.exit(0)

    mp.set_start_method('fork')
    print('Starting task manager')
    man = TaskManager(tasks)
    print("main thread sleeping")
    sleep(100000000)
    man.stopProc()
    print ('Main exit')
