#!/usr/bin/python

from __future__ import print_function
import h5py
import numpy as np

hz2eV = 4.13566769692386e-15  #(source: NIST/CODATA 2018)
radpsec2meV =  hz2eV/(2*np.pi)/1e-3
radpsec2eV = radpsec2meV*1e-3
radpfs2meV = radpsec2meV*1e15

#constants used in NCrystal
constant_c  = 299792458e10 # speed of light in Aa/s
constant_dalton2kg =  1.660539040e-27  # amu to kg (source: NIST/CODATA 2018)
constant_dalton2eVc2 =  931494095.17  # amu to eV/c^2 (source: NIST/CODATA 2018)
constant_avogadro = 6.022140857e23  # mol^-1 (source: NIST/CODATA 2018)
constant_boltzmann = 8.6173303e-5   # eV/K
const_neutron_mass = 1.674927471e-24  #gram
const_neutron_mass_evc2 = 1.0454075098625835e-28  #eV/(Aa/s)^2  #fixme: why not calculated from other constants).#<EXCLUDE-IN-NC1BRANCH>
const_neutron_atomic_mass = 1.00866491588  #atomic unit
constant_planck = 4.135667662e-15  #[eV*s]
constant_hbar = constant_planck*0.5/np.pi  #[eV*s]
constant_ekin2v = np.sqrt(2.0/const_neutron_mass_evc2)  #multiply with sqrt(ekin) to get velocity in Aa/s#<EXCLUDE-IN-NC1BRANCH>

class DensityOfState():
    def __init__(self, fre, dos):
        scale = np.trapz(dos,fre)
        self.maxfre = fre[-1]
        self.dos = dos/scale
        self.fre = fre
        print('scale', scale, 'scaled density of state', np.trapz(self.dos, self.fre))

    def getDensity(self, fre):
        return np.interp(np.abs(fre), self.fre, self.dos)


class NeutronScatteringLaw():
    def __init__(self, fname, kt, coherentLengh, incoXS):
        self.fname = fname
        self.h5file = h5py.File(fname, 'r')
        self.kt = kt
        self.coherentLengh = coherentLengh
        self.incoXS = incoXS

    def __del__(self):
        print('closing file', self.fname)
        self.h5file.close()

    def keys(self):
        return self.h5file.keys()

    def getDos(self, key, energyCuteV):
        hh=self.h5file[key]
        cvv=hh['cvv'][()]
        deltaTfs=hh['deltaTfs'][()]
        fs = 1e-15
        deltaT = deltaTfs*fs
        dos, omega = self.get_spectrum(cvv )
        omega  /= deltaT

        enSpacing = (omega[2]-omega[1])*radpsec2eV
        idx=int(energyCuteV/enSpacing)
        omzeroidx=omega.size//2+1
        if omega[omzeroidx]!=0.:
            raise RuntimeError('omega array should begin with zero')
        # print('omega is', omega[omega.size//2+1])

        vdos=DensityOfState(omega[omzeroidx:omzeroidx+idx], dos[omzeroidx:omzeroidx+idx])
        return vdos

    def quantumSwq(self, key, energyCuteV):
        quantumFtq, tfs, Q = self.quantumFtq(key, energyCuteV)
        swq = self.get_spectrum(quantumFtq, axis = 0)
        omega = np.fft.fftfreq(swq.shape[0])/(tfs[1]-tfs[0])*2*np.pi
        omega = np.fft.fftshift(omega)
        return swq, omega, Q


    def quantumFtq(self, key, energyCuteV):
        hh=self.h5file[key]
        Q=hh['Q'][()]
        deltaTfs=hh['deltaTfs'][()]
        ftq=hh['ftq'][()]
        tfs=hh['t'][()]

        fs = 1e-15
        deltaT = deltaTfs*fs
        vdos = self.getDos(key, energyCuteV)

        gamma_real=np.zeros(tfs.size)
        gamma_im=np.zeros(tfs.size)
        gamma_c=np.zeros(tfs.size)
        fre_domain=np.copy(vdos.fre)

        for i in range(tfs.size):
            t=tfs[i]*fs
            qntReal = self.integrandQuantumReal(fre_domain, vdos,t)
            qntIm= self.integrandQuantumIm(fre_domain, vdos,t)
            classic= self.integrandClassic(fre_domain, vdos,t)

            gamma_real[i]=np.trapz(qntReal,fre_domain)
            gamma_im[i]=np.trapz(qntIm,fre_domain)
            gamma_c[i]=np.trapz(classic,fre_domain)

        diff=gamma_real+1j*gamma_im-gamma_c
        quantumCorrection=np.zeros(ftq.shape, dtype=np.complex128)
        for i in range(ftq.shape[1]):
            quantumCorrection[:,i] = np.exp(-0.5*Q[i]**2*diff)

        return ftq*quantumCorrection, tfs, Q


    def integrandQuantumReal (self, omega, vdos, t):
        if t==0.:
            return np.zeros(omega.size)
        result = constant_hbar/const_neutron_mass_evc2*vdos.getDensity(omega)  / (omega) *((1-np.cos(omega*t))/np.tanh(constant_hbar*0.5*omega/self.kt))
        if omega[0]==0.:
            result[0]=self.kt*t**2*vdos.getDensity(0.)/const_neutron_mass_evc2
        return result


    def integrandQuantumIm(self, omega, vdos, t):
        if t==0.:
            return np.zeros(omega.size)
        result = constant_hbar/const_neutron_mass_evc2*vdos.getDensity(omega)  / (omega) * np.sin(omega*t)
        if omega[0]==0.:
            result[0]=constant_hbar*t*vdos.getDensity(0.)/const_neutron_mass_evc2
        return result

    def integrandClassic(self, omega, vdos, t):
        if t==0.:
            return np.zeros(omega.size)
        result = 2*self.kt*vdos.getDensity(omega)*(1-np.cos(omega*t))/(const_neutron_mass_evc2*omega**2)
        if omega[0]==0.:
            result[0]=self.kt*t**2*vdos.getDensity(0.)/const_neutron_mass_evc2
        return result


    def symmetrize(self, signal, axis=0):
        s = [slice(None)]*signal.ndim
        s[axis] = slice(-1,0,-1)

        # s2 = [slice(None)]*signal.ndim
        # s2[axis] = slice(0,-1,1)
        # signal = np.concatenate((signal[tuple(s)].conjugate(),signal[tuple(s2)]),axis=axis)

        signal = np.concatenate((signal[tuple(s)].conjugate(),signal),axis=axis)

        signal = np.ascontiguousarray(signal)
        return signal

    def get_spectrum(self, signal, axis=0):

        signal = self.symmetrize(signal,axis)
        s = [np.newaxis]*signal.ndim
        s[axis] = slice(None)

        # fftSignal = 0.5*np.fft.fftshift(np.fft.fft(signal,axis=axis),axes=axis)/np.pi
        fftSignal = np.fft.fftshift(np.fft.fft(np.fft.ifftshift(signal,axes=axis),axis=axis),axes=axis)/(2*np.pi)
        fre=np.fft.fftshift(np.fft.fftfreq(fftSignal.shape[axis]))*2*np.pi

        fftSize=fftSignal.shape[axis]
        s = [slice(None)]*signal.ndim
        s[axis] = slice(fftSize//4,fftSize//4*3,1)
        fftSignal=fftSignal[tuple(s)]

        fre = fre[fftSize//4:fftSize//4*3]

        print('real',np.abs(fftSignal.real).sum(), 'imag', np.abs(fftSignal.imag).sum())
        return np.abs(fftSignal.real), fre
