import sys
from mantid.simpleapi import *

#path_param='/csns_workspace/CSNS/shenfeiran/absor/paramData'
#path_data='/csns_workspace/CSNS/shenfeiran/instrument_data/BL18-commissioning'
#path_save='/csns_workspace/CSNS/shenfeiran/absor/Correction'
#path_param="/home/dur/cloudBU/GPPDoffline/paramData"
#path_data='/home/dur/work/dp_GPPD/nexusData'
#path_save="/home/dur/tmp"
path_data="/csns_workspace/CSNS/durong/instrument_data/BL18-commissioning"
path_param="/csns_workspace/CSNS/durong/Desktop/absCorr/paramData"
path_save = "/csns_workspace/CSNS/durong/Desktop/absCorr"
wave_rebin='0.1,0.001,4.9'
d_rebin='0.5,-0.0012,3.8'
v_run=[5406]
hold_run=[5414]
sam_run=[5480]

#absorption correction parameter
absoXSection=369.6811
scatXSection=7.7263
numDensity=0.001095
#modify sample shape
shape='''
<cylinder id="stick">
  <centre-of-bottom-base x="0.0" y="0.0" z="0.0" />
  <axis x="0.0" y="1.0" z="0.0" />
  <radius val="0.0045" />
  <height val="0.007" />
  </cylinder>'''


bankDict={
        "bank1":[521,531,532,533,541,542,543,623,631,632,633,641,642,643],
        "bank3":[123,131,132,133,141,142,143,221,231,232,233,241,242,243],
        "bank2_new":[131,141,233,243,322,331,332,333,341,342,343,422,431,432,433,441,442,443],
        "bank2_old":[322,331,332,333,341,342,343,422,431,432,433,441,442,443]
        }
bankName="bank2"
mergeMode="old"  

def mergeWS(runList,output):
    wsList=[str(s) for s in runList]
    for i in range(len(wsList)):
        if len(wsList)== 1:
            CloneWorkspace(InputWorkspace=wsList[i], OutputWorkspace=output)
            DeleteWorkspace(Workspace=wsList[i])
        elif wsList[i+1]!=wsList[-1]:
            try:
                Plus(LHSWorkspace=wsList[i],RHSWorkspace=wsList[i+1], OutputWorkspace=wsList[i+1])
                DeleteWorkspace(Workspace=wsList[i])
            except:
                logging.warning('TOF can not match and merge failed!')
        else:
            Plus(LHSWorkspace=wsList[i],RHSWorkspace=wsList[i+1], OutputWorkspace=output)
            DeleteWorkspace(Workspace=wsList[i])
            DeleteWorkspace(Workspace=wsList[i+1])
            break

def loadDetWS(runList, moduleName, output):
    for ix in runList:
        name=str(ix)
        filepath = path_data+'/RUN'+str(ix).zfill(7)+'/detector.nxs'
        LoadCSNSNexus(Filename=filepath, OutputWorkspace=name, Bankname=moduleName,Loadmonitor=False)
        LoadInstrument(Workspace=name, Filename=path_param+'/'+moduleName+'.xml', RewriteSpectraMap=True)
    mergeWS(runList,'tmp')
    ConvertUnits(InputWorkspace='tmp',OutputWorkspace='tmp', Target = 'Wavelength', AlignBins = True)
    Rebin(InputWorkspace = 'tmp', OutputWorkspace = 'tmp', Params = wave_rebin)
	#absorption correction
    CreateSampleShape(InputWorkspace='tmp',ShapeXML=shape)
    AbsorptionCorrection(InputWorkspace='tmp',OutputWorkspace='corr_tmp',AttenuationXSection=absoXSection/1.81, ScatteringXSection=scatXSection, SampleNumberDensity=numDensity)
    mtd['corr']=mtd['tmp']/mtd['corr_tmp']
    ConvertUnits(InputWorkspace='corr', OutputWorkspace='corr', Target = 'dSpacing', AlignBins = True)
    SumSpectra(InputWorkspace='corr', OutputWorkspace=output)

	
	
def loadMonWS(runList):
    for ix in runList:
        name=str(ix)
        filepath = path_data+'/RUN'+str(ix).zfill(7)+'/detector.nxs'
        LoadCSNSNexus(Filename=filepath, OutputWorkspace=name, Loadbank=False,Loadmonitor=True, Monitorname="monitor2")
        LoadInstrument(Workspace=name, Filename=path_param+'/monitor2.xml', RewriteSpectraMap=True)
    mergeWS(runList,'tmp')
    ConvertUnits(InputWorkspace='tmp',OutputWorkspace='tmp', Target = 'Wavelength', AlignBins = True)
    Rebin(InputWorkspace = 'tmp', OutputWorkspace = 'tmp', Params = wave_rebin)
    SumSpectra(InputWorkspace='tmp',OutputWorkspace='tmp')
    value=sum(mtd['tmp'].readY(0))
    return value
    

def convertTOF():
    name=mtd['sam_d']
    tmp=name.readX(0)
    y=name.readY(0)
    e=name.readE(0)
    d=[]
    for i in range(len(tmp)-1):
        d.append(tmp[i]+(tmp[i+1]-tmp[i])/2.0)

    tof=[]
    if bankName=='bank1':
        #constA=15482.00858 #5320001
        constA=14728.384562 #5330111
        #constA=15813.5438 #5310001
        
        constB= 0.0
    elif bankName=='bank3':
        constA=5043.368242
        constB= 0.0
    elif bankName=='bank2':
        constA=11457.59304
        constB= -0.000203986

    for i in range(len(d)):
        tof.append(d[i]*constA+constB)
        
    return tof, y, e

def tof2d(tof):
    d=[]
    if bankName=='bank1':
        #a=2.84
        #b=15485.77
        #c=3.38
        a=4.46
        b=14723.34
        c=12.20

    if bankName=='bank2':
        a=2.52
        b=11458.31
        c=4.0

    if bankName=='bank3':
        a=2.74
        b=5032.97
        c=10.77

    for i in range(len(tof)):
        output=(-b+(b**2.0-4*a*(c-tof[i]))**0.5)/2.0/a
        d.append(output)   
    return d

       

def writeFormat(runNo):
    CR=unichr(13)
    LF=unichr(10)
    CRLF=CR+LF
    bname=bankName.upper()
    TOF, YINT, YESD=convertTOF()
    #write gsas
    f_gsas=path_save+'/GPPD_'+bname+'_V2.0_'+(runNo)+'.gsa'
    file_gsas=open(f_gsas, 'w')
    num=bankName[-1]
    if bankName=='bank2':
        strtmp='GPPD Diffraction Histogram for 90-degree bank, '+str(runNo)
    elif bankName=='bank1':
        strtmp='GPPD Diffraction Histogram for 180-degree bank, '+str(runNo)
    else:
        strtmp='GPPD Diffraction Histogram for 30-degree bank, '+str(runNo)
    file_gsas.write('%-80s%s' % (strtmp,CRLF))
    n=len(TOF)
    mul=1
    step = TOF[1]-TOF[0]
    istart = 1
    for i in range(n-1):
        if TOF[i+1]-TOF[i] != step:
            istart = i+1
            break
    res = (TOF[istart]-TOF[istart-1])/float(TOF[istart-1])
    iformat = 2
    if iformat ==0:
        nrec = 1
        nch = n
        strtmp='%s%s %5d %5d %s %6d %6d %2d %2d %s' % ('BANK ',num, n, nch, 'CONST', int(TOF[0]), step, 0, 0, 'FXYE')
    if iformat ==1:
        nrec = 8
        nch=int((n-1)/nrec)
        n=nch*nrec
        strtmp='%s %6d %6d %s %6d %6d %6d %.6f %s' % ('BANK 2', n, nch, 'RALF', int(TOF[0])*32, step*32, int(TOF[istart])*32, res, 'ALT')
    if iformat ==2:
        nrec = 1
        nch = n-1
        n=nch
        strtmp='%s%s %6d %6d %s %6d %6d %6d %.6f %s' % ('BANK ',num, n, nch, 'RALF', int(TOF[0])*32, step*32, int(TOF[istart])*32, res, 'FXYE')
    if iformat ==3:
        nrec = 1
        nch = n
        strtmp='%s %6d %6d %s %6d %.6f %6d %d %s' % ('BANK 2', n, nch, 'SLOG', int(TOF[0]), res, int(TOF[n-1]), 0, 'FXYE')
    file_gsas.write('%-80s%s' % (strtmp,CRLF))
    # write data
    for i in range(nch):
        for j in range(nrec):
            if i*nrec+j==0:
                tof = TOF[0]
            else:
                tof = (TOF[i*nrec+j]+TOF[i*nrec+j-1])*0.5
            if iformat ==1:
                strtmp='%7d%8d%5d' % (tof*32, int(mul*YINT[i*nrec+j]*(TOF[i*nrec+j+1]-TOF[i*nrec+j])), int(mul*YESD[i*nrec+j]*(TOF[i*nrec+j+1]-TOF[i*nrec+j])))
            elif iformat ==2:
                strtmp='%15.6f %15.10f %15.10f' % (tof, YINT[i*nrec+j]*(TOF[i*nrec+j+1]-TOF[i*nrec+j]), YESD[i*nrec+j]*(TOF[i*nrec+j+1]-TOF[i*nrec+j]))
            else:
                strtmp='%8d %15.5f %15.5f' % (tof, YINT[i*nrec+j], YESD[i*nrec+j])
            file_gsas.write('%s' % strtmp)
        file_gsas.write('%s' % CRLF)
    print 'Histogram for GSAS:       '+f_gsas

    #write ZR
    f_zr=path_save+'/GPPD_'+bname+'_V2.0_'+runNo+'.histogramIgor'
    file_zr=open(f_zr, 'w')
    strtmp='IGOR'
    file_zr.write('%s%s' % (strtmp,CRLF))
    strtmp='%s %s %s %s' % ('WAVES/O', 'tof',  'yint',  'yerr')
    file_zr.write('%s%s' % (strtmp, CRLF))
    strtmp='BEGIN'
    file_zr.write('%s%s' % (strtmp,CRLF))
    # write data
    for i in range(n):
        strtmp='%.2f %.6f %.6f' % (TOF[i], YINT[i], YESD[i])
        file_zr.write('%s' % strtmp)
        file_zr.write('%s' % CRLF)
    strtmp='END'
    file_zr.write('%s' % (strtmp))
    print 'Histogram for ZR:       '+f_zr
    
    #write FP
    f_fp=path_save+'/GPPD_'+bname+'_V2.0_'+runNo+'.dat'
    file_fp=open(f_fp, 'w')
    if bankName=='bank2':
        strtmp='GPPD Diffraction Histogram for 90-degree bank, '+str(runNo)
    elif bankName=='bank1':
        strtmp='GPPD Diffraction Histogram for 180-degree bank, '+str(runNo)
    else:
        strtmp='GPPD Diffraction Histogram for 30-degree bank, '+str(runNo)
    file_fp.write('%s%s' % (strtmp,CRLF))
    strtmp='The original intensities and sigmas have been multiplied by 10000'
    file_fp.write('%s%s' % (strtmp,CRLF))
    strtmp='%s %s %s' % ('TOF', 'INT', 'ERR')
    file_fp.write('%s%s' % (strtmp,CRLF))
    for i in range(n):
        strtmp='%.2f %.6f %.6f' % (TOF[i], YINT[i]*(10**(4)), YESD[i]*(10**(4)))
        file_fp.write('%s' % strtmp)
        file_fp.write('%s' % CRLF)
    print 'Histogram for FP:       '+f_fp
    #write d
    d=tof2d(TOF)
    f_fp=path_save+'/GPPD_'+bname+'_V2.0_'+runNo+'_d.dat'
    file_fp=open(f_fp, 'w')
    # write title
    if bankName=='bank2':
        strtmp='GPPD Diffraction Histogram for 90-degree bank, '+str(runNo)
    elif bankName=='bank1':
        strtmp='GPPD Diffraction Histogram for 180-degree bank, '+str(runNo)
    else:
        strtmp='GPPD Diffraction Histogram for 30-degree bank, '+str(runNo)
    file_fp.write('%s%s' % (strtmp,CRLF))

    strtmp='The original intensities and sigmas have been multiplied by 10000'
    file_fp.write('%s%s' % (strtmp,CRLF))
    
    strtmp='%s %s %s' % ('d', 'INT', 'ERR')
    file_fp.write('%s%s' % (strtmp,CRLF))

    # write data
    for i in range(n):
        strtmp='%.6f %.6f %.6f' % (d[i], YINT[i]*(10**(4)), YESD[i]*(10**(4)))
        file_fp.write('%s' % strtmp)
        file_fp.write('%s' % CRLF)

    print 'Histogram for FPd:       '+f_fp


def loadNormWS(runNo,bName,output):
    filename=path_data+'/RUN'+str(runNo).zfill(7)+'/detector.nxs'
    LoadCSNSNexus(Filename=filename, OutputWorkspace='sam',Bankname=bName,Loadmonitor=True, Monitorname='monitor2')
    LoadInstrument(Workspace='sam_2', Filename=path_param+'/monitor2.xml', RewriteSpectraMap=True)
    LoadInstrument(Workspace='sam_1', Filename=path_param+'/'+bankName+'.xml', RewriteSpectraMap=True)
    ConvertUnits(InputWorkspace='sam_2', OutputWorkspace='m', Target = 'Wavelength', AlignBins = True)
    ConvertUnits(InputWorkspace='sam_1', OutputWorkspace='det', Target = 'Wavelength', AlignBins = True)
    Rebin(InputWorkspace = 'm', OutputWorkspace = 'm', Params = wave_rebin)
    Rebin(InputWorkspace = 'det', OutputWorkspace = 'det', Params = wave_rebin)
    SumSpectra(InputWorkspace='m',OutputWorkspace='m')   
    name=mtd['m']
    value=sum(name.readY(0))
    CreateSingleValuedWorkspace(OutputWorkspace='m', DataValue=value)
    Divide(LHSWorkspace="det", RHSWorkspace="m", OutputWorkspace=output, AllowDifferentNumberSpectra=True)
    DeleteWorkspace(Workspace = 'sam')
    DeleteWorkspace(Workspace = 'm')
    DeleteWorkspace(Workspace = 'det')
    
def convertDspacing(wsName, output):
    ConvertUnits(InputWorkspace=wsName, OutputWorkspace=wsName, Target = 'dSpacing', AlignBins = True)
    SumSpectra(InputWorkspace=wsName, OutputWorkspace=output)
    Rebin(InputWorkspace = output, OutputWorkspace = output, Params = d_rebin)
    DeleteWorkspace(Workspace = wsName)


def getCalFile(runno, bankName, mergeMode, expMode):
        filename=""
        mode="_"+mergeMode
        l = os.listdir(path_param)
        for name in l:
            if name[0]==expMode[0]:
                int1 = name.find(str(runno))
                int2 = name.find(bankName)
                if(mode=="_old"):
                    int3=name.find(mode)
                    if int1!=-1 and int2!=-1 and int3!=-1:
                        filename = name
                        break
                else:
                    if int1!=-1 and int2!=-1:
                        filename = name
                        break
        return filename


def processV(bName,runNo):
    fname = path_param+"/"+getCalFile(runNo, bName,mergeMode,'v')
    print fname
    if os.path.exists(fname):
        LoadNexus(Filename=fname, OutputWorkspace='v')
        Rebin(InputWorkspace = 'v', OutputWorkspace = 'v_d', Params = d_rebin)
        DeleteWorkspace(Workspace='v')
    else:
        loadNormWS(runNo, bName, 'v')
        MultipleScatteringCylinderAbsorption(InputWorkspace='v', OutputWorkspace='v')
        convertDspacing('v','v_d')
        StripVanadiumPeaks(InputWorkspace='v_d', OutputWorkspace='v_d')

def processHold(bName, runNo):
    fname =path_param+"/"+getCalFile(runNo, bName,mergeMode,'hold')
    #sys.exit()
    if os.path.exists(fname):
        LoadNexus(Filename=fname, OutputWorkspace='hold')
        Rebin(InputWorkspace = 'hold', OutputWorkspace = 'hold_d', Params = d_rebin)
    else:
        loadNormWS(runNo, bName, 'hold')
        convertDspacing('hold','hold_d')
	mtd['hold_d']*=0.5
    

if len(v_run)==1:
    processV(bankName,v_run[0])

if len(hold_run)==1:
    processHold(bankName,hold_run[0])

if bankName=="bank2":
    name=bankName+"_"+mergeMode
else:
    name=bankName
moduleList=bankDict[name]
num=0
deltaD=[]
maxD=[]
minD=[]
for mnum in moduleList:
    mname="module"+str(mnum)
    loadDetWS(sam_run, mname, 'det')   
    value=float(mtd['det'].readX(0)[1]-mtd['det'].readX(0)[0])
    maxD.append(float(mtd['det'].readX(0)[-1]))
    minD.append(float(mtd['det'].readX(0)[1]))
    deltaD.append(value)
    RenameWorkspace(InputWorkspace='det',OutputWorkspace=str(num))
    num+=1

step_d=max(deltaD)
dmin=min(minD)
dmax=max(maxD)
dRebin=str(dmin)+','+str(step_d)+','+str(dmax)
for j in range(num):
    if j==0:
        Rebin(InputWorkspace=str(j),OutputWorkspace='ans', Params=dRebin)
    else:
        Rebin(InputWorkspace=str(j),OutputWorkspace='tmp', Params=dRebin)
        mtd['ans']+=mtd['tmp']
mc=loadMonWS(sam_run)
mtd['ans']=mtd['ans']/mc
Rebin(InputWorkspace = 'ans', OutputWorkspace = 'sam_d', Params = d_rebin)

if len(hold_run)==1:
    Minus(LHSWorkspace = 'sam_d',RHSWorkspace = 'hold_d', OutputWorkspace = 'sam_d', AllowDifferentNumberSpectra=True)
else:
    pass
if len(v_run)==1:
    Divide(LHSWorkspace = 'sam_d', RHSWorkspace = 'v_d', OutputWorkspace = 'sam_d')
else:
    pass
runName='RUN'+str(sam_run[0]).zfill(7)
writeFormat(runName)
    
    
