#!/usr/bin/python

from __future__ import print_function
import vacf
import time
import sys
import h5py
import numpy as np

# limitations:
# 1. step time must be a constant
class Trajectory:
    def __init__(self, fileName=None, atomRepeat = 1, time2fs = 1.):
        if fileName:
            f=h5py.File(fileName,'r')
            #species axes: [atomID,timeStep]
            self.species=f['particles']['all']['species']['value'][0,:][()]
            self.atomType, counts = np.unique(self.species, return_counts=True)
            if atomRepeat!=1:
                self.atomType=np.repeat(self.atomType, atomRepeat)
                counts = np.repeat(counts, atomRepeat)/atomRepeat

            t = f['particles']['all']['species']['time']
            self.deltaTfs = (t[1]-t[0])*time2fs
            #pos axes: [atomID,timeStep, xyz]
            self.trj = self.swapaxes(f['particles']['all']['position']['value'][()])
            # self.vel = self.swapaxes(f['particles']['all']['velocity']['value'][()])
            #box axes: [timeStep, xyz]
            self.box=f['particles']['all']['box']['edges']['value'][()]
            self.minBoxL = self.box.min()
            self.maxBoxL = self.box.max()
            self.qResolution = 2*np.pi/self.minBoxL
            #upper limit of q depence on the step length of the atom

            f.close()
            self.nFrame= self.trj.shape[1]
            self.nAtom = self.trj.shape[0]

            # for atrj in self.trj:
            #     atrj -= np.trunc(atrj/self.box)*self.box

            self.nMolecule = self.gcd(counts[0],counts[1]) #fixme: this is for two element material ONLY
            self.nAtompMolecule = self.nAtom/self.nMolecule

            print('Input data contain the trajectories of', self.nAtom, 'atoms.')
            print("Species",self.species)
            for at, cnt in zip(self.atomType, counts):
                print('Number of type', at, 'atom:', cnt)
            print('Number of molecule:',self.nMolecule)
            print('Atom per molecule:',self.nAtompMolecule, self.species[0:self.nAtompMolecule])
            print('Number of frames:', self.nFrame, '.')
            print('Time step spacing:', self.deltaTfs, 'fs.')

            print('Time period:', (self.nFrame-1)*self.deltaTfs, 'fs.')
            print('Box size lengths are distributed between ', self.minBoxL, 'Aa and ', self.maxBoxL, 'Aa' )
            print('Q resolution is better then', self.qResolution, 'Aa^-1')

    def getAtomType(self):
        return self.atomType

    def gcd(self, x, y):
        while(y):
            x, y = y, x % y
        return x

    def swapaxes(self, data):
        data=np.swapaxes(data,0,1)
        return np.ascontiguousarray(data)

    def getDeltaT(self):
        return self.deltaTfs

    def incoInter(self, atomType, QVec, timeVec):
        return self.intermediate(atomType, atomType, QVec, timeVec)

    def getPairIndex(self, atomTypeIdx1, atomTypeIdx2, numPairs=None):
        if numPairs:
            if(numPairs>self.nMolecule):
                raise RuntimeError('number of pair is more then number of molecule')
        else:
            numPairs=self.nMolecule

        if type(atomTypeIdx1) is tuple:
            atomType1=atomTypeIdx1[0]
            atomType1Idx=atomTypeIdx1[1]
        else:
            atomType1=atomTypeIdx1
            atomType1Idx=0

        if type(atomTypeIdx2) is tuple:
            atomType2=atomTypeIdx2[0]
            atomType2Idx=atomTypeIdx2[1]
        else:
            atomType2=atomTypeIdx2
            atomType2Idx=0

        atomID = np.arange(0, self.nAtompMolecule*numPairs, self.nAtompMolecule, dtype=np.uint32)
        molecule=self.species[0:self.nAtompMolecule]
        idx1 = np.where(molecule == atomType1)[0][atomType1Idx]
        idx2 = np.where(molecule == atomType2)[0][atomType2Idx]

        atomicPair = np.ascontiguousarray(np.tile(atomID,(2,1)).swapaxes(0,1))
        atomicPair[:,0]+=idx1
        atomicPair[:,1]+=idx2
        print('Selected', len(atomicPair), 'atomic pairs. First few index of the paris:')
        for p in range(min(atomicPair.shape[0],10)):
            print('(',atomicPair[p,0], atomicPair[p,1], ')')
        return atomicPair

    def structFact(self, atomType1, atomType2, QVec, timeVec, numPair=None):
        if atomType1 not in self.atomType:
             raise RuntimeError('atom type ', atomType1, ' is not found in the atom type list')
        if atomType2 not in self.atomType:
             raise RuntimeError('atom type ', atomType2, ' is not found in the atom type list')

        #fixme: check if elements in timeVec are greater than nTimeStep
        FsTQ = np.zeros([timeVec.size, QVec.size ], dtype = np.float64)

        atomicPair= self.getPairIndex(atomType1, atomType2, numPair)

        vacf.structFact(self.trj, self.trj.size,
            atomicPair, atomicPair.size,
            self.box, self.box.size,
            QVec, timeVec,
            self.nAtom, self.nFrame, FsTQ, FsTQ.size);

        return FsTQ

    def intermediate(self, atomType1, atomType2, QVec, timeVec, numPair=None):
        if atomType1 not in self.atomType:
             raise RuntimeError('atom type ', atomType1, ' is not found in the atom type list')
        if atomType2 not in self.atomType:
             raise RuntimeError('atom type ', atomType2, ' is not found in the atom type list')

        #fixme: check if elements in timeVec are greater than nTimeStep
        FsTQ = np.zeros([timeVec.size, QVec.size ], dtype = np.float64)

        atomicPair= self.getPairIndex(atomType1, atomType2, numPair)

        vacf.intermediate(self.trj, self.trj.size,
            atomicPair, atomicPair.size,
            self.box, self.box.size,
            QVec, timeVec,
            self.nAtom, self.nFrame, FsTQ, FsTQ.size);

        return FsTQ

    def gtr(self, atomType1, atomType2, wavelengthCut, mWlBin):
        if atomType1 not in self.atomType:
             raise RuntimeError('atom type ', atomType1, ' is not found in the atom type list')
        if atomType2 not in self.atomType:
             raise RuntimeError('atom type ', atomType2, ' is not found in the atom type list')

        molecule=self.species[0:self.nAtompMolecule]
        idx1 = np.where(molecule == atomType1)[0][0]
        idx2 = np.where(molecule == atomType2)[0][0]
        print("first atomic pair idex in the trajectory file ", idx1, idx2)

        atomID = np.arange(0, self.nAtom, self.nAtompMolecule, dtype=np.uint32)
        atomicPair = np.ascontiguousarray(np.tile(atomID,(2,1)).swapaxes(0,1))
        atomicPair[:,0]+=idx1
        atomicPair[:,1]+=idx2

        gtr=np.zeros([self.nFrame//2,mWlBin], dtype = np.float64)

        return  vacf.gtr_vec(self.trj, self.trj.size,
                    atomicPair, atomicPair.size,
                    self.box, self.box.size,
                    gtr, gtr.size,
                    mWlBin, wavelengthCut,
                    self.deltaTfs, self.nAtom, self.nFrame);


    def cvv(self, atomType1, atomType2, numPair=None, atomType2Idx=0):
        if atomType1 not in self.atomType:
             raise RuntimeError('atom type ', atomType1, ' is not found in the atom type list')
        if atomType2 not in self.atomType:
             raise RuntimeError('atom type ', atomType2, ' is not found in the atom type list')

        atomicPair= self.getPairIndex(atomType1, atomType2, numPair)

        cvvVec = np.zeros(self.nFrame//2, dtype = np.float64)
        vacf.cvv(self.vel, self.vel.size, atomicPair, atomicPair.size,
                cvvVec, cvvVec.size, self.nAtom, self.nFrame)
        return cvvVec
