from mantid.simpleapi import *
import jsonArray
import threading
from Queue import Queue
import time
import matplotlib
import numpy as np
from canvas import *
import redis
import sys
import datetime
import cStringIO
import pickle
import Image

class dataGet(threading.Thread):

    def __init__(self, threadID, neonRedis, refreshtime, bankList, monitorList):
        threading.Thread.__init__(self)
        self.paused = False
        self.pause_cond = threading.Condition(threading.Lock())
        self.thread_stop = False

        self.refreshtime=refreshtime
        self.neonRedis=neonRedis

        self.bankList=bankList
        self.monitorList=monitorList

        self.nbank=len(self.bankList)
        self.nmonitor=len(self.monitorList)
        print self.nbank, self.nmonitor
        #set path of data
        self.neonpathDetPid=[[] for i in range(self.nbank)]
        self.neonpathDetTof=[[] for i in range(self.nbank)]
        self.neonpathDetValue=[[] for i in range(self.nbank)]
        self.neonpathMonPid=[[] for i in range(self.nmonitor)]
        self.neonpathMonTof=[[] for i in range(self.nmonitor)]
        self.neonpathMonValue=[[] for i in range(self.nmonitor)]

        self.detWorkspace=[]
        self.monWorkspace=[]
        self.modulenum_bank=[]
        self.modulenum_mon=[]


        #get data path from neon
        det=bankList.keys()
        for m in range(self.nbank):
            self.detWorkspace.append(det[m])
        
        for i in range(len(self.detWorkspace)):
            _moduleList=bankList[self.detWorkspace[i]]
            self.modulenum_bank.append(len(_moduleList))
            for j in _moduleList:
                self.neonpathDetPid[i].append("/GPPD/workspace/data/module"+str(j).zfill(2)+"/pid")
                self.neonpathDetTof[i].append("/GPPD/workspace/data/module"+str(j).zfill(2)+"/tof")
                self.neonpathDetValue[i].append("/GPPD/workspace/data/module"+str(j).zfill(2)+"/value")

        mon=monitorList.keys()
        for n in range(self.nmonitor):
            self.monWorkspace.append(mon[n])

        for i in range(len(self.monWorkspace)):
            _moduleList=monitorList[self.monWorkspace[i]]
            self.modulenum_mon.append(len(_moduleList))
            for j in _moduleList:
                self.neonpathMonPid[i].append("/GPPD/workspace/data/monitor"+str(j).zfill(2)+"/pid")
                self.neonpathMonTof[i].append("/GPPD/workspace/data/monitor"+str(j).zfill(2)+"/tof")
                self.neonpathMonValue[i].append("/GPPD/workspace/data/monitor"+str(j).zfill(2)+"/value")

        # set data path to neon
        self.neonpathDetImg=[]
        self.neonpathDetXYImg=[]
        self.neonpathDetTofSet=[]
        self.neonpathDetCountsSet=[]
        
        for i in range(len(self.detWorkspace)):
            self.neonpathDetImg.append("/GPPD/workspace/MantidData/bank"+str(i+1).zfill(2)+"/pid_image")
            self.neonpathDetXYImg.append("/GPPD/workspace/MantidData/bank"+str(i+1).zfill(2)+"/xy_image")
            self.neonpathDetTofSet.append("/GPPD/workspace/MantidData/group"+str(i+1).zfill(2)+"/tof")
            self.neonpathDetCountsSet.append("/GPPD/workspace/MantidData/group"+str(i+1).zfill(2)+"/counts")
        
        self.neonpathMonImg=[]
        self.neonpathMonTofSet=[]
        self.neonpathMonCountsSet=[]
        self.neonpathMonXSet=[]
        self.neonpathMonYSet=[]
        self.neonpathMonValueSet=[]
        for i in range(len(self.monWorkspace)):
            self.neonpathMonImg.append("/GPPD/workspace/MantidData/monitor"+str(i+1).zfill(2)+"/pid_image")
            self.neonpathMonTofSet.append("/GPPD/workspace/MantidData/monitor"+str(i+1).zfill(2)+"/tof")
            self.neonpathMonCountsSet.append("/GPPD/workspace/MantidData/monitor"+str(i+1).zfill(2)+"/counts")
            self.neonpathMonXSet.append("/GPPD/workspace/MantidData/monitor"+str(i+1).zfill(2)+"/axis_x")
            self.neonpathMonYSet.append("/GPPD/workspace/MantidData/monitor"+str(i+1).zfill(2)+"/axis_y")
            self.neonpathMonValueSet.append("/GPPD/workspace/MantidData/monitor"+str(i+1).zfill(2)+"/value")
        


        self.detPid = self.getPid(bankList, self.neonpathDetPid)
        print "success in get det pid data"
        print type(self.detPid)
        print type(self.detPid[0][1])
        self.detTof = self.getTof(bankList, self.neonpathDetTof)
        print "success in get det tof data"
        print type(self.detTof[0][1])

        self.monPid = self.getPid(monitorList, self.neonpathMonPid)
        print "success in get mon pid data"
        _path3="/GPPD/workspace/data/monitor01/tof"
        _json=self.neonRedis.get(_path3)
        if _json is None:
            print "No data!"
        self.monTof=jsonArray.jsonDecoder(_json)
        print "the length of mon tof : ",len(self.monTof) 



        #self.monTof = self.getTof(monitorList, self.neonpathMonTof)
        print "success in get mon tof data"
        print "the length of pid in monitor is ", len(self.monPid)
        #print "the length of tof in monitor is ", len(self.monTof[0])
        self.createCanvas()

    def getCanvas(self):
        return self.canvas

    def createCanvas(self):   
        self.canvas = figureCanvas(
            16,
            8,
            72,
            'Neutron Distribution',
            'SpecNo ',
            'TOF / us',
            True,
            )

    def getPid(self, bank, path):
    # usage: self.getPid(self.bankList, self.neonpathDetPid)
    # usage: self.getPid(self.monitorList, self.neonpathMoPid)
        nbank=len(bank)
        pid=[[] for k in range(nbank)]
        for i in range(nbank):
            for _path in path[i]:
                while True:
                    _json=self.neonRedis.get(_path)
                    if _json is None:
                        print "Warning: No pid data!!!"
                    else:
                        _array=jsonArray.jsonDecoder(_json)
                        if _array is None:
                            print "Warning: Incompleted pid data!!!"
                        else:
                            #print len(_array)
                            for k in _array:
                                pid[i].append(k)
                            break

            print "the pixel number in  bankpid[i] is: "+str(len(pid[i]))        
        if (i!=nbank-1):
            print "error: iteration number is: "+str(i)

        #for i in range(nbank):
            #pid[i]=np.array(pid[i])
        #print pid
        return pid

    def getTof(self, bank, path):
        nbank=len(bank)
        tof=[[] for k in range(nbank)]
        for i in range(nbank):
            for _path in path[i]:
                _json=self.neonRedis.get(_path)
                if _json is None:
                    print "Error: No tof data!!!"
                else:
                    _array=jsonArray.jsonDecoder(_json)
                    if (len(_array)< 4377):
                        print "error: tof unfinished!"
                    print "the length of tof is: ", len(_array)
                    for k in _array:
                        tof[i].append(k)
                    break
            #print len(tof[i])
        if (i!=nbank-1):
            print "error: iteration number is: "+str(i)
        #for i in range(nbank):
            #tof[i]=np.array(tof[i])
        #print tof
        return tof

    def getValue(self, bank, path, modulenum):
        be=time.time()
        nbank=len(bank)
        value=[[] for k in range(nbank)]
        _tmp=[]
        st=0
        for i in range(nbank):
            #mvalue=[[] for j in range(modulenum[i])]
            #print "the length of mvalue is "+str(len(mvalue))
            for _path in path[i]:
                _json=self.neonRedis.get(_path)
                _array=jsonArray.jsonDecoder(_json)
                for ix in range(len(_array)):
                    #specNo=144+ix
                    for iy in range(len(_array[ix])):
                        value[i].append(_array[ix,iy])
                        #detws[[i][specNo][iy]]=_array[ix][iy]

        #for i in nbank:
        #loadCSNSraw(detWorksapce[i], detws[i])

            #print len(_tmp)
                #value[i].extend(_tmp)
            #print value[i]
        print len(value)
        print (time.time()-be)
        return value               


    def getDataFromNeon(self):

        self.detValue=self.getValue(self.bankList, self.neonpathDetValue, 4)
        self.monValue=self.getValue(self.monitorList, self.neonpathMonValue, 1)
        
    def createWorkspace(self):
        _monpid=self.monPid[0]
        #_montof=self.monTof[0]
        _monv=self.monValue[0]
        for i in range(len(self.bankList)):
            _bankpid=self.detPid[i]
            _banktof=self.detTof[i]
            _bankv=self.detValue[i]
        
            _n1=[]
            _n2=[]
            _n3=[]
            _n4=[]
            _n5=[]
            _n6=[]

            for j in range(len(_bankpid)):
                _n1.append(int(_bankpid[j]))
            for j in range(len(_banktof)):
                _n2.append(int(_banktof[j]))
            for j in range(len(_bankv)):
                _n3.append(int(_bankv[j]))
            for j in range(len(_monpid)):
                _n4.append(int(_monpid[j]))
            #for j in range(len(_montof)):
            for j in range(len(self.monTof)):
                _n5.append(int(self.monTof[j]))
            for j in range(len(_monv)):
                _n6.append(int(_monv[j]))
            _n0=self.detWorkspace[i]            

            LoadCSNSRaw(OutputWorkspace=_n0, PixelID_bank=_n1, TimeOfFlight_bank=_n2, Counts_bank=_n3, PixelID_monitor=_n4, TimeOfFlight_monitor=_n5, Counts_monitor=_n6)
            print "   INFO: Mantid Workspace updated. "        
           

    def reduceData(self):
        wsName=[]
        for j in range(len(self.detWorkspace)):
            wsName.append(self.detWorkspace[j])
        for i in range(len(self.detWorkspace)):
            _bname=str(wsName[i]+"_1")
            _mname=str(wsName[i]+"_2")

            print "workspace name of bank: ", _bname           
            print "workspace name of monitor: ", _mname           

            LoadInstrument(Workspace=_bname, Filename='instrument/GPPD/GPPD_IDF_bank.xml', RewriteSpectraMap='True')
            LoadCalFile(InputWorkspace=_bname, CalFilename='instrument/GPPD/GPPD_calfile.cal', WorkspaceName='GPPDraw')
            MaskDetectors(Workspace=_bname, MaskedWorkspace='GPPDraw_mask')
            AlignDetectors(InputWorkspace=_bname, OutputWorkspace=_bname+'test_alignD', OffsetsWorkspace='GPPDraw_offsets')
            Rebin(InputWorkspace=_bname+'test_alignD', OutputWorkspace=_bname+'test_rebin', Params='0.1,0.005,2.5')
            DiffractionFocussing(InputWorkspace=_bname+'test_rebin', OutputWorkspace=_bname+'test_dfocus', GroupingWorkspace='GPPDraw_group')
            ConvertUnits(InputWorkspace=_bname+'test_dfocus', OutputWorkspace=_bname+'test_wave', Target='Wavelength')
            # deal with monitor data
            LoadInstrument(Workspace=_mname, Filename='instrument/GPPD/GPPD_IDF_mon.xml', RewriteSpectraMap='True')
            LoadCalFile(InputWorkspace=_mname, CalFilename='instrument/GPPD/monitor_calfile.cal', WorkspaceName='GPPDmon')
            AlignDetectors(InputWorkspace=_mname, OutputWorkspace=_mname+'test_alignD', OffsetsWorkspace='GPPDmon_offsets')
            Rebin(InputWorkspace=_mname+'test_alignD', OutputWorkspace=_mname+'test_rebin', Params='0.1,0.005,2.5')
            DiffractionFocussing(InputWorkspace=_mname+'test_rebin', OutputWorkspace=_mname+'test_dfocus', GroupingWorkspace='GPPDmon_group')
            # create tof-counts data for monitor
            ConvertUnits(InputWorkspace=_mname+'test_dfocus', OutputWorkspace=_mname+'_tof', Target='TOF')
            Rebin(InputWorkspace=_mname+'_tof', OutputWorkspace=_mname+'_tof_rebin', Params='5000,8,35000')
            


            ConvertUnits(InputWorkspace=_mname+'test_dfocus', OutputWorkspace=_mname+'test_wave', Target='Wavelength')
            Rebin(InputWorkspace=_mname+'test_wave', OutputWorkspace=_mname+'_wave_rebin', Params='0.7,0.005,4.2')
            RebinToWorkspace(WorkspaceToRebin=_bname+'test_wave', WorkspaceToMatch=_mname+'_wave_rebin', OutputWorkspace=_bname+'_wave_rebin')
            Divide(LHSWorkspace=_bname+'_wave_rebin', RHSWorkspace=_mname+'_wave_rebin', OutputWorkspace=_bname+'_sam_nor', AllowDifferentNumberSpectra=True)
            ConvertUnits(InputWorkspace=_bname+'_sam_nor', OutputWorkspace=_bname+'_sam_nor_tof', Target='TOF')
            Rebin(InputWorkspace=_bname+'_sam_nor_tof', OutputWorkspace=_bname+'_sam_final_tof', Params='5000,8,35000')
            
            print " INFO: Data Reduction Finished!"

    def getPidTofFigure(self):
    
        xaxis=[]
        yaxis=[]
        zaxis=[]
        for i in range(len(self.detWorkspace)):
            #get pid-tof data
            rawname=str(self.detWorkspace[0]+'_1')
            name1=mtd[rawname]
            numHist=name1.getNumberHistograms()
            # get tof
            tof=name1.readX(0)
            ntof=len(tof)-1
            print "tof:", ntof 
            #get pid
            pid=[]
            for k in range(numHist):
                pid.append(k)
            #set 2D array
            x=np.zeros((numHist,ntof))
            y=np.zeros((numHist,ntof))
            z=np.zeros((numHist,ntof))
            for j in range(numHist):
                z[j][:]=name1.readY(j)
            for j in range(numHist):
                for k in range(ntof):
                    x[j][k]=pid[j]
                    y[j][k]=tof[k]
            xaxis.append(x)
            yaxis.append(y)
            zaxis.append(z)
            
        return xaxis,yaxis,zaxis
            
    def getXYMonitor(self):
        _tmpx=[]
        _tmpy=[]
        _tmpz=[]
        rawname=str(self.detWorkspace[0]+'_2')
        name1=mtd[rawname]
        numHist=name1.getNumberHistograms()
        for i in range(numHist):
            det=name1.getDetector(i)
            #idlist.append(det.getID())
            #print idlist
            pos=det.getRelativePos()
            n2=sum(name1.readY(i))
            _tmpx.append(pos.X())
            _tmpy.append(pos.Y())
            _tmpz.append(n2)
        xaxis=list(set(_tmpx))
        yaxis=list(set(_tmpy))
        xaxis.sort()
        yaxis.sort()
        zaxis=[[] for i in range(len(xaxis))]
        for j in range(len(yaxis)):
            for k in range(len(xaxis)):
                zaxis[j].append(_tmpz[(j+1)*k])    

        return xaxis, yaxis, zaxis
    

    def getXYBank(self):
        _tmpx=[]
        _tmpy=[]
        xaxis=[[] for i in range(6)]
        yaxis=[[] for i in range(6)]
        zaxis=[[] for i in range(6)]
        #_tmpz=[]
        idlist=[[] for i in range(4)]
        counts=[[] for i in range(4)]
        for i in range(len(self.detWorkspace)):
            fname=str(self.detWorkspace[i]+'_1')
            name=mtd[fname]
            num=name.getNumberHistograms()

            for j in range(4):
                for k in range(144):
                    det=name.getDetector((j+1)*k)
                    idlist[j].append(det.getID())
                    n2=sum(name.readY((j+1)*k))
                    counts[j].append(n2)
                    # get x y position
                    pos=det.getRelativePos()
                    _tmpx.append(pos.X())
                    _tmpy.append(pos.Y())
            # set x y axis              
            x1=list(set(_tmpx))
            y1=list(set(_tmpy))
            x1.sort()
            y1.sort()
            x2=[]
            y2=[]
            print "x1 length is: ", len(x1)            
            for k in range(len(x1)):
                x2.append(x1[k]+0.06)
                y2.append(x1[k]+0.06)

            xaxis[i].extend(x1) 
            yaxis[i].extend(y1)
            
            xaxis[i].extend(x2) 
            yaxis[i].extend(y2)
            print "the length of xaxis[i]: ", len(xaxis[i])
            
            # set counts data
            _tmpz=np.zeros((24,24))
            for m in range(4):
                if idlist[m][0]==10000 or idlist[m][0]==50000:
                    for x in range(12):
                        for y in range(12):
                            _tmpz[x,y]=counts[m][(x+1)*y]

                if idlist[m][0]==20000 or idlist[m][0]==60000:
                    for x in range(12):
                        for y in range(12):
                            _tmpz[x+12,y]=counts[m][(x+1)*y]

                if idlist[m][0]==30000 or idlist[m][0]==70000:
                    for x in range(12):
                        for y in range(12):
                            _tmpz[x,y+12]=counts[m][(x+1)*y]

                if idlist[m][0]==40000 or idlist[m][0]==80000:
                    for x in range(12):
                        for y in range(12):
                            _tmpz[x+12,y+12]=counts[m][(x+1)*y]
            _z1=[[] for a in range(24)]
            for p in range(24):
                _z1[p].extend(_tmpz[p, :])
            zaxis[i].extend(_z1)
        return xaxis, yaxis, zaxis        

    def getTofCountsBank(self):
        xaxis=[]
        yaxis=[]
        for i in range(len(self.detWorkspace)):
            #get counts and tof data
            fname=str(self.detWorkspace[i]+'_1_sam_final_tof')
            name2=mtd[fname]
            #counts=name2.extractY()
            #tof=name2.extractX()
            counts=name2.readY(0)
            tof=[]
            _tmp=name2.readX(0)
            for j in range(len(_tmp)-1):
                tof.append(_tmp[j])
            xaxis.append(tof)
            yaxis.append(counts)
        return xaxis, yaxis

    def getTofCountsMonitor(self):
        xaxis=[]
        yaxis=[]
        for i in range(len(self.monWorkspace)):
            #get counts and tof data
            fname=str(self.detWorkspace[i]+'_2_tof_rebin')
            name2=mtd[fname]
            counts=name2.readY(0)
            tof=[]
            _tmp=name2.readX(0)
            for j in range(len(_tmp)-1):
                tof.append(_tmp[j])
            xaxis.append(tof)
            yaxis.append(counts)
        return xaxis, yaxis




    def dataSet(self):


        x_bank, y_bank, value_bank=self.getXYBank()
        for i in range(len(self.detWorkspace)):
            _be=time.time()
            self.canvas.ax.clear()
            self.canvas.ax.pcolormesh(x_bank[i],y_bank[i],value_bank[i])
            self.canvas.draw()

            _p=self.neonpathDetXYImg[i]
            self.setImgFromCanvasJson(_p, self.canvas)
            print time.time()-_be
 
        print "finish XY-img"




        x,y,z=self.getPidTofFigure()
        for i in range(len(self.detWorkspace)):
            _be=time.time()
            self.canvas.ax.clear()
            self.canvas.ax.pcolormesh(x[i],y[i],z[i])
            self.canvas.draw()
            
            _p=self.neonpathDetImg[i]
            self.setImgFromCanvasJson(_p, self.canvas)
            print time.time()-_be
 
        print "finish img"
        b_tof, b_counts=self.getTofCountsBank()
        for i in range(len(self.detWorkspace)):
            _p1=self.neonpathDetTofSet[i]
            _data1=b_tof[i]
            self.setDataFromMantid(_p1, _data1)

            _p2=self.neonpathDetCountsSet[i]
            _data2=b_counts[i]
            self.setDataFromMantid(_p2, _data2)
    
        m_tof, m_counts=self.getTofCountsMonitor()
        for i in range(len(self.monWorkspace)):
            _p1=self.neonpathMonTofSet[i]
            _data1=m_tof[i]
            self.setDataFromMantid(_p1, _data1)

            _p2=self.neonpathMonCountsSet[i]            
            _data2=m_counts[i]
            self.setDataFromMantid(_p2, _data2)
    
        x_mon, y_mon, value_mon=self.getXYMonitor()
        print x_mon, y_mon, value_mon
        for i in range(len(self.monWorkspace)):
            _p1=self.neonpathMonXSet[i]
            self.setDataFromMantid(_p1, x_mon)
            _p2=self.neonpathMonYSet[i]
            self.setDataFromMantid(_p2, y_mon)
            _p3=self.neonpathMonValueSet[i]
            self.setDataFromMantid(_p3, value_mon)

 

    def process(self):
        _be=time.time()
        print "   INFO: getDataFromNeon", datetime.datetime.now()
        self.getDataFromNeon()
        print "   INFO: createWorkspace", datetime.datetime.now()
        self.createWorkspace()
        print "   INFO: reduceData", datetime.datetime.now()
        self.reduceData()
        print "   INFO: dataSet", datetime.datetime.now()
        self.dataSet()
        print "===================================="
        print time.time()-_be, "seconds"
        print "===================================="
    def run(self):
        while True:
            self.process()
            time.sleep(self.refreshtime)

    def setDataFromMantid(self, neonpath, data):
        _json_data=jsonArray.jsonEncoder(data)
        self.neonRedis.set(neonpath, _json_data)

    def setImgFromFile(self, neonpath,imgFile):
        _str = open(imgFile, "rb").read()
        self.neonRedis.set(neonpath, _str)
    
    #=============================
    # save image to file
    def setImgFromCanvasFile(self, neonpath, canvas):
        canvas.fig.savefig('img/pid_send.png', format='png')
        _str = open('img/pid_send.png', "rb").read()
        self.neonRedis.set(neonpath, _str)
    
    #=============================
    # save image to stringIO
    def setImgFromCanvasIO(self, neonpath, canvas):
        _strio=cStringIO.StringIO()
        canvas.fig.savefig(_strio, transparent=True, format='png')
        _strio.seek(0)
        _strio=cStringIO.StringIO(_strio.read())

        self.neonRedis.set(neonpath, _strio.getvalue())
        _strio.close()

    def getImgFromNeon(self, neonpath):
        _img=self.neonRedis.get(neonpath)
        _img=cStringIO.StringIO(_img)
        Image.open(_img).save('img/pid_receive.png')

    #=============================
    # save image to string
    def setImgFromCanvasStr(self, neonpath, canvas):
        _str = canvas.tostring_rgb()
        self.neonRedis.set(neonpath, _str)

    def getImgFromNeonStr(self, neonpath):
        _img=self.neonRedis.get(neonpath)
        _img=np.fromstring(_img, dtype='uint8').reshape(1200,750,3)
        Image.fromarray(_img,"RGB").save('img/pid_receive.png')
    
    #=============================
    # save img to array, then to json
    def setImgFromCanvasJson(self, neonpath, canvas):
        _array = np.fromstring(canvas.tostring_rgb(), np.uint8)
        _array = _array.reshape(canvas.get_width_height()[::-1] + (3,))
        _json=jsonArray.jsonEncoder(_array)
        self.neonRedis.set(neonpath, _json)
    
    def getImgFromNeonJson(self, neonpath):
        _json=self.neonRedis.get(neonpath)
        _array=jsonArray.jsonDecoder(_json)
        Image.fromarray(_array,"RGB").save('img/pid_receive.png')

    def getImgFromNeon(self, neonpath):
        _img=self.neonRedis.get(neonpath)
        _img=cStringIO.StringIO(_img)
        Image.open(_img).save('img/pid_receive.png')
    
    def pause(self):
        #self.paused = True
        #self.pause_cond.acquire()
        self.can_run.set()
        with self.arrayZ.mutex:
            self.arrayZ.queue.clear()
            self.valueC.queue.clear()

    def resume(self):
        #self.paused = False
        #self.pause_cond.notify()
        #self.pause_cond.release()
        self.can_run.clear()

    def stop(self):
        self.paused = True
        self.pause_cond.acquire()
        self.can_run.set()

class connectNeon():

    def __init__(
        self,
        ip,
        port,
        timeout,
        ):

        self.status=False
        client=None
        self.server=None
        begin=time.time()

        while True:
            if not client:
                self.server = redis.Redis(host=ip, port=port, db=0)
                self.status=True
                print "   INFO: NEON Started"
                break
            else:
                if time.time()-begin>timeout:
                    print "   ERROR: Connect NEON timeout"  
                    sys.exit()
                    #break
                try:
                    print "   INFO: Attempt to connect NEON"  
                    client=redis.client_list()
                except:
                    pass

    def getServer(self):
        return self.server

