﻿# -*- coding: utf-8 -*-

"""
    :author: Junrong Zhang
    :copyright: © 2020 CSNS
    :license: GPL V3, see LICENSE for more details.
"""

from flask_socketio import emit, SocketIO
from flask_rq2 import RQ
myrq=RQ(burst=True)
import mantid.simpleapi as md
from .models import hcline
from .config.development import REDIS_URL, NEXUS_PATH, IDF_PATH, CAL_PATH


#for the searching of mask and calibration files
import os
from os import walk


prepath_data = NEXUS_PATH
monsIDF = IDF_PATH
cal_file_path=CAL_PATH
mask_file_path=CAL_PATH

def find_mask_file():
    mask_files=[]
    for (root, dirs, files) in walk(mask_file_path):
        for f in files:
            if f[0:4] == "mask":
                mask_files.append(f)
    return mask_files
    
    
def find_cal_file():
    cal_files=[]
    for (root, dirs, files) in walk(cal_file_path):
        for f in files:
            if f[0:11] == "calibration":
                cal_files.append(f)
    return cal_files
               
def loadMonsWS(wsList, output):
    for ix in wsList:
        #name='RUN'+str(int(ix)).zfill(7)
        path_tmp = prepath_data+"/"+ix+"/detector.nxs"
        md.LoadCSNSNexus(Instrument = "SANS", Filename = path_tmp, OutputWorkspace = ix, Loadbank=False, Loadmonitor =True, Monitorname = "monitor2,monitor3")
        md.LoadInstrument(Workspace = ix, Filename=monsIDF, RewriteSpectraMap =True)

    for i in range(len(wsList)):
        if len(wsList)== 1:
            md.CloneWorkspace(InputWorkspace=str(wsList[i]), OutputWorkspace=output)
            md.DeleteWorkspace(Workspace=str(wsList[i]))
        elif wsList[i+1]!=wsList[-1]:
            md.Plus(LHSWorkspace=str(wsList[i]),RHSWorkspace=str(wsList[i+1]), OutputWorkspace=str(wsList[i+1]))
            md.DeleteWorkspace(Workspace=str(wsList[i]))
        else:
            md.Plus(LHSWorkspace=str(wsList[i]),RHSWorkspace=str(wsList[i+1]), OutputWorkspace=output)
            md.DeleteWorkspace(Workspace=str(wsList[i]))
            md.DeleteWorkspace(Workspace=str(wsList[i+1]))
            break

def getWSData(wsname):
    md.ConvertToPointData(InputWorkspace=wsname,OutputWorkspace=wsname)
    _x = md.mtd[wsname].readX(0)
    _y = md.mtd[wsname].readY(0)
    return _x, _y

@myrq.job
def plot_trans_empty_trans(ns, data):
    #job = myrq.get_current_job()
    socket = SocketIO(message_queue=REDIS_URL)

    #=======================
    #samRun=data['trans_sample_trans_runno']
    empRun=data['trans_empty_trans_runno']
    
    #loadMonsWS(samRun, "monTranSam") 
    loadMonsWS(empRun, "monTran") 
    #md.ConvertUnits(InputWorkspace='monTranSam', OutputWorkspace='monTranSam', Target='Wavelength', AlignBins=True)
    md.ConvertUnits(InputWorkspace='monTran', OutputWorkspace='monTran', Target='Wavelength', AlignBins=True)
    wave_rebin = data["Trans_Wavelength_Min"]+","+data["Trans_Wavelength_Step"]+","+data["Trans_Wavelength_Max"]
    #md.Rebin(InputWorkspace='monTranSam', OutputWorkspace="monTranSam", Params=wave_rebin)
    md.Rebin(InputWorkspace='monTran', OutputWorkspace="monTran", Params=wave_rebin)
    #md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    #md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)
    md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)
    
    #x is same for different spectrum, which is the wavelength
    #_,sam_2= getWSData("sam_2")
    #_,sam_3= getWSData("sam_3")
    _,emp_2= getWSData("emp_2")
    x,emp_3= getWSData("emp_3")

    _data=[]
    _data.append(x)
    #_data.append(sam_2)
    _data.append(emp_2)
    name=[]
    #name.append("sample")
    name.append("empty")
    series1 = hcline(_data, name, title="Monitor2 Spectrum", xlabel="TOF", ylabel="Intensity")

    _data=[]
    _data.append(x)
    #_data.append(sam_3)
    _data.append(emp_3)
    name=[]
    #name.append("sample")
    name.append("empty")
    series2 = hcline(_data, name, title="Monitor3 Spectrum", xlabel="TOF", ylabel="Intensity")

    series={}
    series['series1']=series1
    series['series2']=series2

    socket.emit(data['event'], series, namespace=ns)


@myrq.job
def plot_trans_sample_trans(ns, data):
    #job = myrq.get_current_job()
    socket = SocketIO(message_queue=REDIS_URL)

    #=======================
    samRun=data['trans_sample_trans_runno']
    #empRun=data['trans_empty_trans_runno']
    
    loadMonsWS(samRun, "monTranSam") 
    #loadMonsWS(empRun, "monTran") 
    md.ConvertUnits(InputWorkspace='monTranSam', OutputWorkspace='monTranSam', Target='Wavelength', AlignBins=True)
    #md.ConvertUnits(InputWorkspace='monTran', OutputWorkspace='monTran', Target='Wavelength', AlignBins=True)
    wave_rebin = data["Trans_Wavelength_Min"]+","+data["Trans_Wavelength_Step"]+","+data["Trans_Wavelength_Max"]
    md.Rebin(InputWorkspace='monTranSam', OutputWorkspace="monTranSam", Params=wave_rebin)
    #md.Rebin(InputWorkspace='monTran', OutputWorkspace="monTran", Params=wave_rebin)
    md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    #md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_2',StartWorkspaceIndex=0, EndWorkspaceIndex=0)
    md.CropWorkspace(InputWorkspace='monTranSam', OutputWorkspace='sam_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)
    #md.CropWorkspace(InputWorkspace='monTran', OutputWorkspace='emp_3',StartWorkspaceIndex=1, EndWorkspaceIndex=1)

    _,sam_2= getWSData("sam_2")
    x,sam_3= getWSData("sam_3")
    #_,emp_2= getWSData("emp_2")
    #x,emp_3= getWSData("emp_3")

    _data=[]
    _data.append(x)
    _data.append(sam_2)
    #_data.append(emp_2)
    name=[]
    name.append("sample")
    #name.append("empty")
    series1 = hcline(_data, name, title="Monitor2 Spectrum", xlabel="TOF", ylabel="Intensity")

    _data=[]
    _data.append(x)
    _data.append(sam_3)
    #_data.append(emp_3)
    name=[]
    name.append("sample")
    #name.append("empty")
    series2 = hcline(_data, name, title="Monitor3 Spectrum", xlabel="TOF", ylabel="Intensity")

    series={}
    series['series1']=series1
    series['series2']=series2

    socket.emit(data['event'], series, namespace=ns)

@myrq.job
def cal_trans(ns, data):
    #job = myrq.get_current_job()
    socket = SocketIO(message_queue=REDIS_URL)

    samRun=data['trans_sample_trans_runno']
    empRun=data['trans_empty_trans_runno']

    loadMonsWS(samRun, "monTranSam") 
    loadMonsWS(empRun, "monTran") 
    md.ConvertUnits(InputWorkspace='monTranSam', OutputWorkspace='monTranSam', Target='Wavelength', AlignBins=True)
    md.ConvertUnits(InputWorkspace='monTran', OutputWorkspace='monTran', Target='Wavelength', AlignBins=True)
    wave_rebin = data["Trans_Wavelength_Min"]+","+data["Trans_Wavelength_Step"]+','+data["Trans_Wavelength_Max"]
    md.Rebin(InputWorkspace='monTranSam', OutputWorkspace="monTranSam", Params=wave_rebin)
    md.Rebin(InputWorkspace='monTran', OutputWorkspace="monTran", Params=wave_rebin)

    md.CalculateTransmission(SampleRunWorkspace='monTranSam', DirectRunWorkspace='monTran', OutputWorkspace='ans', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Log', OutputUnfittedData=True)
    
    fitMethod = data["Trans_TransFittingMethod"]
    if fitMethod == "Smooth":
        smoothNP = data["Trans_Smooth_Ns"]
        md.SmoothData(InputWorkspace='ans_unfitted', OutputWorkspace='ans_unfitted', NPoints=smoothNP)
    elif fitMethod == "Linear":
        md.CalculateTransmission(SampleRunWorkspace='monTranSam', DirectRunWorkspace='monTran', OutputWorkspace='ans', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Linear', OutputUnfittedData=True)
    elif fitMethod == "Raw":
        pass
    else:
        md.CalculateTransmission(SampleRunWorkspace='monTranSam', DirectRunWorkspace='monTran', OutputWorkspace='ans', IncidentBeamMonitor=1, TransmissionMonitor=2, RebinParams=wave_rebin, FitMethod='Log', OutputUnfittedData=True)


    _, p_cal = getWSData("ans")
    x, p_raw = getWSData("ans_unfitted")

    _data=[]
    _data.append(x)
    _data.append(p_raw)
    _data.append(p_cal)
    name=[]
    name.append("raw")
    name.append("fitted")
    series = hcline(_data, name, title="Trasnsimission", xlabel="Wavelength", ylabel="Intensity")
 
    socket.emit(data['event'], series, namespace=ns)
'''
@myrq.job
def find_mask_file(ns):
    mask_files=[]
    for (root, dirs, files) in walk(mask_file_path):
        for f in files:
            if f[-1:4] = "mask":
                mask_files.append(f)
            
        
    socket.emit(data['event'], mask_files, namespace=ns)
'''