#!/usr/bin/python

from __future__ import print_function
import h5py
import numpy as np

from constants import *
from helper import real2realFFT


from scipy import interpolate
class DensityOfState():
    def __init__(self, fre, dos):
        scale = np.trapz(dos,fre)
        self.maxfre = fre[-1]
        self.dos = dos/scale
        self.fre = fre
        self.interp = interpolate.interp1d(self.fre, self.dos)
        # print('scale', scale, 'scaled density of state', np.trapz(self.dos, self.fre))

    def getDensity(self, fre):
        return self.interp(np.abs(fre))


class NeutronScatteringLaw():
    def __init__(self, fname, kt, coherentLengh, incoXS):
        self.fname = fname
        self.h5file = h5py.File(fname, 'r')
        self.kt = kt
        self.coherentLengh = coherentLengh
        self.incoXS = incoXS

    def __del__(self):
        print('closing file', self.fname)
        self.h5file.close()

    def keys(self):
        return self.h5file.keys()

    def getDos(self, key, energyCuteV):
        hh=self.h5file[key]
        cvv=hh['cvv'][()]
        deltaTfs=hh['deltaTfs'][()]
        fs = 1e-15
        deltaT = deltaTfs*fs
        dos, omega = real2realFFT(cvv, deltaT/(2*np.pi) ) #to angular frequncy, in rad*Hz

        enSpacing = (omega[2]-omega[1])*const_radpsec2eV
        idx=int(energyCuteV/enSpacing)
        omzeroidx=omega.size//2+1
        if omega[omzeroidx]!=0.:
            raise RuntimeError('omega array should begin with zero')
        # print('omega is', omega[omega.size//2+1])

        vdos=DensityOfState(omega[omzeroidx:omzeroidx+idx], dos[omzeroidx:omzeroidx+idx])
        return vdos

    def quantumSwq(self, key, energyCuteV, quantum=1):
        quantumFtq, t, Q = self.quantumFtq(key, energyCuteV)

        swq , fre = real2realFFT(quantumFtq,(t[1]-t[0])/(2*np.pi), axis=0) #to angular frequncy, in rad*Hz
        swq /= 2*np.pi*const_hbar
        # swq=swq/np.trapz(swq.T, fre)

        if quantum:
            mididx = fre.size//2+1
            detailbal = np.exp(-fre[mididx:]*const_radpsec2meV/25.3)
            swq[mididx:, :] =(np.flip(swq[1:mididx-1, :],axis=0).T*detailbal).T

        return swq, fre, Q


    def quantumFtq(self, key, energyCuteV, quantum=1):
        hh=self.h5file[key]
        Q=hh['Q'][()]
        deltaTfs=hh['deltaTfs'][()]
        ftq=hh['ftq'][()]
        tfs=hh['t'][()]

        fs = 1e-15
        if not quantum:
            return ftq, tfs*fs, Q

        deltaT = deltaTfs*fs
        tmax = (tfs[-1]-tfs[0])*1e-15
        deltaf = 1/tmax
        deltaAngularF = deltaf*2*np.pi
        deltaEneV = deltaAngularF*const_radpsec2eV
        fmax=1./(2.*deltaT)
        enMax = fmax*2*np.pi*const_radpsec2eV
        print('group',key,'summary:')
        print('  Time period is', tmax*1e12, 'ps. Delta energy is therefor', deltaEneV*1e3, 'meV')
        print('  Time step is', deltaT*1e15, 'fs. Max energy is therefor',enMax,'eV'  )

        vdos = self.getDos(key, energyCuteV)
        gamma_real=np.zeros(tfs.size)
        gamma_im=np.zeros(tfs.size)
        gamma_c=np.zeros(tfs.size)
        fre_domain=np.copy(vdos.fre)

        for i in range(tfs.size):
            t=tfs[i]*fs
            qntReal = self.integrandQuantumReal(fre_domain, vdos,t)
            qntIm= self.integrandQuantumIm(fre_domain, vdos,t)
            classic= self.integrandClassic(fre_domain, vdos,t)

            gamma_real[i]=np.trapz(qntReal,fre_domain)
            gamma_im[i]=np.trapz(qntIm,fre_domain)
            gamma_c[i]=np.trapz(classic,fre_domain)

        diff=gamma_real+1j*gamma_im-gamma_c
        quantumCorrection=np.zeros(ftq.shape, dtype=np.complex128)
        for i in range(ftq.shape[1]):
            quantumCorrection[:,i] = np.exp(-0.5*Q[i]**2*diff)

        return ftq*quantumCorrection, tfs*fs, Q


    def integrandQuantumReal (self, omega, vdos, t):
        if t==0.:
            return np.zeros(omega.size)
        result = const_hbar/const_neutron_mass_evc2*vdos.getDensity(omega)  / (omega) *((1-np.cos(omega*t))/np.tanh(const_hbar*0.5*omega/self.kt))
        if omega[0]==0.:
            result[0]=self.kt*t**2*vdos.getDensity(0.)/const_neutron_mass_evc2
        return result


    def integrandQuantumIm(self, omega, vdos, t):
        if t==0.:
            return np.zeros(omega.size)
        result = const_hbar/const_neutron_mass_evc2*vdos.getDensity(omega)  / (omega) * np.sin(omega*t)
        if omega[0]==0.:
            result[0]=const_hbar*t*vdos.getDensity(0.)/const_neutron_mass_evc2
        return result

    def integrandClassic(self, omega, vdos, t):
        if t==0.:
            return np.zeros(omega.size)
        result = 2*self.kt*vdos.getDensity(omega)*(1-np.cos(omega*t))/(const_neutron_mass_evc2*omega**2)
        if omega[0]==0.:
            result[0]=self.kt*t**2*vdos.getDensity(0.)/const_neutron_mass_evc2
        return result
