#include "TrajAna.hh"
#include <numeric>
#include <omp.h>
#include <cmath>
#include <algorithm>
#include "NumpyHist1D.hh"
#include <time.h>       /* time */

// fixme: for intermediate function three different better numerical methods
// are tested. No improvement observed
// 1. tried Lebedev with higher order
// 2. tried cos instead of cosf
// 3. tried stable sum for accumulating cosine values
 
class LebedevQuadrature {
public:
  LebedevQuadrature()
  {
    // //theta, phi, weight
    // std::vector<double> raw = {
    //    0.000000000000000, 90.000000000000000, 0.066666666666667,
    //  180.000000000000000, 90.000000000000000, 0.066666666666667,
    //   90.000000000000000, 90.000000000000000, 0.066666666666667,
    //  -90.000000000000000, 90.000000000000000, 0.066666666666667,
    //   90.000000000000000, 0.000000000000000, 0.066666666666667,
    //   90.000000000000000  , 180.000000000000000, 0.066666666666667,
    //   45.000000000000000,54.735610317245346, 0.075000000000000,
    //   45.000000000000000  , 125.264389682754654, 0.075000000000000,
    //  -45.000000000000000,54.735610317245346, 0.075000000000000,
    //  -45.000000000000000  , 125.264389682754654, 0.07500000000000,
    //  135.000000000000000,54.735610317245346, 0.075000000000000,
    //  135.000000000000000  , 125.264389682754654, 0.075000000000000,
    // -135.000000000000000,54.735610317245346, 0.075000000000000,
    // -135.000000000000000  , 125.264389682754654, 0.075000000000000};

    std::vector<double> raw = {
      0.000000000000000 ,   90.000000000000000  ,   0.166666666666667,
      180.000000000000000 ,   90.000000000000000   ,  0.166666666666667,
      90.000000000000000  ,  90.000000000000000  ,   0.166666666666667,
      -90.000000000000000 ,   90.000000000000000 ,    0.166666666666667,
      90.000000000000000  ,   0.000000000000000  ,   0.166666666666667,
      90.000000000000000  , 180.000000000000000  ,   0.166666666666667};

    for(auto it=raw.begin(); it<raw.end(); it+=3)
    {
      double theta = *it;
      double phi = *(it+1);
      double w = *(it+2);
      m_dir_w.push_back(std::vector<double> {cos ( theta ) * sin ( phi ), sin ( theta ) * sin ( phi ), cos ( phi ), w} );
    }
  }
  ~LebedevQuadrature(){}

  std::vector<std::vector<double>>::const_iterator begin() const {return m_dir_w.begin();};
  std::vector<std::vector<double>>::const_iterator end() const {return m_dir_w.end();};

private:
  std::vector<std::vector<double>> m_dir_w;
};


void intermediate(double* traj, unsigned trajSize,
  unsigned *atomicPair, unsigned atomicPairSize,
  double *box, unsigned boxSize,
  double *q, unsigned qSize,
  unsigned *timeStep, unsigned TimStepSize,
  unsigned nAtom, unsigned nTimeStep, double* FsTQ, unsigned FsTQSize)
{

  if(trajSize!=nAtom*nTimeStep*3)
  {
    throw std::runtime_error ("incoherentIntermediate: position vector size error");
  }
  if(FsTQSize!=qSize*TimStepSize)
  {
    throw std::runtime_error ("incoherentIntermediate: return vector size error");
  }


  std::vector<double*> trajOfAtom1;
  std::vector<double*> trajOfAtom2;
  for(unsigned i=0; i<atomicPairSize/2;i++)
  {
    trajOfAtom1.push_back(traj+ (*(atomicPair+i*2))*nTimeStep*3);
    trajOfAtom2.push_back(traj+ (*(atomicPair+i*2+1))*nTimeStep*3);
  }

  #pragma omp parallel for default(none) shared(TimStepSize,trajOfAtom1, trajOfAtom2, nTimeStep, timeStep, box, boxSize, q, qSize, FsTQ)
  for(unsigned i=0;i<TimStepSize;i++)
  {
    std::vector<double> fq;
    qtCorrelation(trajOfAtom1, trajOfAtom2,  nTimeStep, timeStep[i],
      box, boxSize, q, qSize, fq);
    std::copy(fq.begin(), fq.end(), FsTQ + qSize*i );
  }

}

void qtCorrelation(const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2, unsigned nTimeStep, unsigned timeIdx,
  double *box, unsigned boxSize, double *q, unsigned qSize, std::vector<double>& fq)
{
  if(atomPos1.size() != atomPos2.size())
  {
    throw std::runtime_error ("intScat: atomPos vector size error");
  }
  if(boxSize!=nTimeStep*3)
  {
    throw std::runtime_error ("intScat: box size error");
  }
  fq.resize(0);
  fq.resize(qSize,0.);

  LebedevQuadrature lquad;
  double qdirX(0.),qdirY(0.),qdirZ(1.);

  // std::vector<StableSum> ssum(qSize);

  double *pFq = fq.data();
  // // #pragma omp parallel for reduction(+:pFq[:qNum])
  for(unsigned n=0;n<atomPos1.size();n++)
  {
    double *pos1 = atomPos1[n];
    double *pos2 = atomPos2[n] + timeIdx*3;
    double *end = pos1 + (nTimeStep-timeIdx)*3;
    //very 3 numbers in pos belongs to 1 time step
    //fixme assuming x,y,z are always smaller than 10
    double *boxpos = box ;

    // loop over time steps
    for( ; pos1<end; boxpos +=3 )
    {
      double x = (*(pos1++))-(*(pos2++));
      double y = (*(pos1++))-(*(pos2++));
      double z = (*(pos1++))-(*(pos2++));

      if(ncabs(x)*2 > *boxpos )
      {
        x = x<0. ? x + *boxpos : x - *boxpos;
      }
      if(ncabs(y)*2 > *(boxpos+1) )
      {
        y = y<0. ? y + *(boxpos+1) : y - *(boxpos+1);
      }
      if(ncabs(z)*2 > *(boxpos+2))
      {
        z = z<0. ? z + *(boxpos+2) : z - *(boxpos+2);
      }

      // //get mean
      for(auto it=lquad.begin();it!=lquad.end();++it)
      {
        double dotprt = x*(*it)[0] + y*(*it)[1] +z*(*it)[2];
        double sum(0.);
        for(unsigned iq=0; iq<qSize;iq++)
        {
          //fixeme: use cosf?
          *(pFq+iq)  += cosf(dotprt*(*(q+iq)))*(*it)[3];
          // ssum[iq].add(cosf(dotprt*(*(q+iq)))*(*it)[3]);
        }
      }

      // // isotropic random number in fact does the same trick. TBI
      // randisotropic(qdirX,qdirY,qdirZ);
      // double dotprt = x*qdirX + y*qdirY +z*qdirZ;
      // //loop over q
      // for(unsigned iq=0; iq<qNum;iq++)
      // {
      //   *(pFq+iq) += cos(dotprt*((iq+1)*qSpacing) );
      // }
    }
  }
  double normFact = 1./( atomPos1.size()*(nTimeStep-timeIdx) );
  for (auto &v : fq)
    v *= normFact;
  // for(unsigned i=0;i<qSize;i++)
  // {
  //   fq[i]=ssum[i].sum()*normFact;
  // }
}

double fsqt(double* atomPos1, unsigned atomPos1Size,
  double* atomPos2, unsigned atomPos2Size,
  double *box, unsigned boxSize, double q,
  unsigned timeIdx, unsigned nAtom, unsigned nTimeStep)
{
  if(atomPos1Size!=nAtom*nTimeStep*3 || atomPos1Size!=atomPos2Size)
  {
    throw std::runtime_error ("intScat: vector size error");
  }

  LebedevQuadrature lquad;
  double tot (0.);
  double qdirX(0.),qdirY(0.),qdirZ(1.);


  for(unsigned n=0;n<nAtom;n++)
  {
    double *pos1 = atomPos1 + n*(nTimeStep*3);
    double *pos2 = atomPos2 + n*(nTimeStep*3)+timeIdx*3;
    double *end = pos1 + (nTimeStep-timeIdx)*3;
    //very 3 numbers in pos belongs to 1 time step
    //fixme assuming x,y,z are always smaller than 10
    double *boxpos = box ;

    // printf("%g\n", threshold);
    // abort();
    for( ; pos1<end; boxpos +=3 )
    {
      double x = (*(pos1++))-(*(pos2++));
      double y = (*(pos1++))-(*(pos2++));
      double z = (*(pos1++))-(*(pos2++));

      if(ncabs(x)*2 > *boxpos )
      {
        x = x<0. ? x + *boxpos : x - *boxpos;
      }
      if(ncabs(y)*2 > *(boxpos+1) )
      {
        y = y<0. ? y + *(boxpos+1) : y - *(boxpos+1);
      }
      if(ncabs(z)*2 > *(boxpos+2))
      {
        z = z<0. ? z + *(boxpos+2) : z - *(boxpos+2);
      }

      // //get mean
      // for(auto it=lquad.begin();it!=lquad.end();++it)
      // {
      //   double dotprt = x*(*it)[0]*q + y*(*it)[1]*q +z*(*it)[2]*q;
      //   tot += cos(dotprt)*(*it)[3];
      // }

      // isotropic random number in fact does the same trick. TBI
      randisotropic(qdirX,qdirY,qdirZ);
      double dotprt = x*qdirX*q + y*qdirY*q +z*qdirZ*q;
      tot += cos(dotprt);

    }
  }
  return tot/(nAtom*(nTimeStep-timeIdx));
}


//velocity autocorrelation function
double vaf(double* vel, unsigned velSize,
  unsigned idx, unsigned nAtom, unsigned nTimeStep)
{
  return vafRl(vel, velSize,  idx,  nAtom,  nTimeStep);
}

//velocity autocorrelation function
void vaf_vec(double* vel, unsigned velSize,
             double* returnPointer, unsigned retsize,
             unsigned nAtom, unsigned nTimeStep)
{
  if(retsize>nTimeStep*0.5)
    throw std::runtime_error ("vaf_vec: returned vector size should not be greater than half of nTimeStep");

  #pragma omp parallel for default(none) shared(vel,velSize,returnPointer,retsize,nAtom,nTimeStep)
  for(unsigned i=0;i<retsize;i++)
  {
    if(!i)
      printf("vaf_vec: max thread %d, Num thread %d \n", omp_get_max_threads(), omp_get_num_threads() );
    returnPointer[i]=vaf(vel, velSize, i,nAtom, nTimeStep );
  }
}

void gtr_vec(double* pos, unsigned posSize,
            double *box, unsigned boxSize,
            double* returnPointer, unsigned retsize,
            unsigned numSpatialBin, double wavelengthCut,
            double dltT, unsigned nAtom, unsigned nTimeStep)
{
  unsigned maxLagTimeStep = retsize/numSpatialBin;
  //calculate distance distribution at fixed time
  #pragma omp parallel for default(none) shared(pos, posSize, box, boxSize, dltT, nAtom, nTimeStep, maxLagTimeStep, returnPointer, numSpatialBin, wavelengthCut)
  for(unsigned i=0; i<maxLagTimeStep;i++ ) //fixme i=0 is wasted
  {
    if(i==1)
      printf("grt_vec: max thread %d, Num thread %d \n", omp_get_max_threads(), omp_get_num_threads() );

    //space 0 to 15 Aa, time
    NumpyHist1D hist1d(numSpatialBin, 0., wavelengthCut);
    //gtr(hist1d, pos, posSize, box, boxSize, i, nAtom, nTimeStep);
    pairCorrelation(hist1d, pos, posSize, pos, posSize, box, boxSize, i, nAtom, nTimeStep);
    //Convert the histogram for the distances at a fixed time  to a pointwise distribution density.
    //For self-correlation, density is zero for r=0, except when also t=0, density(0,0) is a delta function.
    //The delta function is not treated at the moment.
    //For collected-correlation, density at the averaged pair distance.
    double *pt = returnPointer + numSpatialBin*i;
    *(pt++) = 0.;
    std::vector<double> density;
    density.reserve(numSpatialBin);
    density.push_back(0.);

    double rSpacing = wavelengthCut/numSpatialBin;
    double r = rSpacing;
    for(auto it = hist1d.getRaw().begin(); it != hist1d.getRaw().end()-1;)
    {
      density.push_back(0.5*(*it+ (*(++it))) );
      r+= rSpacing;
    }

    std::copy(density.begin(), density.end(), returnPointer + numSpatialBin*i);

  }
}


EXPORT_SYMBOL void pairCorrelation(NumpyHist1D& hist, double* pos, unsigned posSize,
              double * pos2, unsigned posSize2,
              double *box, unsigned boxSize,
              unsigned timeIdx, unsigned nAtom, unsigned nTimeStep)
{
  if(posSize!=nAtom*nTimeStep*3 || posSize2!=nAtom*nTimeStep*3)
  {
    printf("posSize %d, posSize cal %d, nAtom %d, nTimeStep %d\n",
            posSize, nAtom*nTimeStep*3, nAtom, nTimeStep);
    throw std::runtime_error ("grt: vector size error");
  }

  if(boxSize!=nTimeStep*3)
  {
    printf("boxSize %d, boxSize cal %d, nAtom %d, nTimeStep %d\n",
            posSize, nAtom*nTimeStep*3, nAtom, nTimeStep);
    throw std::runtime_error ("grt: vector size error");
  }

  for(unsigned n=0;n<nAtom;n++)
  {
    double *atom1_pos = pos+ n*(nTimeStep*3);
    double *end = pos + n*(nTimeStep*3)+(nTimeStep-timeIdx)*3;
    double *atom2_pos = pos2 + n*(nTimeStep*3)+timeIdx*3;
    double *boxpos = box ;
    //very 3 numbers in pos belongs to 1 time step
    //fixme assuming x,y,z are always smaller than 10
    double avghalfBoxSize = (boxpos[0]+boxpos[1]+boxpos[2])*0.5/3;
    double threshold = avghalfBoxSize*avghalfBoxSize;
    // printf("%g\n", threshold);
    // abort();
    for( ; atom1_pos<end; boxpos +=3 )
    {
      double x = (*(atom1_pos++))-(*(atom2_pos++));
      double y = (*(atom1_pos++))-(*(atom2_pos++));
      double z = (*(atom1_pos++))-(*(atom2_pos++));

      double xx = x*x;
      double yy = y*y;
      double zz = z*z;

      if(xx>threshold)
      {
        x = x<0 ? x + boxpos[0] : x - boxpos[0];
        xx = x*x;
      }
      if(yy>threshold)
      {
        y = y<0 ? y + boxpos[1] : y - boxpos[1];
        yy = y*y;
      }
      if(zz>threshold)
      {
        z = z<0 ? z + boxpos[2] : z - boxpos[2];
        zz = z*z;
      }

      hist.fill_unguard(sqrt(xx+yy+zz));
    }
  }
    if(timeIdx%100==0)
    {
      double overScope = hist.getOverflow() + hist.getUnderflow();
      double tot = hist.getIntegral()+overScope;
      printf("lagtime idx %d, miss %g\n", timeIdx, overScope/tot);
      printf("integral %g, underflow %g, overflow %g \n",  hist.getIntegral(),
      hist.getUnderflow(), hist.getOverflow());
    }

}
