#include "TrajAna.hh"
#include <numeric>
#include <omp.h>
#include <cmath>
#include "NumpyHist1D.hh"
#include <vector>

//velocity autocorrelation function
double vaf(double* vel, unsigned velSize,
  unsigned idx, unsigned nAtom, unsigned nStep)
{
  return vafRl(vel, velSize,  idx,  nAtom,  nStep);
}


//velocity autocorrelation function
void vaf_vec(double* vel, unsigned velSize,
             double* returnPointer, unsigned retsize,
             unsigned nAtom, unsigned nStep)
{
  if(retsize>nStep*0.5)
    throw std::runtime_error ("vaf_vec: returned vector size should not be greater than half of nStep");

  //#pragma omp parallel for default(none) shared(vel,velSize,returnPointer,retsize,nAtom,nStep)
  for(unsigned i=0;i<retsize;i++)
  {
    if(!i)
      printf("vaf_vec: max thread %d, Num thread %d \n", omp_get_max_threads(), omp_get_num_threads() );
    returnPointer[i]=vaf(vel, velSize, i,nAtom, nStep );
  }
}

void gtr_vec(double* pos, unsigned posSize,
            double *box, unsigned boxSize,
            double* returnPointer, unsigned retsize,
            unsigned numSpatialBin, double wavelengthCut,
            double dltT, unsigned nAtom, unsigned nStep)
{
  unsigned maxLagTimeStep = retsize/numSpatialBin;
  //calculate distance distribution at fixed time
  #pragma omp parallel for default(none) shared(pos, posSize, box, boxSize, dltT, nAtom, nStep, maxLagTimeStep, returnPointer, numSpatialBin, wavelengthCut)
  for(unsigned i=0; i<maxLagTimeStep;i++ ) //fixme i=0 is wasted
  {
    if(i==1)
      printf("grt_vec: max thread %d, Num thread %d \n", omp_get_max_threads(), omp_get_num_threads() );

    //space 0 to 15 Aa, time
    NumpyHist1D hist1d(numSpatialBin, 0., wavelengthCut);
    gtr(hist1d, pos, posSize, box, boxSize, i, nAtom, nStep);
    //Convert the histogram for the distances at a fixed time  to a pointwise distribution density.
    //For self-correlation, density is zero for r=0, except when also t=0, density(0,0) is a delta function.
    //The delta function is not treated at the moment.
    //For collected-correlation, density at the averaged pair distance.
    double *pt = returnPointer + numSpatialBin*i;
    *(pt++) = 0.;
    std::vector<double> density;
    density.reserve(numSpatialBin);
    density.push_back(0.);

    double rSpacing = wavelengthCut/numSpatialBin;
    double r = rSpacing;
    for(auto it = hist1d.getRaw().begin(); it != hist1d.getRaw().end()-1;)
    {
      density.push_back(0.5*(*it+ (*(++it))) );
      r+= rSpacing;
    }

    std::copy(density.begin(), density.end(), returnPointer + numSpatialBin*i);

  }
}

void gtr(NumpyHist1D& hist, double* pos, unsigned posSize,
         double *box, unsigned boxSize,
         unsigned timeIdx, unsigned nAtom, unsigned nStep)
{
  if(posSize!=nAtom*nStep*3)
  {
    printf("posSize %d, posSize cal %d, nAtom %d, nStep %d\n",
            posSize, nAtom*nStep*3, nAtom, nStep);
    throw std::runtime_error ("grt: vector size error");
  }

  if(boxSize!=nStep*3)
  {
    printf("boxSize %d, boxSize cal %d, nAtom %d, nStep %d\n",
            posSize, nAtom*nStep*3, nAtom, nStep);
    throw std::runtime_error ("grt: vector size error");
  }

  for(unsigned n=0;n<nAtom;n++)
  {
    double *pos1 = pos+ n*(nStep*3);
    double *pos2 = pos + n*(nStep*3)+timeIdx*3;
    double *end = pos + n*(nStep*3)+(nStep-timeIdx)*3;
    double *boxpos = box ;
    //very 3 numbers in pos belongs to 1 time step
    //fixme assuming x,y,z are always smaller than 10
    double threshold = 144.;
    for( ; pos1<end; boxpos +=3 )
    {
      double x = (*(pos1++))-(*(pos2++));
      double y = (*(pos1++))-(*(pos2++));
      double z = (*(pos1++))-(*(pos2++));

      double xx = x*x;
      double yy = y*y;
      double zz = z*z;

      if(xx>threshold)
      {
        x = x<0 ? x + boxpos[0] : x - boxpos[0];
        xx = x*x;
      }
      if(yy>threshold)
      {
        y = y<0 ? y + boxpos[1] : y - boxpos[1];
        yy = y*y;
      }
      if(zz>threshold)
      {
        z = z<0 ? z + boxpos[2] : z - boxpos[2];
        zz = z*z;
      }

      hist.fill_unguard(sqrt(xx+yy+zz));
    }
  }
    if(timeIdx%100==0)
    {
      double overScope = hist.getOverflow() + hist.getUnderflow();
      double tot = hist.getIntegral()+overScope;
      printf("lagtime idx %d, miss %g\n", timeIdx, overScope/tot);
      printf("integral %g, underflow %g, overflow %g \n",  hist.getIntegral(),
      hist.getUnderflow(), hist.getOverflow());
    }
}
