#!/usr/bin/python

from __future__ import print_function
import vacf
import time
import sys
import h5py
import numpy as np

fname = 'dump_h5md.h5'
if len(sys.argv)==2:
    fname = sys.argv[1]

f=h5py.File(fname,'r')


#[step, atom, value]
species=f['particles']['all']['species']['value'][()]

def trim(data, spe, key):
  old_shape = data.shape
  data=data[spe==key,:]
  atomLeft = data.shape[0]/old_shape[0]
  #1. get the corret shape
  data=data.reshape([old_shape[0], atomLeft, 3])
  #2. swap axes from [stepID, atomID, vector3] to [atomID, stepID, vector3]
  data=np.swapaxes(data,0,1)
  #3. make data contiguous
  #index used to be "timestep, "
  #pos, trj, atom
  return np.ascontiguousarray(data)

def getDeltaT(time):
  return (time[1]-time[0])  #in fetosecond

pos = trim(f['particles']['all']['position']['value'][()],species,1)
vel= trim(f['particles']['all']['velocity']['value'][()] ,species,1)
deltaTfs = getDeltaT(f['particles']['all']['species']['time'][()]) #in fs
box=f['particles']['all']['box']['edges']['value'][()]
f.close()

nStep= pos.shape[1]
nAtom = pos.shape[0]

#1 rad/sec = 6.5821e-16 eV
radpsec2meV =  6.5821e-13 * 2*np.pi
radpps2meV =  radpsec2meV*1e12
radpfs2meV =  radpsec2meV*1e15

print ('Delta time ',deltaTfs,'fs')
print ('Velocity size after trim ', vel.shape)
print ('box size', box.shape)



#####c
start_time = time.time()

rSize= 100
tSize = nStep/5
wlCut = 12.
gtr = np.zeros([tSize, rSize ], dtype = np.float64)

rSpacingInAa, tSpacingInfs, rScalefact = vacf.gtr_vec(pos, pos.size,
             box, box.size,
             gtr, gtr.size,
             rSize, wlCut,
             deltaTfs, nAtom,  nStep)
print  ('c function elapsed ' , time.time() - start_time, 's')


# import matplotlib.pyplot as plt
# spaRange = np.arange(gtr.shape[1]) *  rSpacingInAa
#
# idx = [1, 10, 67, 67*2 ,67*2**2, 67*2**3]
# for i in idx:
#     #plt.plot(spaRange,  grt[:,i], label = str(tSpacingInfs*i)+'fs')
#     plt.plot(spaRange, scalefact* gtr[i,:], label = str(tSpacingInfs*i)+'fs')
# plt.xlabel('R, Aa')
# plt.legend()
# plt.title('Gr mult by 4pi*r*r')
#
# plt.figure()
# idx = [1, 10, 67, 67*2 ,67*2**2, 67*2**3]
# for i in idx:
#     plt.plot(spaRange,  gtr[i,:], label = str(tSpacingInfs*i)+'fs')
# plt.xlabel('R, Aa')
# plt.legend()
# plt.title('Gr rwa')
# plt.show()

f=h5py.File('gtr.h5','w')
f['gtr'] = gtr
f['rSpacingInAa'] = rSpacingInAa
f['tSpacingInfs'] = tSpacingInfs
f['rScalefact'] = rScalefact
f['rSize'] = rSize
f['tSize'] = tSize

f.close()

# fqt=np.fft.rfft(grt,axis=0)
# sqw=np.fft.rfft(fqt.real,axis=1)
#
# rSpaceResl = wlCut/numSpatialBin
# kSpaceResl = 1./rSpaceResl
# qaxis = np.arange(numSpatialBin) * kSpaceResl # in Aa^-1
# taxis = np.arange(numTimeBin) * deltaTfs #in fs
# eaxis = np.arange(numTimeBin)/deltaTfs * radpfs2meV
# sqw=np.abs(np.fft.fft2(grt))
#
# plt.contour( eaxis,qaxis, sqw)
