from mantid.simpleapi import *
import xml.etree.ElementTree as ET
import sys
import time
import WriteGSA


def readXml(file):
    tree = ET.parse(file)
    root = tree.getroot()
    configure={}

    _tmp = root.findall('Bank')[0].text
    configure['Bank']=_tmp
    _tmp = root.findall('Module')[0].text
    configure['Module']=_tmp

    _tmp = root.findall('Sample_det')[0].text
    configure['sam_run_det']=_tmp
    _tmp = root.findall('Vanadium_det')[0].text
    configure['v_run_det']=_tmp
    _tmp = root.findall('Hold_det')[0].text
    configure['hold_run_det']=_tmp
    _tmp = root.findall('Empty_det')[0].text
    configure['empty_run_det']=_tmp
    _tmp = root.findall('TF')[0].text
    configure['timeFocus']=_tmp
    _tmp = root.findall('wmin')[0].text
    configure['wmin']=_tmp
    _tmp = root.findall('wmax')[0].text
    configure['wmax']=_tmp
    _tmp = root.findall('dmin')[0].text
    configure['dmin']=_tmp
    _tmp = root.findall('dmax')[0].text
    configure['dmax']=_tmp
    _tmp = root.findall('Absorption')[0].text
    configure['Absorption']=_tmp
    _tmp = root.findall('Multiple_scattering')[0].text
    configure['MS']=_tmp
    _tmp = root.findall('Volume')[0].text
    configure['Volume']=_tmp
    _tmp = root.findall('Mass')[0].text
    configure['Mass']=_tmp
    return configure


def loadWS(runList, bankName, prepath):
    path='/home/dur/work/nexusData/GPPD/RUN'
    for ix in runList:
        name=str(ix)
        try:
            filePath=path+name.zfill(7)+'/detector.nxs'
            LoadCSNSNexus(Filename=filePath, OutputWorkspace=name, Bankname=bankName,Loadmonitor=True, Monitorname='monitor2')
            LoadInstrument(Workspace=name+'_2', Filename=prepath+'/paramData/monitor2.xml', RewriteSpectraMap=True) 
            LoadInstrument(Workspace=name+'_1', Filename=prepath+'/paramData/'+bankName+'.xml', RewriteSpectraMap=True) 
        except:
            detFile=path+name.zfill(7)+'/'+bankName+'.nxs'
            monFile=path+name.zfill(7)+'/monitor2.nxs'
            LoadNexus(Filename=detFile,OutputWorkspace=name+'_1')
            LoadNexus(Filename=monFile,OutputWorkspace=name+'_2')



def mergeWS(runList,output):
    wsList=[str(s) for s in runList]
    for i in range(len(wsList)):
        if len(wsList)== 1:
            CloneWorkspace(InputWorkspace=wsList[i], OutputWorkspace=output)
            DeleteWorkspace(Workspace=wsList[i])
        elif wsList[i+1]!=wsList[-1]:
            Plus(LHSWorkspace=wsList[i],RHSWorkspace=wsList[i+1], OutputWorkspace=wsList[i+1])
            DeleteWorkspace(Workspace=wsList[i])
        else:
            Plus(LHSWorkspace=wsList[i],RHSWorkspace=wsList[i+1], OutputWorkspace=output)
            DeleteWorkspace(Workspace=wsList[i])
            DeleteWorkspace(Workspace=wsList[i+1])
            break

def normData(wsName, bankName, wave_rebin,output, monSUM, prepath):
        
    ConvertUnits(InputWorkspace=wsName+'_1', OutputWorkspace='det', Target = 'Wavelength', AlignBins = True)
    ConvertUnits(InputWorkspace=wsName+'_2', OutputWorkspace='m', Target = 'Wavelength', AlignBins = True)
    SumSpectra(InputWorkspace='m',OutputWorkspace='m')
    Rebin(InputWorkspace = 'det', OutputWorkspace = 'det', Params = wave_rebin)
    Rebin(InputWorkspace = 'm', OutputWorkspace = 'm', Params = wave_rebin)
    if monSUM:
        name=mtd['m']
        value=sum(name.readY(0))
        CreateSingleValuedWorkspace(OutputWorkspace='m', DataValue=value)
    else:
        pass
    Divide(LHSWorkspace="det", RHSWorkspace="m", OutputWorkspace=output, AllowDifferentNumberSpectra=True)
    ConvertUnits(InputWorkspace=output, OutputWorkspace='ans', Target = 'dSpacing', AlignBins = True)
    SumSpectra(InputWorkspace='ans', OutputWorkspace='ans')
    SaveAscii(InputWorkspace='ans',Filename=prepath+'/'+output+'.dat')
    DeleteWorkspace(Workspace = wsName)
    DeleteWorkspace(Workspace = 'det')
    DeleteWorkspace(Workspace = 'm')
    DeleteWorkspace(Workspace = 'ans')


    
def subHold(samWS,holdWS, output, prepath):
    Minus(LHSWorkspace = samWS,RHSWorkspace = holdWS, OutputWorkspace = output)
    ResetNegatives(InputWorkspace=output, OutputWorkspace=output, AddMinimum=False)
    ConvertUnits(InputWorkspace=output, OutputWorkspace='ans', Target = 'dSpacing', AlignBins = True)
    SumSpectra(InputWorkspace='ans', OutputWorkspace='ans')
    SaveAscii(InputWorkspace='ans',Filename=prepath+'/'+output+'.dat')
    DeleteWorkspace(Workspace = 'ans')
    
def Reduction():
    _be=time.time()

    conf=readXml('configure.xml')

    if conf['Bank'] is not None:
        if int(conf['Bank']) == 0:
            bankName='bank1'
        elif int(conf['Bank']) == 1:
            bankName='bank2'
        elif int(conf['Bank']) == 2:
            bankName='bank3'
    
    if conf['Module'] is not None:
        moduleName=conf['Module']

    if conf['hold_run_det'] is not None:
        hold_run=[int(i) for i in conf['hold_run_det'].split(',')]
        hNum=len(hold_run)
    else:
        hNum=0
    if conf['sam_run_det'] is not None:
        sam_run=[int(i) for i in conf['sam_run_det'].split(',')]
        sNum=len(sam_run)
    else:
        sNum=0
    if conf['v_run_det'] is not None:
        v_run=[int(i) for i in conf['v_run_det'].split(',')]
        vNum=len(v_run)
    else:
        vNum=0
    if conf['empty_run_det'] is not None:
        empty_run=[int(i) for i in conf['empty_run_det'].split(',')]
        eNum=len(empty_run)
    else:
        eNum=0

    if conf['Volume'] is not None:
        volume=float(conf['Volume'])
    else:
        volume=1.0
    if conf['Mass'] is not None:
        mass=float(conf['Mass'])
    else:
        mass=1.0
    if conf['Absorption'] is not None:
        if int(conf['Absorption'])!= 0:
            Absorption=True
        else:
            Absorption=False
    else:
        Absorption=1
    if conf['Volume'] is not None:
        volume=float(conf['Volume'])
    else:
        volume=1.0
    if conf['Mass'] is not None:
        mass=float(conf['Mass'])
    else:
        mass=1.0

    if conf['wmin'] is not None:
        w1=float(conf['wmin'])
    else:
        w1=0.7
    if conf['wmax'] is not None:
        w2=float(conf['wmax'])
    else:
        w1=4.5
    wave_rebin=str(w1)+',0.001,'+str(w2)

    if conf['dmin'] is not None:
        d1=float(conf['dmin'])
    else:
        d1=0.7
    if conf['dmax'] is not None:
        d2=float(conf['dmax'])
    else:
        d2=2.8
    #d_rebin=str(d1)+',0.0004,'+str(d2)
    #d_rebin=str(d1)+',0.0002,'+str(d2)
    d_rebin=str(d1)+',-0.0008,'+str(d2)

    monSUM=True
    prepath=sys.path[0]
    
    CreateSingleValuedWorkspace(OutputWorkspace='volume', DataValue=volume)

    if sNum>0:
        loadWS(sam_run, bankName, prepath)
        mergeWS(sam_run,'sam')
        normData('sam',bankName, wave_rebin,'sam_raw', monSUM,prepath)
    if hNum>0:
        loadWS(hold_run, bankName, prepath)
        mergeWS(hold_run,'hold')
        normData('hold',bankName, wave_rebin,'hold_raw', monSUM,prepath)
    if vNum>0:
        loadWS(v_run, bankName, prepath)
        mergeWS(v_run,'v')
        normData('v',bankName, wave_rebin,'v_raw', monSUM,prepath)


    if sNum>0:
        if hNum>0:
            subHold('sam_raw','hold_raw','sam_hold',prepath)
            ConvertUnits(InputWorkspace='sam_hold', OutputWorkspace='sam_d', Target = 'dSpacing', AlignBins = True)
            SumSpectra(InputWorkspace='sam_d', OutputWorkspace='sam_d')
        else:
            ConvertUnits(InputWorkspace='sam_raw', OutputWorkspace='sam_d', Target = 'dSpacing', AlignBins = True)
            SumSpectra(InputWorkspace='sam_d', OutputWorkspace='sam_d')

        if vNum>0:
            if hNum>0:
                subHold('v_raw','hold_raw','v_hold',prepath)
                if Absorption:
                    MultipleScatteringCylinderAbsorption(InputWorkspace='v_hold', OutputWorkspace='v_hold')
                else:
                    pass
                ConvertUnits(InputWorkspace='v_hold', OutputWorkspace='v_hold', Target = 'dSpacing', AlignBins = True)
                SumSpectra(InputWorkspace='v_hold', OutputWorkspace='v_d')
            else:
                if Absorption:
                    MultipleScatteringCylinderAbsorption(InputWorkspace='v_raw', OutputWorkspace='v_raw')
                else:
                    pass
                ConvertUnits(InputWorkspace='v_raw', OutputWorkspace='v_raw', Target = 'dSpacing', AlignBins = True)
                SumSpectra(InputWorkspace='v_raw', OutputWorkspace='v_d')

            StripVanadiumPeaks(InputWorkspace='v_d', OutputWorkspace='v_d')
            Rebin(InputWorkspace = 'sam_d', OutputWorkspace = 'sam_d', Params = d_rebin)
            Rebin(InputWorkspace = 'v_d', OutputWorkspace = 'v_d', Params = d_rebin)
            Divide(LHSWorkspace = 'sam_d', RHSWorkspace = 'v_d', OutputWorkspace = 'ans')
            Divide(LHSWorkspace = 'ans', RHSWorkspace = 'volume', OutputWorkspace = 'ans')
            SaveAscii(InputWorkspace='ans',Filename=prepath+'/sam_Id.dat', Separator='CSV')
            ConvertUnits(InputWorkspace='ans', OutputWorkspace='ans', Target = 'TOF', AlignBins = True)
            #Rebin(InputWorkspace = 'ans', OutputWorkspace = 'ans', Params = '5000,16,40000')
            ReplaceSpecialValues(InputWorkspace='ans', OutputWorkspace='ans', NaNValue=0, InfinityValue=0)
            CropWorkspace(InputWorkspace='ans', OutputWorkspace='ans', Xmin=0, Xmax=40000)
            SaveAscii(InputWorkspace='ans',Filename=prepath+'/sam_Itof.dat', Separator='CSV')

        else:
            Divide(LHSWorkspace = 'sam_d', RHSWorkspace = 'volume', OutputWorkspace = 'sam_d')
            SaveAscii(InputWorkspace='sam_d',Filename=prepath+'/sam_Id.dat', Separator='CSV')
            ConvertUnits(InputWorkspace='sam_d', OutputWorkspace='ans', Target = 'TOF', AlignBins = True)
            #Rebin(InputWorkspace = 'ans', OutputWorkspace = 'ans', Params = '5000,16,40000')
            ReplaceSpecialValues(InputWorkspace='ans', OutputWorkspace='ans', NaNValue=0, InfinityValue=0)
            CropWorkspace(InputWorkspace='ans', OutputWorkspace='ans', Xmin=0, Xmax=40000)
            SaveAscii(InputWorkspace='ans',Filename=prepath+'/sam_Itof.dat', Separator='CSV')
            
    else:
        pass

    print "total time: ", time.time()-_be, " seconds!"

    return bankName


#bname=Reduction()
#WriteGSA.WriteGSA('test',bname)
