#!/usr/bin/python

from __future__ import print_function
import vacf
import time
import sys
import h5py
import numpy as np

# limitations:
# 1. step time must be a constant
class Trajectory:
    def __init__(self, fileName=None):
        if fileName:
            f=h5py.File(fileName,'r')
            #species axes: [atomID,timeStep]
            self.species=f['particles']['all']['species']['value'][0,:][()]
            self.atomType, counts = np.unique(self.species, return_counts=True)
            # // if(!qSpacing)
            # //    qSpacing = 2*M_PI/ (*std::max_element(box, box+boxSize)) *2;

            t = f['particles']['all']['species']['time']
            self.deltaTfs = t[1]-t[0]
            #pos axes: [atomID,timeStep, xyz]
            self.trj = self.swapaxes(f['particles']['all']['position']['value'][()])
            self.vel = self.swapaxes(f['particles']['all']['velocity']['value'][()])
            #box axes: [timeStep, xyz]
            self.box=f['particles']['all']['box']['edges']['value'][()]
            self.minBoxL = self.box.min()
            self.maxBoxL = self.box.max()
            self.qResolution = 0.5/self.minBoxL
            #upper limit of q depence on the step length of the atom

            f.close()
            self.nFrame= self.trj.shape[1]
            self.nAtom = self.trj.shape[0]

            self.nMolecule = self.gcd(counts[0],counts[1]) #fixme: this is for two element material ONLY
            self.nAtompMolecule = self.nAtom/self.nMolecule

            print('Input data contain the trajectories of', self.nAtom, 'atoms.')
            for at, cnt in zip(self.atomType, counts):
                print('Number of type', at, 'atom:', cnt)
            print('Number of molecule:',self.nMolecule)
            print('Atom per molecule:',self.nAtompMolecule, self.species[0:self.nAtompMolecule])
            print('Number of frames:', self.nFrame, '.')
            print('Time step spacing:', self.deltaTfs, 'fs.')

            print('Time period:', (self.nFrame-1)*self.deltaTfs, 'fs.')
            print('Box size lengths are distributed between ', self.minBoxL, 'Aa and ', self.maxBoxL, 'Aa' )
            print('Q resolution is better then', self.qResolution, 'Aa^-1')

    def getAtomType(self):
        return self.atomType

    def gcd(self, x, y):
        while(y):
            x, y = y, x % y
        return x

    def getDeltaT(self):
        return self.deltaTfs

    def swapaxes(self, data):
        data=np.swapaxes(data,0,1)
        return np.ascontiguousarray(data)

    # def axis0fft(self, a, spacing, units=1.):
    #     s = axis0fft(a)
    #     om=np.fft.fftfreq(a.shape[1]*2-2)/spacing*units
    #     return om, s

# sampleRate = 1./(timeStep[1]-timeStep[0]) #per fs
# #1 rad/sec = 6.5821e-16 eV
# radpsec2meV =  6.5821e-13 * 2*np.pi
# radpfs2meV = radpsec2meV*1e15
# fftsize = timeStep.size-2
# om=np.fft.fftfreq(fftsize)*sampleRate*radpfs2meV


    def axis0fft2d(self, a, spacing=None, units=1. ):
        #make an even function to make real result
        if len(a.shape) > 2:
            raise RuntimeError('input of axis0fft must be lower than 2d')
        b=np.flip(a,axis=0)
        c=np.concatenate((a,b[1:]),axis=0)[:-1]
        result=np.fft.fft(c,axis=0)
        result=result.real.T/result.shape[1] #normalisation
        if(spacing):
            om=np.fft.fftfreq(result.shape[1])/spacing*units
            print('om shape', om.shape, ' s shape', result.shape)
            return om[:om.size//2], result[:,:result.shape[1]//2]
        return result

    def fftReal(self, a, spacing, units=1.):
        if len(a.shape) > 1:
            raise RuntimeError('input of fftReal must be 1d array')
        result=np.fft.fft(a)
        om=np.fft.fftfreq(a.size)/spacing*units
        return om[:a.size//2], result[:a.size//2].real

    def incoInter(self, atomType, QVec, timeVec):
        return self.intermediate(atomType, atomType, QVec, timeVec)

    def getPairIndex(self, atomType1, atomType2, numPairs=None):
        if numPairs:
            if(numPairs>self.nMolecule):
                raise RuntimeError('number of pair is more then number of molecule')
        else:
            numPairs=self.nMolecule

        atomID = np.arange(0, self.nAtompMolecule*numPairs, self.nAtompMolecule, dtype=np.uint32)
        molecule=self.species[0:self.nAtompMolecule]
        idx1 = np.where(molecule == atomType1)[0][0]
        idx2 = np.where(molecule == atomType2)[0][0]

        atomicPair = np.ascontiguousarray(np.tile(atomID,(2,1)).swapaxes(0,1))
        atomicPair[:,0]+=idx1
        atomicPair[:,1]+=idx2
        print('Selected', len(atomicPair), 'atomic pairs. First few index of the paris:')
        for p in range(min(atomicPair.shape[0],10)):
            print('(',atomicPair[p,0], atomicPair[p,1], ')')
        return atomicPair


    def intermediate(self, atomType1, atomType2, QVec, timeVec, numPair):
        if atomType1 not in self.atomType:
             raise RuntimeError('atom type ', atomType1, ' is not found in the atom type list')
        if atomType2 not in self.atomType:
             raise RuntimeError('atom type ', atomType2, ' is not found in the atom type list')

        #fixme: check if elements in timeVec are greater than nTimeStep
        FsTQ = np.zeros([timeVec.size, QVec.size ], dtype = np.float64)

        atomicPair= self.getPairIndex(atomType1, atomType2, numPair)

        vacf.intermediate(self.trj, self.trj.size,
            atomicPair, atomicPair.size,
            self.box, self.box.size,
            QVec, timeVec,
            self.nAtom, self.nFrame, FsTQ, FsTQ.size);

        return FsTQ

    def gtr(self, atomType1, atomType2, wavelengthCut, mWlBin):
        if atomType1 not in self.atomType:
             raise RuntimeError('atom type ', atomType1, ' is not found in the atom type list')
        if atomType2 not in self.atomType:
             raise RuntimeError('atom type ', atomType2, ' is not found in the atom type list')

        molecule=self.species[0:self.nAtompMolecule]
        idx1 = np.where(molecule == atomType1)[0][0]
        idx2 = np.where(molecule == atomType2)[0][0]
        print("first atomic pair idex in the trajectory file ", idx1, idx2)

        atomID = np.arange(0, self.nAtom, self.nAtompMolecule, dtype=np.uint32)
        atomicPair = np.ascontiguousarray(np.tile(atomID,(2,1)).swapaxes(0,1))
        atomicPair[:,0]+=idx1
        atomicPair[:,1]+=idx2

        gtr=np.zeros([self.nFrame//2,mWlBin], dtype = np.float64)

        return  vacf.gtr_vec(self.trj, self.trj.size,
                    atomicPair, atomicPair.size,
                    self.box, self.box.size,
                    gtr, gtr.size,
                    mWlBin, wavelengthCut,
                    self.deltaTfs, self.nAtom, self.nFrame);


    def cvv(self, atomType1, atomType2, numPair=None):
        if atomType1 not in self.atomType:
             raise RuntimeError('atom type ', atomType1, ' is not found in the atom type list')
        if atomType2 not in self.atomType:
             raise RuntimeError('atom type ', atomType2, ' is not found in the atom type list')

        atomicPair= self.getPairIndex(atomType1, atomType2, numPair)

        cvvVec = np.zeros(self.nFrame//2, dtype = np.float64)
        vacf.cvv(self.vel, self.vel.size, atomicPair, atomicPair.size,
                cvvVec, cvvVec.size, self.nAtom, self.nFrame)
        return cvvVec
