﻿# -*- coding: utf-8 -*-

"""
    :author: Junrong Zhang
    :copyright: © 2020 CSNS
    :license: GPL V3, see LICENSE for more details.
"""

from flask_socketio import emit, SocketIO
from flask_rq2 import RQ
myrq=RQ(burst=True)
import mantid.simpleapi as md
from .models import hcline
from .config.development import REDIS_URL, NEXUS_PATH, IDF_PATH, CAL_PATH


#for the searching of mask and calibration files
import os
from os import walk


prepath_data = NEXUS_PATH
monsIDF = IDF_PATH
cal_file_path=CAL_PATH
mask_file_path=CAL_PATH

def find_mask_file():
    mask_files=[]
    for (root, dirs, files) in walk(mask_file_path):
        for f in files:
            if f[0:4] == "mask":
                mask_files.append(f)
    return mask_files
    
    
def find_cal_file():
    cal_files=[]
    for (root, dirs, files) in walk(cal_file_path):
        for f in files:
            if f[0:11] == "calibration":
                cal_files.append(f)
    return cal_files
               
def loadMonsWS(wsList, output):
    for ix in wsList:
        #name='RUN'+str(int(ix)).zfill(7)
        path_tmp = prepath_data+"/"+ix+"/detector.nxs"
        md.LoadCSNSNexus(Instrument = "SANS", Filename = path_tmp, OutputWorkspace = ix, Loadbank=False, Loadmonitor =True, Monitorname = "monitor2,monitor3")
        md.LoadInstrument(Workspace = ix, Filename=monsIDF, RewriteSpectraMap =True)

    for i in range(len(wsList)):
        if len(wsList)== 1:
            md.CloneWorkspace(InputWorkspace=str(wsList[i]), OutputWorkspace=output)
            md.DeleteWorkspace(Workspace=str(wsList[i]))
        elif wsList[i+1]!=wsList[-1]:
            md.Plus(LHSWorkspace=str(wsList[i]),RHSWorkspace=str(wsList[i+1]), OutputWorkspace=str(wsList[i+1]))
            md.DeleteWorkspace(Workspace=str(wsList[i]))
        else:
            md.Plus(LHSWorkspace=str(wsList[i]),RHSWorkspace=str(wsList[i+1]), OutputWorkspace=output)
            md.DeleteWorkspace(Workspace=str(wsList[i]))
            md.DeleteWorkspace(Workspace=str(wsList[i+1]))
            break

def getWSData(wsname):
    md.ConvertToPointData(InputWorkspace=wsname,OutputWorkspace=wsname)
    _x = md.mtd[wsname].readX(0)
    _y = md.mtd[wsname].readY(0)
    return _x, _y

@myrq.job
def plot_trans_empty_trans(ns, data):
    #job = myrq.get_current_job()
    socket = SocketIO(message_queue=REDIS_URL)

    #=======================
    #samRun=data['trans_sample_trans_runno']
    empRun=data['trans_empty_trans_runno']
    
    #loadMonsWS(samRun, "monTranSam") 
    loadMonsWS(empRun, "monTran") 
    #md.ConvertUnits(InputWorkspace='monTranSam', OutputWorkspace='monTranSam', Target='Wavelength', AlignBins=True)
    md.ConvertUnits(InputWorkspace='monTran', OutputWorkspace='monTran', Target='Wavelength', AlignBins=True)
    wave_rebin = data["Trans_Wavelength_Min"]+","+data["Trans_Wavelength_Step"]+","+data["Trans_Wavelength_Max"]
    #md.Rebin(InputWorkspace='monTranSam', OutputWorkspace="monTranSam", Params=wave_rebin)
    md.Rebin(InputWorkspace='monTran', OutputWorkspace="monTran", Params=wave_rebin)
    #md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    #md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)
    md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)
    
    #x is same for different spectrum, which is the wavelength
    #_,sam_2= getWSData("sam_2")
    #_,sam_3= getWSData("sam_3")
    _,emp_2= getWSData("emp_2")
    x,emp_3= getWSData("emp_3")

    _data=[]
    _data.append(x)
    #_data.append(sam_2)
    _data.append(emp_2)
    name=[]
    #name.append("sample")
    name.append("empty")
    series1 = hcline(_data, name, title="Monitor2 Spectrum", xlabel="TOF", ylabel="Intensity")

    _data=[]
    _data.append(x)
    #_data.append(sam_3)
    _data.append(emp_3)
    name=[]
    #name.append("sample")
    name.append("empty")
    series2 = hcline(_data, name, title="Monitor3 Spectrum", xlabel="TOF", ylabel="Intensity")

    series={}
    series['series1']=series1
    series['series2']=series2

    socket.emit(data['event'], series, namespace=ns)


@myrq.job
def plot_trans_sample_trans(ns, data):
    #job = myrq.get_current_job()
    socket = SocketIO(message_queue=REDIS_URL)

    #=======================
    samRun=data['trans_sample_trans_runno']
    #empRun=data['trans_empty_trans_runno']
    
    loadMonsWS(samRun, "monTranSam") 
    #loadMonsWS(empRun, "monTran") 
    md.ConvertUnits(InputWorkspace='monTranSam', OutputWorkspace='monTranSam', Target='Wavelength', AlignBins=True)
    #md.ConvertUnits(InputWorkspace='monTran', OutputWorkspace='monTran', Target='Wavelength', AlignBins=True)
    wave_rebin = data["Trans_Wavelength_Min"]+","+data["Trans_Wavelength_Step"]+","+data["Trans_Wavelength_Max"]
    md.Rebin(InputWorkspace='monTranSam', OutputWorkspace="monTranSam", Params=wave_rebin)
    #md.Rebin(InputWorkspace='monTran', OutputWorkspace="monTran", Params=wave_rebin)
    md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    #md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)
    #md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)

    _,sam_2= getWSData("sam_2")
    x,sam_3= getWSData("sam_3")
    #_,emp_2= getWSData("emp_2")
    #x,emp_3= getWSData("emp_3")

    _data=[]
    _data.append(x)
    _data.append(sam_2)
    #_data.append(emp_2)
    name=[]
    name.append("sample")
    #name.append("empty")
    series1 = hcline(_data, name, title="Monitor2 Spectrum", xlabel="TOF", ylabel="Intensity")

    _data=[]
    _data.append(x)
    _data.append(sam_3)
    #_data.append(emp_3)
    name=[]
    name.append("sample")
    #name.append("empty")
    series2 = hcline(_data, name, title="Monitor3 Spectrum", xlabel="TOF", ylabel="Intensity")

    series={}
    series['series1']=series1
    series['series2']=series2

    socket.emit(data['event'], series, namespace=ns)


@myrq.job
def plot_trans_emptycell_trans(ns, data):
    #job = myrq.get_current_job()
    socket = SocketIO(message_queue=REDIS_URL)

    #=======================
    samRun=data['trans_emptycell_trans_runno']
    #empRun=data['trans_empty_trans_runno']
    
    loadMonsWS(samRun, "monTranSam") 
    #loadMonsWS(empRun, "monTran") 
    md.ConvertUnits(InputWorkspace='monTranSam', OutputWorkspace='monTranSam', Target='Wavelength', AlignBins=True)
    #md.ConvertUnits(InputWorkspace='monTran', OutputWorkspace='monTran', Target='Wavelength', AlignBins=True)
    wave_rebin = data["Trans_Wavelength_Min"]+","+data["Trans_Wavelength_Step"]+","+data["Trans_Wavelength_Max"]
    md.Rebin(InputWorkspace='monTranSam', OutputWorkspace="monTranSam", Params=wave_rebin)
    #md.Rebin(InputWorkspace='monTran', OutputWorkspace="monTran", Params=wave_rebin)
    md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    #md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)
    #md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)

    _,sam_2= getWSData("sam_2")
    x,sam_3= getWSData("sam_3")
    #_,emp_2= getWSData("emp_2")
    #x,emp_3= getWSData("emp_3")

    _data=[]
    _data.append(x)
    _data.append(sam_2)
    #_data.append(emp_2)
    name=[]
    name.append("emptycell")
    #name.append("empty")
    series1 = hcline(_data, name, title="Monitor2 Spectrum", xlabel="TOF", ylabel="Intensity")

    _data=[]
    _data.append(x)
    _data.append(sam_3)
    #_data.append(emp_3)
    name=[]
    name.append("emptycell")
    #name.append("empty")
    series2 = hcline(_data, name, title="Monitor3 Spectrum", xlabel="TOF", ylabel="Intensity")

    series={}
    series['series1']=series1
    series['series2']=series2

    socket.emit(data['event'], series, namespace=ns)


@myrq.job
def plot_trans_solvent_trans(ns, data):
    #job = myrq.get_current_job()
    socket = SocketIO(message_queue=REDIS_URL)

    #=======================
    samRun=data['trans_solvent_trans_runno']
    #empRun=data['trans_empty_trans_runno']
    
    loadMonsWS(samRun, "monTranSam") 
    #loadMonsWS(empRun, "monTran") 
    md.ConvertUnits(InputWorkspace='monTranSam', OutputWorkspace='monTranSam', Target='Wavelength', AlignBins=True)
    #md.ConvertUnits(InputWorkspace='monTran', OutputWorkspace='monTran', Target='Wavelength', AlignBins=True)
    wave_rebin = data["Trans_Wavelength_Min"]+","+data["Trans_Wavelength_Step"]+","+data["Trans_Wavelength_Max"]
    md.Rebin(InputWorkspace='monTranSam', OutputWorkspace="monTranSam", Params=wave_rebin)
    #md.Rebin(InputWorkspace='monTran', OutputWorkspace="monTran", Params=wave_rebin)
    md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    #md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)
    #md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)

    _,sam_2= getWSData("sam_2")
    x,sam_3= getWSData("sam_3")
    #_,emp_2= getWSData("emp_2")
    #x,emp_3= getWSData("emp_3")

    _data=[]
    _data.append(x)
    _data.append(sam_2)
    #_data.append(emp_2)
    name=[]
    name.append("solvent")
    #name.append("empty")
    series1 = hcline(_data, name, title="Monitor2 Spectrum", xlabel="TOF", ylabel="Intensity")

    _data=[]
    _data.append(x)
    _data.append(sam_3)
    #_data.append(emp_3)
    name=[]
    name.append("solvent")
    #name.append("empty")
    series2 = hcline(_data, name, title="Monitor3 Spectrum", xlabel="TOF", ylabel="Intensity")

    series={}
    series['series1']=series1
    series['series2']=series2

    socket.emit(data['event'], series, namespace=ns)


@myrq.job
def cal_trans(ns, data):
    #job = myrq.get_current_job()
    socket = SocketIO(message_queue=REDIS_URL)
    print(data)
    
    series={}
    
    samRun=data['trans_sample_trans_runno']
    empRun=data['trans_empty_trans_runno']
    #bool emptycellChecked
    emptycellChecked=data['is_emptycell']
    #bool solventChecked
    solventChecked=data['is_solvent']
    #print(data['is_emptycell'])
    
    loadMonsWS(samRun, "monTranSam") 
    loadMonsWS(empRun, "monTran") 
    md.ConvertUnits(InputWorkspace='monTranSam', OutputWorkspace='monTranSam', Target='Wavelength', AlignBins=True)
    md.ConvertUnits(InputWorkspace='monTran', OutputWorkspace='monTran', Target='Wavelength', AlignBins=True)
    wave_rebin = data["Trans_Wavelength_Min"]+","+data["Trans_Wavelength_Step"]+','+data["Trans_Wavelength_Max"]
    md.Rebin(InputWorkspace='monTranSam', OutputWorkspace="monTranSam", Params=wave_rebin)
    md.Rebin(InputWorkspace='monTran', OutputWorkspace="monTran", Params=wave_rebin)
    print(wave_rebin)
    md.CalculateTransmission(SampleRunWorkspace='monTranSam', DirectRunWorkspace='monTran', OutputWorkspace='ans_sam', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Log', OutputUnfittedData=True)
    
    fitMethod = data["Trans_TransFittingMethod"]
    if fitMethod == "Smooth":
        smoothNP = data["Trans_Smooth_Ns"]
        md.SmoothData(InputWorkspace='ans_sam_unfitted', OutputWorkspace='ans_sam_unfitted', NPoints=smoothNP)
    elif fitMethod == "Linear":
        md.CalculateTransmission(SampleRunWorkspace='monTranSam', DirectRunWorkspace='monTran', OutputWorkspace='ans_sam', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Linear', OutputUnfittedData=True)
    elif fitMethod == "Raw":
        pass
    else:
        md.CalculateTransmission(SampleRunWorkspace='monTranSam', DirectRunWorkspace='monTran', OutputWorkspace='ans_sam', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Log', OutputUnfittedData=True)

    
    _, p_cal = getWSData("ans_sam")
    x, p_raw = getWSData("ans_sam_unfitted")
    
    _data=[]
    _data.append(x)
    _data.append(p_cal)
    _data.append(p_raw)
    print(x)
    print(p_cal)
    print(p_raw)
    print(_data)
    name=[]
    name.append("sample_fitted")
    name.append("sample_raw")
    
    
    
    '''
    _dataSampleRaw=[]
    _dataSampleRaw.append(x)
    _dataSampleRaw.append(p_raw)
    name=[]
    name.append("sample_raw")
    
    series1 = hcline(_dataSampleRaw, name, title="Transmission", xlabel="lambda(AA)", ylabel="Trans")
    series['series1']=series1
    
    _dataSampleFitted=[]
    _dataSampleFitted.append(x)
    _dataSampleFitted.append(p_cal)
    name=[]
    name.append("sample_fitted")
    
    series2 = hcline(_dataSampleRaw, name, title="Transmission", xlabel="lambda(AA)", ylabel="Trans")
    series['series2']=series2
    '''
    
    if (emptycellChecked):
        emptycellRun=data['trans_emptycell_trans_runno']
        loadMonsWS(emptycellRun, "monTranEmptycell")
        md.ConvertUnits(InputWorkspace='monTranEmptycell', OutputWorkspace='monTranEmptycell', Target='Wavelength', AlignBins=True)
        md.Rebin(InputWorkspace='monTranEmptycell', OutputWorkspace="monTranEmptycell", Params=wave_rebin)
        print(wave_rebin)
        md.CalculateTransmission(SampleRunWorkspace='monTranEmptycell', DirectRunWorkspace='monTran', OutputWorkspace='ans_emptycell', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Log', OutputUnfittedData=True)
        fitMethod = data["Trans_TransFittingMethod"]
        if fitMethod == "Smooth":
            smoothNP = data["Trans_Smooth_Ns"]
            md.SmoothData(InputWorkspace='ans_emptycell_unfitted', OutputWorkspace='ans_emptycell_unfitted', NPoints=smoothNP)
        elif fitMethod == "Linear":
            md.CalculateTransmission(SampleRunWorkspace='monTranEmptycell', DirectRunWorkspace='monTran', OutputWorkspace='ans_emptycell', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Linear', OutputUnfittedData=True)
        elif fitMethod == "Raw":
            pass
        else:
            md.CalculateTransmission(SampleRunWorkspace='monTranEmptycell', DirectRunWorkspace='monTran', OutputWorkspace='ans_emptycell', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Log', OutputUnfittedData=True)
        
        x, p1_cal = getWSData("ans_emptycell")
        _, p1_raw = getWSData("ans_emptycell_unfitted")
        
        #_data=[]
        #_data.append(x)
        _data.append(p1_cal)
        _data.append(p1_raw)
        
        #name=[]
        name.append("emptycell_fitted")
        name.append("emptycell_raw")
        
        
        
        '''
        _dataEmptycellRaw=[]
        _dataEmptycellRaw.append(x)
        _dataEmptycellRaw.append(p_raw)
        name=[]
        name.append("emptycell_raw")
    
        series3 = hcline(_dataEmptycellRaw, name, title="Transmission", xlabel="lambda(AA)", ylabel="Trans")
        series['series3']=series3
        
        
        _dataEmptycellFitted=[]
        _dataEmptycellFitted.append(x)
        _dataEmptycellFitted.append(p_cal)
        name=[]
        name.append("emptycell_fitted")
        
        series4 = hcline(_dataEmptycellFitted, name, title="Transmission", xlabel="lambda(AA)", ylabel="Trans")
        series['series4']=series4
        '''
        
    if (solventChecked):
        emptycellRun=data['trans_solvent_trans_runno']
        loadMonsWS(emptycellRun, "monTranSolvent")
        md.ConvertUnits(InputWorkspace='monTranSolvent', OutputWorkspace='monTranSolvent', Target='Wavelength', AlignBins=True)
        md.Rebin(InputWorkspace='monTranSolvent', OutputWorkspace="monTranSolvent", Params=wave_rebin)
        md.CalculateTransmission(SampleRunWorkspace='monTranSolvent', DirectRunWorkspace='monTran', OutputWorkspace='ans_solvent', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Log', OutputUnfittedData=True)
        fitMethod = data["Trans_TransFittingMethod"]
        if fitMethod == "Smooth":
            smoothNP = data["Trans_Smooth_Ns"]
            md.SmoothData(InputWorkspace='ans_solvent_unfitted', OutputWorkspace='ans_solvent_unfitted', NPoints=smoothNP)
        elif fitMethod == "Linear":
            md.CalculateTransmission(SampleRunWorkspace='monTranSolvent', DirectRunWorkspace='monTran', OutputWorkspace='ans_solvent', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Linear', OutputUnfittedData=True)
        elif fitMethod == "Raw":
            pass
        else:
            md.CalculateTransmission(SampleRunWorkspace='monTranSolvent', DirectRunWorkspace='monTran', OutputWorkspace='ans_solvent', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Log', OutputUnfittedData=True)
        
        _, p_cal = getWSData("ans_solvent")
        x, p_raw = getWSData("ans_solvent_unfitted")
        
        _data=_data
        _data.append(p_cal)
        _data.append(p_raw)
    
        name.append("solvent_fitted")
        name.append("solvent_raw")
        
        
        
        '''
        _dataSolventRaw=[]
        _dataSolventRaw.append(x)
        _dataSolventRaw.append(p_raw)
        name=[]
        name.append("solvent_raw")
    
        series5 = hcline(_dataSolventRaw, name, title="Transmission", xlabel="lambda(AA)", ylabel="Trans")
        series['series5']=series5
        
        
        _dataSolventFitted=[]
        _dataSolventFitted.append(x)
        _dataSolventFitted.append(p_cal)
        name=[]
        name.append("solvent_fitted")
    
        series6 = hcline(_dataSolventFitted, name, title="Transmission", xlabel="lambda(AA)", ylabel="Trans")
        series['series6']=series6
        '''
        
    
    
    series = hcline(_data, name, title="Transmission", xlabel="Wavelength (AA)", ylabel="Intensity")
 
    socket.emit(data['event'], series, namespace=ns)
'''
@myrq.job
def find_mask_file(ns):
    mask_files=[]
    for (root, dirs, files) in walk(mask_file_path):
        for f in files:
            if f[-1:4] = "mask":
                mask_files.append(f)
            
        
    socket.emit(data['event'], mask_files, namespace=ns)
'''