#!/usr/bin/python
from mantid.simpleapi import *
import sys
import json
import math

class sampleInfo():
    def __init__(self,filename):
        #self.filename=filename
        self.sampleName={}
        self.sampleEle=[]
        self.sampleEleNum=[]
        self.ratio=[]
        path = "/home/dur/work/PDF/"
        with open(path+"elementInfo.json", 'r') as f:
            self.eleInfo = json.load(f)
        f.close()

        with open(filename, 'r') as f:
            self.load_dict = json.load(f)
        f.close()
    
    def getSampleInfo(self):
        self.sam_h=float(self.load_dict['height'])
        self.beam_h=float(self.load_dict['beamheight'])
        self.sam_mass=float(self.load_dict['mass'])
        self.sam_r=float(self.load_dict['radius'])
        sample=str(self.load_dict["sample_name"])
        sample=sample.split('-')
        for i in range(len(sample)):
            tmp=sample[i]
            tmp=tmp.split(":")
            self.sampleEle.append(str(tmp[0]))
            self.sampleEleNum.append(float(tmp[1]))
        for i in range(len(self.sampleEle)):
            self.ratio.append(self.sampleEleNum[i]/sum(self.sampleEleNum))


    def getDensity(self):
        self.density = self.sam_mass/(math.pi*self.sam_r*self.sam_r*self.sam_h)

    def getMolarMass(self):
        self.molarMass=0.0
        for i in range(len(self.sampleEle)):
            self.molarMass+=float(self.sampleEleNum[i])*float(self.eleInfo[self.sampleEle[i]]['molarmass'])

    def getRealHeight(self):
        if self.sam_h>=self.beam_h:
            return self.beam_h
        else:
            return self.sam_h

    def getAttenXS(self):
        self.attenXS = 0
        for i in range(len(self.sampleEle)):
            ele=self.sampleEle[i]
            self.attenXS+=float(self.eleInfo[ele]["Abs_xs"])*self.ratio[i]

    def getScattXS(self):
        self.scattXS = 0
        for i in range(len(self.sampleEle)):
            ele=self.sampleEle[i]
            self.scattXS+=float(self.eleInfo[ele]["Scatt_xs"])*self.ratio[i]


    def getB_avg_sqrd(self):
        self.b_avg_sqrd = 0.0
        for i in range(len(self.sampleEle)):
            ele=self.sampleEle[i]
            self.b_avg_sqrd += float(self.eleInfo[ele]['Coh_b'])/10.0*self.ratio[i]
        self.b_avg_sqrd=math.pow(self.b_avg_sqrd,2)

    def getB_sqrd_avg(self):
        self.b_sqrd_avg = 0.0
        for i in range(len(self.sampleEle)):
            ele=self.sampleEle[i]
            self.b_sqrd_avg += math.pow(float(self.eleInfo[ele]['Coh_b'])/10.0,2)*self.ratio[i]

    def getNumDensity(self):
        self.numDensity = math.pow(10,-24)*self.density/self.molarMass*6.02*math.pow(10,23)

    def getVolume(self,r_real, h_real):
        return (math.pi*r_real*r_real*h_real)

    def getAtomNum(self):
        volume = self.getVolume(self.sam_r, self.getRealHeight())
        #self.natoms = self.numDensity*math.pow(10,24)*volume
        self.natoms = self.numDensity*volume


    def getAll(self):
        self.getSampleInfo()
        self.getDensity()
        self.getMolarMass()
        self.getAttenXS()
        self.getScattXS()
        self.getNumDensity()
        self.getAtomNum()
        self.getB_sqrd_avg()
        self.getB_avg_sqrd()




class data_handling():
    def __init__(self,runList, waveRebin, outputName):
        self.path_cur='/home/dur/work/PDF'
        self.path_data='/home/dur/work/nexusData'
        self.path_xml='/home/dur/work/paramData'
        self.waveRebin=waveRebin
        self.runNo=runList
        self.outputName=outputName

    def loadBankWS(self, moduleName,pids,output):
        for ix in self.runNo:
            name=str(ix)
            filePath=self.path_data+'/RUN'+name.zfill(7)+'/detector.nxs'
            LoadCSNSNexus(Filename=filePath, OutputWorkspace=name, Bankname=moduleName,Loadbank=True)
            LoadInstrument(Workspace=name, Filename=self.path_xml+'/module/'+moduleName+'.xml', RewriteSpectraMap=True)
        self.mergeWS(output)
        ExtractSpectra(InputWorkspace=output,OutputWorkspace=output, WorkspaceIndexList=pids)


    def loadMonWS(self,output):
        for ix in self.runNo:
            name=str(ix)
            filePath=self.path_data+'/RUN'+name.zfill(7)+'/detector.nxs'
            LoadCSNSNexus(filePath, OutputWorkspace=name,Loadbank=False,Bankname='monitor2')
            LoadInstrument(Workspace=name, Filename=self.path_xml+'/module/monitor2.xml', RewriteSpectraMap=True)
        self.mergeWS(output)

    def mergeWS(self, output):
        wsList=[str(s) for s in self.runNo]
        for i in range(len(wsList)):
            if len(wsList)== 1:
                CloneWorkspace(InputWorkspace=wsList[i], OutputWorkspace=output)
                DeleteWorkspace(Workspace=wsList[i])
            elif wsList[i+1]!=wsList[-1]:
                Plus(LHSWorkspace=wsList[i],RHSWorkspace=wsList[i+1], OutputWorkspace=wsList[i+1])
                DeleteWorkspace(Workspace=wsList[i])
            else:
                Plus(LHSWorkspace=wsList[i],RHSWorkspace=wsList[i+1], OutputWorkspace=output)
                DeleteWorkspace(Workspace=wsList[i])
                DeleteWorkspace(Workspace=wsList[i+1])
                break

    def ws2wave(self, wsname,judge):
        ConvertUnits(InputWorkspace=wsname,OutputWorkspace=wsname ,Target = 'Wavelength', AlignBins = True)
        Rebin(InputWorkspace=wsname, OutputWorkspace=wsname,Params=self.waveRebin)
        if judge:
            pass
        else:
            SumSpectra(InputWorkspace=wsname,OutputWorkspace=wsname)

    def absorCorr(self,wsName,numDensity, radius, attenXS, scattXS):
        MultipleScatteringCylinderAbsorption(InputWorkspace=wsName, OutputWorkspace=wsName,SampleNumberDensity=numDensity, CylinderSampleRadius=radius,AttenuationXSection=attenXS/1.81,ScatteringXSection=scattXS)

    def ws2d(self,wsname, output):
        ConvertUnits(InputWorkspace=wsname,OutputWorkspace=wsname ,Target = 'dSpacing', AlignBins = True)
        SumSpectra(InputWorkspace=wsname,OutputWorkspace=output)

    def getdRebin(self, moduleNums):
        deltaD=[]
        maxD=[]
        minD=[]
        for num in range(moduleNums):
            value=float(mtd[str(num)].readX(0)[1]-mtd[str(num)].readX(0)[0])
            maxD.append(float(mtd[str(num)].readX(0)[-1]))
            minD.append(float(mtd[str(num)].readX(0)[1]))
            deltaD.append(value)
        step_d=max(deltaD)
        dmin=max(minD)
        dmax=min(maxD)
        dRebin=str(dmin)+','+str(step_d)+','+str(dmax)
        return dRebin

    def sumModules(self, moduleNums, d_rebin, output):
        for j in range(moduleNums):
            if j==0:
                Rebin(InputWorkspace=str(j),OutputWorkspace='ans', Params=d_rebin)
            else:
                Rebin(InputWorkspace=str(j),OutputWorkspace='tmp', Params=d_rebin)
                mtd['ans']+=mtd['tmp']
        RenameWorkspace(InputWorkspace='ans',OutputWorkspace=output)

    def getPidRange(self, moduleNum, xstep, num):
        start = num*xstep
        no=moduleNum/100
        pidstr=''
        if no%2==0:
            for j in range(48):
                end=111*j+110-xstep*num
                sta=end-xstep
                pidstr+=str(sta)+'-'+str(end)+','
        else:
            for j in range(48):
                sta=111*j+start
                end=sta+17
                pidstr+=str(sta)+'-'+str(end)+','

        pidstr=pidstr[:-1]
        return pidstr

    def processV(self, wsname, npoints):
        StripVanadiumPeaks(InputWorkspace=wsname,OutputWorkspace=wsname)
        SmoothData(InputWorkspace=wsname,OutputWorkspace=wsname,NPoints=npoints)
        SetUncertainties(InputWorkspace=wsname,OutputWorkspace=wsname,SetError='zero')



#exp infomation
v_run=[5406,5407]# V-rod
hold_run=[5414,5432]#hold
bg_run=[5415]#background
sam_run=[5474]

path_cur = "/home/dur/work/PDF/"

#vanadium info
vFile=path_cur+'vanadiumInfo.json'
vInfo=sampleInfo(vFile)
vInfo.getAll()
print ("==================vanadium================")
print ('scatt_XS, atten_XS:',vInfo.scattXS,vInfo.attenXS)
print ('molar mass: ', vInfo.molarMass)
print ('density: ',vInfo.density)
v_volume=vInfo.getVolume(vInfo.sam_r,vInfo.getRealHeight())
print ('volume: ',v_volume)
print ('num density: ',vInfo.numDensity)
print ('atom nums: ',vInfo.numDensity*v_volume)
#v_numDensity=0.0721
prefactor = vInfo.numDensity*v_volume/4.0/math.pi*vInfo.scattXS
print ('v_factor: ',prefactor)

#sample info
samFile=path_cur+'sampleInfo.json'
samInfo=sampleInfo(samFile)
samInfo.getAll()
print ("==================sample================")
print ('scatt_XS, atten_XS:',samInfo.scattXS,samInfo.attenXS)
print ('molar mass: ', samInfo.molarMass)
print ('density: ',samInfo.density)
sam_volume=samInfo.getVolume(samInfo.sam_r,samInfo.getRealHeight())
print ('volume: ',sam_volume)
print ('num density: ',samInfo.numDensity)
print ('atom nums: ',samInfo.numDensity*sam_volume)
laue_term=samInfo.b_sqrd_avg/samInfo.b_avg_sqrd
print ("laue_term: ", laue_term)
print ("b_avg_sqrd: ",samInfo.b_avg_sqrd)
print ("b_sqrd_avg: ",samInfo.b_sqrd_avg)
#sys.exit()
waveRebin="0.25,0.001,4.7"
with open(path_cur+"bankInfo.json", 'r') as f1:
    load_dict = json.load(f1)
#print('Please enter the num of bank')
#count=''
count = str(sys.argv[1])
xstep=int(load_dict[count]['step_x'])#get angle(step)
dRebin=str(load_dict[count]["d_rebin"])

angleNos=int(111/xstep)

def loadNxs(expType, count, num):
    fname="/home/dur/work/PDF/tmpNxs/"+expType+"_"+count+"_"+str(num)+".nxs"
    LoadNexus(Filename=fname, OutputWorkspace=expType)

for num in range(angleNos):
    loadNxs('v',count, num)
    loadNxs('bg',count, num)
    loadNxs('hold',count, num)
    loadNxs('sam',count, num)
    #Rebin(InputWorkspace='v',OutputWorkspace='v',Params=dRebin)
    #Rebin(InputWorkspace='bg',OutputWorkspace='bg',Params=dRebin)
    #Rebin(InputWorkspace='hold',OutputWorkspace='hold',Params=dRebin)
    #Rebin(InputWorkspace='sam',OutputWorkspace='sam',Params=dRebin)
    mtd['v']-=mtd['bg']
    mtd['sam']-=(mtd['hold']*0.5)
    mtd['v']/=prefactor
    mtd['sam']=mtd['sam']/mtd['v']
    mtd['sam']/=samInfo.natoms
    ConvertUnits(InputWorkspace='sam',OutputWorkspace='sam_q', Target='MomentumTransfer', AlignBins=False)
    mtd['ans'] = (1. / samInfo.b_avg_sqrd) * mtd['sam_q'] - laue_term + 1
    SaveNexus(Filename='sq_'+count+"_"+str(num)+".nxs",InputWorkspace='ans')
