#!/usr/bin/python
# -*- coding: utf-8 -*-

import sys
import os
import epics
from epics.ca import CAThread, withInitialContext
import redis
import datetime
import time
import random
import numpy as np
from scipy import exp
from scipy.optimize import leastsq
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import matplotlib.animation as anim
from drawnow import drawnow
import warnings
import signal
import json
from Queue import Queue
from threading import Thread
import threading
import logging
import multiprocessing
import atexit
import jsonArray
import paramiko
import neon

# write = sys.stdout.write
# flush = sys.stdout.flush

import warnings
warnings.filterwarnings('ignore')


def getTime():
    return time.strftime('%Y-%m-%d %H:%M:%S')


class getRedisServer:

    def __init__(
        self,
        ip,
        port,
        retry,
        ):

        self.status = False
        client = False
        self.server = None
        begin = time.time()
        for i in range(int(retry)):
            if not self.server:
                self.server = redis.Redis(host=ip, port=port,
                        password='sanlie;123', db=0,
                        socket_connect_timeout=1.0)
            else:
                try:
                    client = self.server.client_list()
                except redis.exceptions.ConnectionError:
                    pass

                    # logging.warning("Attempt to connect NEON again!")

                if client:
                    self.status = True

                    # print getTime(), "INFO: NEON connected"

                    break
                else:

                    # logging.exception("NEON not available")

                    print getTime(), 'WARNING: Redis retry'
            time.sleep(0.2)

    def getStatus(self):
        return self.status

    def getServer(self):
        return self.server


class getEpicsServer:

    def __init__(self, retry):

        epics.ca.use_initial_context()

        pvCommand = 'EXP_IB2_RM:soft:ctrl_cmd'
        pvStatus = 'EXP_IB2_RM:soft:analyse_stat'

        self.status = False

        self.epicsCommand = None
        self.epicsStatus = None

        for i in range(int(retry)):
            if not self.epicsCommand:
                try:
                    self.epicsCommand = epics.PV(pvCommand)
                    self.epicsStatus = epics.PV(pvStatus)
                    self.status = True
                    print getTime(), 'INFO: EPICS connected'
                    break
                except:
                    print getTime(), 'WARNING: EPICS retry'
            time.sleep(0.2)

    def getStatus(self):
        return self.status

    def getServer(self):
        return (self.epicsCommand, self.epicsStatus)


class getDroneServer:

    def __init__(self, ip, module):
        self.user = 'drone'
        self.ip = ip
        self.status = False
        self.sshserver = None
        self.clitrpc = None

    def create(self):
        _client = paramiko.SSHClient()
        _client.load_system_host_keys()
        for i in range(30):
            try:
                _client.connect(ip, username=user)
                self.sshserver = _client
                m_neonRedis = neon.Neon.NeonRedis(host='10.1.31.116',
                        port=9000, db=0, isWritable=True)
                clitpob = neon.Neon.NeonService.POBox(m_neonRedis,
                        '/MR/process/detector', '10.1.31.118:cockpit01')
                self.clitrpc = \
                    neon.Neon.NeonService.NeonRPC(sendPOBox=clitpob,
                        recvPOBox=clitpob)
                break
            except paramiko.AuthenticationException:
                print 'Warning: Drone Authentication Failed.'
            except:
                print 'Warning: Connect Drone Again.'
                time.sleep(0.2)

        if self.sshserver:
            _cmd = '/home/drone/workspace/drone/daqrun/module01'
            (stdin, stdout, stderr) = self.sshserver.exec_command(_cmd)
            print 'INFO: ', stdout
        else:
            print 'ERROR: ssh not available.'

    def start(self):
        self.clitrpc.execRPC('10.1.31.118:drone01', 'Start',
                             {'string': 'start', 'times': 3}, 'Print',
                             -1.0)
        pass

    def configure(self):
        self.clear()

    def stop(self):
        self.clear()

    def abort(self):
        _cmd = 'pkill -9 -u drone module01'
        (stdin, stdout, stderr) = self.sshserver.exec_command(_cmd)

    def clear(self):
        self.clitrpc.execRPC('10.1.31.118:drone01', 'Clear',
                             {'string': 'clear', 'times': 3}, 'Print',
                             -1.0)


class setHeartbeat(threading.Thread):

    def __init__(self, refreshtime):

        threading.Thread.__init__(self)

        pvHeart = 'EXP_IB2_RM:soft:analyse_hb'
        self.pvname = epics.PV(pvHeart)
        self.refreshtime = refreshtime

    def setStatus(self, _status):
        self.status = _status

    def getStatus(self):
        return self.status

    def process(self):

        # _data={'timestamp':time.strftime('%Y-%m-%d %H:%M:%S')}
        # self.pvname.put(json.dumps(_data))

        self.pvname.put(getTime(), use_complete=True)

    def run(self):
        while True:
            try:
                self.process()
            except:
                pass
            time.sleep(self.refreshtime)


class setStatus(threading.Thread):

    def __init__(self, refreshtime):

        threading.Thread.__init__(self)

        self.status = 'WAITING'
        self.pvname = epics.PV('EXP_IB2_RM:soft:analyse_stat')
        self.refreshtime = refreshtime

    def setStatus(self, _status):
        self.status = _status

    def getStatus(self):
        return self.status

    def process(self):
        self.pvname.put(self.getStatus(), use_complete=True)

    def run(self):
        while True:
            try:
                self.process()
            except:
                pass
            time.sleep(self.refreshtime)


class getEpicsCommand(threading.Thread):

    def __init__(
        self,
        threadHeartbeat,
        redisServer,
        epicsCommand,
        epicsStatus,
        ):
        threading.Thread.__init__(self)
        epics.ca.use_initial_context()
        self.epicsQueue = Queue()
        self.redisServer = redisServer
        self.epicsCommand = epicsCommand
        self.epicsStatus = epicsStatus
        self.threadHeartbeat = threadHeartbeat

        self.setTOF()

        self.MRstatusList = [
            'WAITING',
            'INITIALIZED',
            'READY',
            'ANALYSING',
            'ANALYSED',
            'ENDED',
            'ERROR',
            ]
        self.MRcommandList = ['CONF', 'STARTANALYSE', 'NEXTPOINT',
                              'STOP', 'RESET']
        self.MYstatusList = [
            'waiting',
            'unconfigured',
            'configuring',
            'ready',
            'running',
            'paused',
            'error',
            ]
        self.MYcommandList = [
            'configure',
            'unconfigure',
            'start',
            'pause',
            'resume',
            'stop',
            'abort',
            ]

        self.mantidHeartBeat = '/MR/heartbeat/mantid'
        self.pilotHeartBeat = '/MR/heartbeat/pilot'
        self.setCommand(6, 1, 1, 50)

        self.epicsCommand.get()
        self.epicsCommand.add_callback(self.getCommand)

    def getCommand(self, char_value=None, **kw):
        self.epicsQueue.put(char_value)

    def setCommand(
        self,
        cmd,
        stat,
        epics,
        retry,
        ):
        cmd = int(cmd)
        stat = int(stat)
        epics = int(epics)
        retry = int(retry)
        if epics != -1:
            _status = self.MRstatusList[epics]
        self.redisServer.set('/MR/control/command', json.dumps(cmd))

        _count = -1
        while True:
            if retry != -1:
                _count += 1
                if _count > retry:
                    break

            mantidStatus = self.getMantidStatus()
            pilotStatus = self.getPilotStatus()
            if pilotStatus == stat and mantidStatus == stat:
                print '   Waiting: pilot,  ', \
                    self.MYstatusList[pilotStatus]
                print '   Waiting: mantid, ', \
                    self.MYstatusList[mantidStatus]
                if epics != -1:
                    self.threadHeartbeat.setStatus(_status)
                    self.epicsStatus.put(_status)
                    print getTime(), 'INFO: send state, ', \
                        self.MRstatusList[epics]

                print getTime(), 'INFO: set state, ', \
                    self.MRstatusList[epics]
                break
            else:
                print '   Waiting: pilot,  ', \
                    self.MYstatusList[pilotStatus]
                print '   Waiting: mantid, ', \
                    self.MYstatusList[mantidStatus]
            time.sleep(0.5)

    def getDroneStatus(self):
        droneStatus = self.redisServer.get(self.droneHeartBeat)
        droneStatus = json.load(droneStatus)['status']
        droneStatus = json.load(droneStatus)['timestamp']
        return droneStatus

    def getMantidStatus(self):
        try:
            mantidStatus = self.redisServer.get(self.mantidHeartBeat)
            mantidStatus = json.loads(mantidStatus)['status']
        except:
            mantidStatus = 0

        return mantidStatus

    def getPilotStatus(self):
        try:
            pilotStatus = self.redisServer.get(self.pilotHeartBeat)
            pilotStatus = json.loads(pilotStatus)['status']
        except:
            pilotStatus = 0

        return pilotStatus

    def setTOF(self):
        _tmin = 100
        _tmax = 40000
        _bin = 16.0
        ntof = int((_tmax - _tmin) / _bin)
        self.tofmantid = np.zeros(ntof + 1)
        for i in range(ntof + 1):
            self.tofmantid[i] = int(_tmin + (i - 1) * _bin - 0.5 * _bin)

        _json = jsonArray.jsonEncoder(self.tofmantid.tolist())
        self.redisServer.set('/MR/workspace/detector/module01/tof',
                             _json)

    def getEpicsConfigure(self):
        run_type = epics.caget('EXP_IB2_RM:soft:run_type')
        run_no = epics.caget('EXP_IB2_RM:soft:run_no')
        start_pos = epics.caget('EXP_IB2_RM:soft:start_pos')
        end_pos = epics.caget('EXP_IB2_RM:soft:end_pos')
        scan_point = epics.caget('EXP_IB2_RM:soft:scan_points')
        proton_charge = epics.caget('EXP_IB2_RM:soft:proton_charge')
        end_time = epics.caget('EXP_IB2_RM:soft:end_time')
        userID = epics.caget('EXP_IB2_RM:soft:user_id')
        proposalID = epics.caget('EXP_IB2_RM:soft:proposal_id')

        _points = []
        _step = (float(end_pos) - float(start_pos)) \
            / (float(scan_point) - 1)
        for i in range(int(scan_point)):
            _points.append(float(start_pos) + i * _step)

        _configure = {}
        _configure['runNo'] = run_no
        _configure['userID'] = userID
        _configure['proposalID'] = proposalID
        _configure['runType'] = run_type
        _configure['positionPoint'] = _points
        _configure['protonCharge'] = proton_charge
        _configure['tof'] = self.tofmantid.tolist()

        _configure = json.dumps(_configure, ensure_ascii=False)
        self.redisServer.set('/MR/control/configure', _configure)
        print getTime(), 'INFO: Set configure, ', run_no

    def run(self):
        while True:
            char_value = self.epicsQueue.get()
            print getTime(), 'INFO: Receive command ', char_value

            if char_value == 'CONF':
                self.getEpicsConfigure()
                time.sleep(0.1)
                self.setCommand(0, 3, 2, -1)
            elif char_value == 'STARTANALYSE':
                self.setCommand(2, 4, 3, -1)
            elif char_value == 'NEXTPOINT':
                self.setCommand(5, 3, 2, -1)
            elif char_value == 'RESET':
                self.setCommand(6, 1, 1, -1)
            elif char_value == 'FINISH':
                self.setCommand(5, 3, 4, -1)
            elif char_value == 'STOP':
                self.setCommand(5, 3, -1, -1)
                self.setCommand(1, 1, 5, -1)
            else:
                print getTime(), 'WARNING: Invalid command ', char_value

    def stop(self):
        self.p.clear_callbacks()


if __name__ == '__main__':
    print '==================='
    print 'Welcome to CockPit!'
    print '        RM        '
    print '==================='

    # connect Redis

    try:
        myredis = getRedisServer('10.1.31.116', 9000, 10)
        if myredis.getStatus():
            redisServer = myredis.getServer()
            print getTime(), 'INFO: Redis Connected.'
    except:
        redisServer = None
        print getTime(), 'ERROR: Redis Failed.'

    print '==================='

    # connect Epics

    try:
        myepics = getEpicsServer(10)
        if myepics.getStatus():
            (epicsCommand, epicsStatus) = myepics.getServer()
            print getTime(), 'INFO: EPICS Connected.'
    except:
        epicsCommand = None
        epicsStatus = None
        print getTime(), 'ERROR: EPICS Failed.'

    print '==================='

    # start drone

    # start heartbeat

    print '==================='
    try:
        threadHeartbeat = setHeartbeat(1.0)
        threadHeartbeat.setDaemon(True)
        threadHeartbeat.start()
        print getTime(), 'INFO: Heartbeat connected.'
    except:
        print getTime(), 'ERROR: Heartbeat Failed.'

    time.sleep(0.1)

    # set status

    print '==================='
    try:
        threadSetStatus = setStatus(1.0)
        threadSetStatus.setDaemon(True)
        threadSetStatus.start()
        print getTime(), 'INFO: State thread created.'
    except:
        print getTime(), 'ERROR: State thread Failed.'
    time.sleep(0.1)

    # receive epics command

    print '==================='
    try:
        threadGetEpics = getEpicsCommand(threadSetStatus, redisServer,
                epicsCommand, epicsStatus)
        threadGetEpics.setDaemon(True)
        threadGetEpics.start()
        print getTime(), 'INFO: Command thread created.'
    except:
        print getTime(), 'ERROR: Command thread Failed.'
    time.sleep(0.1)

    print '==================='

    try:
        threadGetEpics.join()
        threadSetStatus.join()
        threadHeartbeat.join()
    except:
        pass

    print '==================='
    threadSetStatus.setStatus('ERROR')
    epicsStatus.put('ERROR')
