from mantid.simpleapi import *
import jsonArray
import threading
from Queue import Queue
import time
import matplotlib
import numpy as np

class dataGet(threading.Thread):

    def __init__(self, threadID, neonRedis, refreshtime, bankList, monitorList):
        threading.Thread.__init__(self)
        self.paused = False
        self.pause_cond = threading.Condition(threading.Lock())
        self.thread_stop = False

        self.refreshtime=refreshtime
        self.neonRedis=neonRedis

        self.bankList=bankList
        self.monitorList=monitorList

        self.nbank=len(self.bankList)
        self.nmonitor=len(self.monitorList)
        print self.nbank, self.nmonitor
        #set path of data
        self.neonpathDetPid=[[] for i in range(self.nbank)]
        self.neonpathDetTof=[[] for i in range(self.nbank)]
        self.neonpathDetValue=[[] for i in range(self.nbank)]
        self.neonpathMonPid=[[] for i in range(self.nmonitor)]
        self.neonpathMonTof=[[] for i in range(self.nmonitor)]
        self.neonpathMonValue=[[] for i in range(self.nmonitor)]

        self.detWorkspace=[]
        self.monWorkspace=[]
        self.modulenum_bank=[]
        self.modulenum_mon=[]
        det=bankList.keys()
        for m in range(self.nbank):
            self.detWorkspace.append(det[m])
        
        for i in range(len(self.detWorkspace)):
            _moduleList=bankList[self.detWorkspace[i]]
            self.modulenum_bank.append(len(_moduleList))
            for j in _moduleList:
                self.neonpathDetPid[i].append("/GPPD/workspace/data/module"+str(j).zfill(2)+"/pid")
                self.neonpathDetTof[i].append("/GPPD/workspace/data/module"+str(j).zfill(2)+"/tof")
                self.neonpathDetValue[i].append("/GPPD/workspace/data/module"+str(j).zfill(2)+"/value")

        mon=monitorList.keys()
        for n in range(self.nmonitor):
            self.monWorkspace.append(mon[n])

        for i in range(len(self.monWorkspace)):
            _moduleList=monitorList[self.monWorkspace[i]]
            self.modulenum_mon.append(len(_moduleList))
            for j in _moduleList:
                self.neonpathMonPid[i].append("/GPPD/workspace/data/monitor"+str(j).zfill(2)+"/pid")
                self.neonpathMonTof[i].append("/GPPD/workspace/data/monitor"+str(j).zfill(2)+"/tof")
                self.neonpathMonValue[i].append("/GPPD/workspace/data/monitor"+str(j).zfill(2)+"/value")

        self.detPid = self.getPid(bankList, self.neonpathDetPid)
        print "success in get det pid data"
        print type(self.detPid)
        print type(self.detPid[0][1])
        self.detTof = self.getTof(bankList, self.neonpathDetTof)
        print "success in get det tof data"
        print type(self.detTof[0][1])

        self.monPid = self.getPid(monitorList, self.neonpathMonPid)
        print "success in get mon pid data"
        self.monTof = self.getTof(monitorList, self.neonpathMonTof)
        print "success in get mon tof data"

        #_nbank=len(self.detPid)
        #_nmonitor=len(self.monPid)
        #for i in _nbank: self.detWs=np.zeros((_nbank, len(self.detPid[i]), len(self.detTof[i])))
        #for i in _nmonitor: self.monWs=np.zeros((_nmonitor, len(self.monPid[i]), len(self.monTof[i])))

    def getPid(self, bank, path):
    # usage: self.getPid(self.bankList, self.neonpathDetPid)
    # usage: self.getPid(self.monitorList, self.neonpathMoPid)
        nbank=len(bank)
        pid=[[] for k in range(nbank)]
        for i in range(nbank):
            for _path in path[i]:
                while True:
                    _json=self.neonRedis.get(_path)
                    if _json is None:
                        print "Warning: No pid data!!!"
                    else:
                        _array=jsonArray.jsonDecoder(_json)
                        if _array is None:
                            print "Warning: Incompleted pid data!!!"
                        else:
                            #print len(_array)
                            for k in _array:
                                pid[i].append(k)
                            break

            print "the pixel number in  bankpid[i] is: "+str(len(pid[i]))        
        if (i!=nbank-1):
            print "error: iteration number is: "+str(i)

        #for i in range(nbank):
            #pid[i]=np.array(pid[i])
        #print pid
        return pid

    def getTof(self, bank, path):
        nbank=len(bank)
        tof=[[] for k in range(nbank)]
        for i in range(nbank):
            for _path in path[i]:
                _json=self.neonRedis.get(_path)
                if _json is None:
                    print "Error: No tof data!!!"
                else:
                    _array=jsonArray.jsonDecoder(_json)
                    if (len(_array)< 4377):
                        print "error: tof unfinished!"
                    for k in _array:
                        tof[i].append(k)
                    break
            #print len(tof[i])
        if (i!=nbank-1):
            print "error: iteration number is: "+str(i)
        #for i in range(nbank):
            #tof[i]=np.array(tof[i])
        #print tof
        return tof

    def getValue(self, bank, path, modulenum):
        be=time.time()
        nbank=len(bank)
        value=[[] for k in range(nbank)]
        _tmp=[]
        st=0
        for i in range(nbank):
            #mvalue=[[] for j in range(modulenum[i])]
            #print "the length of mvalue is "+str(len(mvalue))
            for _path in path[i]:
                _json=self.neonRedis.get(_path)
                _array=jsonArray.jsonDecoder(_json)
                for ix in range(len(_array)):
                    #specNo=144+ix
                    for iy in range(len(_array[ix])):
                        value[i].append(_array[ix,iy])
                        #detws[[i][specNo][iy]]=_array[ix][iy]

        #for i in nbank:
        #loadCSNSraw(detWorksapce[i], detws[i])

            #print len(_tmp)
                #value[i].extend(_tmp)
            #print value[i]
        print len(value)
        print (time.time()-be)
        return value               


    def getDataInWorkspace(self):

        detValue=self.getValue(self.bankList, self.neonpathDetValue, 4)
        monValue=self.getValue(self.monitorList, self.neonpathMonValue, 4)
        wsName=[]
        _monpid=self.monPid[0]
        _montof=self.monTof[0]
        _monv=monValue[0]
        for i in range(len(self.bankList)):
            _bankpid=self.detPid[i]
            _banktof=self.detTof[i]
            _bankv=detValue[i]
        
         
            _n1=[]
            _n2=[]
            _n3=[]
            _n4=[]
            _n5=[]
            _n6=[]

            for j in range(len(_bankpid)):
                _n1.append(int(_bankpid[j]))
            for j in range(len(_banktof)):
                _n2.append(int(_banktof[j]))
            for j in range(len(_bankv)):
                _n3.append(int(_bankv[j]))
            for j in range(len(_monpid)):
                _n4.append(int(_monpid[j]))
            for j in range(len(_montof)):
                _n5.append(int(_montof[j]))
            for j in range(len(_monv)):
                _n6.append(int(_monv[j]))
            _n0=self.detWorkspace[i]            

            LoadCSNSRaw(OutputWorkspace=_n0, PixelID_bank=_n1, TimeOfFlight_bank=_n2, Counts_bank=_n3, PixelID_monitor=_n4, TimeOfFlight_monitor=_n5, Counts_monitor=_n6)
            wsName.append(_n0)
            print "outputWS= ", wsName[i]        
            

    def reduceData(self, detPid, detTof, detValue, monPid, monTof, monValue):
        _monitorpid=[] 
        _monitortof=[] 
        _monitorvalue=[] 
        for i in range(len(self.monWorkspace)):
            if (self.monWorkspace[i]=="monitor2"):
                _monitorpid.extend(monPid[i])
                _monitortof.extend(monTof[i])
                _monitorvalue.extend(monValue[i])
                break

        xpid=[[] for i in range(len(self.detWorkspace))]        
        ypid=[[] for i in range(len(self.detWorkspace))]        
        value=[[] for i in range(len(self.detWorkspace))]        
        timeOfFlight=[[] for i in range(len(self.detWorkspace))]        
        totalCounts=[[] for i in range(len(self.detWorkspace))]        

        for i in range(len(self.detWorkspace)):
            _bankname=self.detWorkspace[i]
            _bankpid=detPid[i]
            _banktof=detTof[i]
            _bankvalue=detValue[i]

            LoadCSNSRaw(OutputWorkspace=_bankname, PixelID_bank=_bankpid, TimeOfFlight_bank=_banktof, Counts_bank=_bankvalue, PixelID_monitor=_monitorpid, TimeOfFlight_monitor=_monitortof, Counts_monitor=_monitorvalue)
            LoadInstrument(Workspace=_bankname+'_1', Filename='/home/dur/work/gppd_reduction/code_reduction/GPPD_IDF_bank.xml', RewriteSpectraMap='True')
            LoadCalFile(InputWorkspace=_bankname+'_1', CalFilename='/home/dur/work/gppd_reduction/code_reduction/calfile.cal', WorkspaceName='GPPDraw')
            MaskDetectors(Workspace=_bankname+'_1', MaskedWorkspace='GPPDraw_mask')
            AlignDetectors(InputWorkspace=_bankname+'_1', OutputWorkspace=_bankname+'test_alignD', OffsetsWorkspace='GPPDraw_offsets')
            Rebin(InputWorkspace=_bankname+'test_alignD', OutputWorkspace=_bankname+'test_rebin', Params='0.1,0.005,2.5')
            DiffractionFocussing(InputWorkspace=_bankname+'test_rebin', OutputWorkspace=_bankname+'test_dfocus', GroupingWorkspace='GPPDraw_group')
            ConvertUnits(InputWorkspace=_bankname+'test_dfocus', OutputWorkspace=_bankname+'test_wave', Target='Wavelength')
            Rebin(InputWorkspace=_bankname+'test_wave', OutputWorkspace=_bankname+'test_wave_rebin', Params='0.85,0.005,4.25')
            ConvertUnits(InputWorkspace=_bankname+'test_dfocus', OutputWorkspace=_bankname+'test_dfocus_tof', Target='TOF')
            # get x-axis, y-axis, counts data
            name1=mtd[_bankname+'test_1']
            num=name1.getNumberHistograms()

            for j in range(num):
                det=name1.getDetector(j)
                position=det.getPos()
                xpid[i].append(position.X())
                ypid[i].append(position.Y())
                value[i].append(name1.counts(j))

            # get tof-counts 2D spectrum
            name2=mtd[_bankname+'test_dfocus_tof']
            counts=name2.extractY()
            timebin=name2.extractX()
            for n in range(len(counts)):
                totalCounts[i].append(counts[n])
            for m in range(len(timebin)):
                timeOfFlight[i].append(timebin[m])

        print "finish reduction!"
        return xpid, ypid, value, timeOfFlight, totalCounts

    def setValue(self, value, detname):
        #for i in range(self.nbank):
            
        pass


    def process(self):
        be=time.time()
        print "process"

        #detValue=self.getValue(self.bankList, self.neonpathDetValue) 
        #print "finish get det data in process"

        #detPid, detTof, detValue, monPid, monTof, monValue=self.getData()
        #xpid, ypid, value, timeOfFlight, totalCounts=self.reduceData(detPid, detTof, detValue, monPid, monTof, monValue)




        '''
        if (self.neonpathMonValue != None):
            _jsonZ=self.neonRedis.get(self.neonpathMonValue)
            _dataZ=jsonArray.jsonDecoder(_jsonZ)
            self.arrayQ.put(_dataZ) 
            self.arrayQ.task_done()

        print (time.time()-be)
        '''
    def run(self):
        while True:
            self.process()
            print "get"
            time.sleep(self.refreshtime)

    def pause(self):
        #self.paused = True
        #self.pause_cond.acquire()
        self.can_run.set()
        with self.arrayZ.mutex:
            self.arrayZ.queue.clear()
            self.valueC.queue.clear()

    def resume(self):
        #self.paused = False
        #self.pause_cond.notify()
        #self.pause_cond.release()
        self.can_run.clear()

    def stop(self):
        self.paused = True
        self.pause_cond.acquire()
        self.can_run.set()


class dataSet(threading.Thread):

    def __init__(self,threadID,neonRedis,refreshtime,neonpathX,neonpathY,neonpathZ,neonpathC,xbin,ybin,setQueue):

        threading.Thread.__init__(self)
        self.paused = False
        self.pause_cond = threading.Condition(threading.Lock())
        self.thread_stop = False

        self.neonRedis=neonRedis
        self.neonpathX=neonpathX
        self.neonpathY=neonpathY
        self.neonpathZ=neonpathZ
        self.neonpathC=neonpathC
        self.setQueue=setQueue
        self.refreshtime=refreshtime

        self.setXaxis(xbin)
        self.setYaxis(ybin)

    def setXaxis(self, xbins):
        _json_data=jsonArray.jsonEncoder(xbins)
        self.neonRedis.set(self.neonpathX, _json_data)

    def setYaxis(self, ybins):
        _json_data=jsonArray.jsonEncoder(ybins)
        self.neonRedis.set(self.neonpathY, _json_data)

    def process(self):
        be=time.time()

        if (self.neonpathC != None):
            pass
            #_dataC=self.valueC.get()
            #self.neonRedis.set(self.neonpathCvalue, _dataC)

        if (self.neonpathZ != None):
            if self.setQueue.empty():
                return

        _arrayZ=self.setQueue.get()
        _jsonZ=jsonArray.jsonEncoder(_arrayZ)
        self.neonRedis.set(self.neonpathZ, _jsonZ)

        print (time.time()-be)

    def run(self):
        while True:
            self.process()
        print "set"
        time.sleep(self.refreshtime)

    def destroy(self):
        self.neonRedis.delete(self.neonX)
        self.neonRedis.delete(self.neonY)
        self.neonRedis.delete(self.neonZ)
        self.neonRedis.delete(self.neonC)

    def pause(self):
        self.paused = True
        self.pause_cond.acquire()

    def resume(self):
        self.paused = False
        self.pause_cond.notify()
        self.pause_cond.release()

    def stop(self):
        self.thread_stop = True


