#!/usr/bin/python

from __future__ import print_function
import h5py
import numpy as np
import sys
import matplotlib.pyplot as plt

hz2eV = 4.13566769692386e-15  #(source: NIST/CODATA 2018)
radpsec2meV =  hz2eV/(2*np.pi)/1e-3
radpfs2meV = radpsec2meV*1e15

#constants used in NCrystal
constant_c  = 299792458e10 # speed of light in Aa/s
constant_dalton2kg =  1.660539040e-27  # amu to kg (source: NIST/CODATA 2018)
constant_dalton2eVc2 =  931494095.17  # amu to eV/c^2 (source: NIST/CODATA 2018)
constant_avogadro = 6.022140857e23  # mol^-1 (source: NIST/CODATA 2018)
constant_boltzmann = 8.6173303e-5   # eV/K
const_neutron_mass = 1.674927471e-24  #gram
const_neutron_mass_evc2 = 1.0454075098625835e-28  #eV/(Aa/s)^2  #fixme: why not calculated from other constants).#<EXCLUDE-IN-NC1BRANCH>
const_neutron_atomic_mass = 1.00866491588  #atomic unit
constant_planck = 4.135667662e-15  #[eV*s]
constant_hbar = constant_planck*0.5/np.pi  #[eV*s]
constant_ekin2v = np.sqrt(2.0/const_neutron_mass_evc2)  #multiply with sqrt(ekin) to get velocity in Aa/s#<EXCLUDE-IN-NC1BRANCH>


def symmetrize(signal, axis=0):
    """Return a symmetrized version of an input signal

    :Parameters:
        #. signal (np.array): the input signal
        #. axis (int): the axis along which the signal should be symmetrized
    :Returns:
        #. np.array: the symmetrized signal
    """
    signal=signal.astype(np.complex128)
    s = [slice(None)]*signal.ndim
    s[axis] = slice(-1,0,-1)

    signal = np.concatenate((signal[tuple(s)].conjugate(),signal),axis=axis)

    return signal

def get_spectrum(signal, axis=0):

    signal = symmetrize(signal,axis)

    signal = np.ascontiguousarray(signal)

    s = [np.newaxis]*signal.ndim
    s[axis] = slice(None)

    # We compute the unitary inverse fourier transform with angular frequencies as described in
    # https://en.wikipedia.org/wiki/Fourier_transform#Discrete_Fourier_Transforms_and_Fast_Fourier_Transforms

    # For information about the manipulation around fftshift and ifftshift
    # http://www.mathworks.com/matlabcentral/newsreader/view_thread/285244

    # fftSignal = 0.5*np.fft.fftshift(np.fft.fft(signal,axis=axis),axes=axis)/np.pi
    fftSignal = 0.5*np.fft.fftshift(np.fft.fft(np.fft.ifftshift(signal,axes=axis),axis=axis),axes=axis)/np.pi
    print('real',np.abs(fftSignal.real).sum(), 'imag', np.abs(fftSignal.imag).sum())
    return np.abs(fftSignal.real)/signal.shape[axis]

class DensityOfState():
    def __init__(self, fre, dos):
        scale = np.trapz(dos,fre)
        self.maxfre = fre[-1]
        self.dos = dos/scale
        self.fre = fre
        print('scale', scale, 'scaled density of state', np.trapz(self.dos, self.fre))

    def getDensity(self, fre):
        return np.interp(np.abs(fre), self.fre, self.dos)

fname = 'locateh2o.h5'
if len(sys.argv)==2:
    fname = sys.argv[1]


f=h5py.File(fname,'r')

hh=f['hh']
cvv=hh['cvv'][()]
Q=hh['Q'][()]
deltaTfs=hh['deltaTfs'][()]
ftq=hh['ftq'][()]
tfs=hh['t'][()]
f.close()

if np.count_nonzero(np.isnan(ftq)): # a bug before 30 Apr. ftq is nan when time window is zero.
    ftq[0,:]=1

plt.figure()
for idx in np.arange(0,Q.size, Q.size//10):
    plt.semilogx(tfs, ftq[:,idx], label='Q='+str(Q[idx]))
plt.legend()


fs=1e-15

deltaT = deltaTfs*fs
dos = get_spectrum(cvv )
omega  =np.fft.fftfreq(dos.size)/deltaT*2*np.pi #divided by 2 because of the filp
dos = dos[dos.size//2:dos.size]
omega = omega[0:dos.size//2]

enSpacing = (omega[1]-omega[0])*radpsec2meV
idx=int(500./enSpacing) #energy cut 500meV

vdos=DensityOfState(omega[0:idx], dos[0:idx])
kt=0.0253
perkt=1./kt #in 1/eV

plt.figure()

x=np.linspace(0,vdos.fre.max(),100)
plt.plot(x*radpsec2meV, vdos.getDensity(x)/radpsec2meV, label='getDensity')
plt.xlabel('freq, meV')
plt.ylabel('dos, meV^-1')
plt.legend()




def integrandQuantumReal (omega, vdos, t):
    if t==0.:
        return np.zeros(omega.size)
    result = constant_hbar/const_neutron_mass_evc2*vdos.getDensity(omega)  / (omega) *((1-np.cos(omega*t))/np.tanh(constant_hbar*0.5*omega*perkt))
    if omega[0]==0.:
        result[0]=kt*t**2*vdos.getDensity(0.)/const_neutron_mass_evc2
    return result


def integrandQuantumIm(omega, vdos, t):
    if t==0.:
        return np.zeros(omega.size)
    result = constant_hbar/const_neutron_mass_evc2*vdos.getDensity(omega)  / (omega) * np.sin(omega*t)
    if omega[0]==0.:
        result[0]=constant_hbar*t*vdos.getDensity(0.)/const_neutron_mass_evc2
    return result

def integrandClassic(omega, vdos, t):
    if t==0.:
        return np.zeros(omega.size)
    result = 2*kt*vdos.getDensity(omega)*(1-np.cos(omega*t))/(const_neutron_mass_evc2*omega**2)
    if omega[0]==0.:
        result[0]=kt*t**2*vdos.getDensity(0.)/const_neutron_mass_evc2
    return result



gamma_real=np.zeros(tfs.size)
gamma_im=np.zeros(tfs.size)
gamma_c=np.zeros(tfs.size)

fre_domain=np.copy(vdos.fre)

for i in range(tfs.size):
    t=tfs[i]*fs
    qntReal = integrandQuantumReal(fre_domain, vdos,t)
    qntIm= integrandQuantumIm(fre_domain, vdos,t)
    classic= integrandClassic(fre_domain, vdos,t)

    gamma_real[i]=(np.trapz(qntReal,fre_domain))
    gamma_im[i]=(np.trapz(qntIm,fre_domain))
    gamma_c[i]=(np.trapz(classic,fre_domain))
#
    # print("integral real", gamma_real[i], t)
    # import sys
    # sys.exist()
#    print( "integral im", gamma_im)
#    print( "classic ", gamma_c)
#
#    plt.figure()
#    plt.plot(domain,qntReal,label='real')
#    plt.plot(domain,qntIm, label='im')
#    plt.plot(domain,classic, label='classic')
#
#    plt.legend()


diff=(gamma_real**2+gamma_im**2)**.5-gamma_c
kappa=1.6
f = np.exp(-0.5*kappa**2*diff)
plt.figure()
plt.semilogx(tfs*1e-3, f)
plt.title('quantum correction factor')

#
#
##from multiprocessing import Pool
##p = Pool(5)
##print(p.map(f, [1, 2, 3]))
#

diff=gamma_real+1j*gamma_im-gamma_c
quantumCorrection=np.zeros(ftq.shape, dtype=np.complex128)
for i in range(ftq.shape[1]):
    quantumCorrection[:,i] = np.exp(-0.5*Q[i]**2*diff)

swq = get_spectrum(ftq*quantumCorrection, axis=0)
# swq = get_spectrum(ftq, axis=0)

en=np.fft.fftfreq(swq.shape[0])/deltaT*2*np.pi*radpsec2meV
en=np.fft.fftshift(en)

plt.figure()
for i in [30,100,150]:
    order = 10
    plt.semilogy(en, swq[:,i], label='Q='+str(Q[i]))

plt.legend()


def qmeV(theta, en, enscat):
    ratio = enscat/en
    k0=np.sqrt(en/2.072124652399821)
    deg = np.pi/180.
    scale = np.sqrt(1.+ ratio - 2*np.cos(theta*deg) *np.sqrt(ratio) )
    return k0*scale

from scipy import interpolate
intSwq = interpolate.interp2d(en, Q, swq.T)

scatE =np.linspace(10., 1000., 1000)
en0 = 154.
angledeg=14
cab7Q =  qmeV(angledeg, en0, scatE )

#plt.plot(scatE,cab7Q)


xs=[]
for cq, dltE in zip(cab7Q, scatE-en0):
    xs.append( intSwq(dltE, cq))
xs=np.array(xs)

plt.figure()
plt.semilogy(scatE*1e-3, xs*1e3*80.26*2) #in eV
plt.show()
