#include "TrajAna.hh"
#include <numeric>
#include <omp.h>
#include <cmath>
#include <algorithm>
#include "NumpyHist1D.hh"
#include <time.h>       /* time */
#include <algorithm>    // std::min

// fixme: for intermediate function three different better numerical methods
// are tested. No improvement observed
// 1. tried Lebedev with higher order
// 2. tried cos instead of cosf
// 3. tried stable sum for accumulating cosine values

class LebedevQuadrature {
public:
  LebedevQuadrature()
  {
    // // theta, phi, weight
    // std::vector<double> raw = {
    //    0.000000000000000, 90.000000000000000, 0.066666666666667*2,
    //   90.000000000000000, 90.000000000000000, 0.066666666666667*2,
    //   90.000000000000000, 0.000000000000000, 0.066666666666667*2,
    //   45.000000000000000,54.735610317245346, 0.075000000000000*2,
    //   45.000000000000000  , 125.264389682754654, 0.075000000000000*2,
    //  -45.000000000000000,54.735610317245346, 0.075000000000000*2,
    //  -45.000000000000000  , 125.264389682754654, 0.07500000000000*2};
    //  //fixme: elements can be reduced by 1 when applying
    //  //cos(a+b+c)+cos(a+b-c)+cos(a-b+c)+cos(a-b-c)=4*cos(a)*cos(b)*cos(c)

    std::vector<double> raw = {
      0.000000000000000 ,   90.000000000000000  ,   0.166666666666667*2,
      90.000000000000000  ,  90.000000000000000  ,   0.166666666666667*2,
      90.000000000000000  ,   0.000000000000000  ,   0.166666666666667*2};

    double deg = M_PI/180.;
    for(auto it=raw.begin(); it<raw.end(); it+=3)
    {
      double theta = (*it)*deg;
      double phi = (*(it+1))*deg;
      double w = *(it+2);
      m_dir_w.push_back(std::vector<double> {cos ( theta ) * sin ( phi ), sin ( theta ) * sin ( phi ), cos ( phi ), w} );
      // printf("Lebedev root (%g, %g, %g), weight %g\n", m_dir_w.back()[0], m_dir_w.back()[1], m_dir_w.back()[2], m_dir_w.back()[3]);
    }
  }
  ~LebedevQuadrature(){}

  std::vector<std::vector<double>>::const_iterator begin() const {return m_dir_w.begin();};
  std::vector<std::vector<double>>::const_iterator end() const {return m_dir_w.end();};

private:
  std::vector<std::vector<double>> m_dir_w;
};

//velocity function
void cvv(double* vel, unsigned velSize,
         unsigned *atomicPair, unsigned atomicPairSize,
         double* returnPointer, unsigned retsize,
         unsigned nAtom, unsigned nTimeStep) //fixme remove unused nAtom
{
  if(retsize>nTimeStep*0.5)
    throw std::runtime_error ("vaf_vec: returned vector size should not be greater than half of nTimeStep");

  std::vector<double*> velj;
  std::vector<double*> veljp;
  for(unsigned i=0; i<atomicPairSize/2;i++)
  {
    velj.push_back(vel+ (*(atomicPair+i*2))*nTimeStep*3);
    veljp.push_back(vel+ (*(atomicPair+i*2+1))*nTimeStep*3);
  }


  #pragma omp parallel for default(none) shared(velj,veljp,returnPointer,retsize,nTimeStep)
  for(unsigned i=0;i<retsize;i++)
  {
    if(!i)
      printf("vaf_vec: max thread %d, Num thread %d \n", omp_get_max_threads(), omp_get_num_threads() );
    returnPointer[i]=cvv_t(velj, veljp, i, nTimeStep );
  }
}


double cvv_t(const std::vector<double*>& vel1, const std::vector<double*>& vel2,
  unsigned timeIdx, unsigned nTimeStep)
{
  if(vel1.size() != vel2.size())
  {
    throw std::runtime_error ("cvv_t: atom velocity vector size error");
  }

  double tot (0.);
  for(unsigned n=0;n<vel1.size();n++)
  {
    double *velj = vel1[n];
    double *endj = velj+(nTimeStep-timeIdx)*3;

    double *veljp = vel2[n]+timeIdx*3;
    for( ; velj<endj; velj++, veljp++)
    {
      tot += (*velj)*(*veljp);
    }
  }
  return tot/(vel1.size()*(nTimeStep-timeIdx));
}


void intermediate(double* traj, unsigned trajSize,
  unsigned *atomicPair, unsigned atomicPairSize,
  double *box, unsigned boxSize,
  double *q, unsigned qSize,
  unsigned *tWindow, unsigned TimStepSize,
  unsigned nAtom, unsigned nTimeStep, double* FsTQ, unsigned FsTQSize)
{
  std::vector<unsigned> cache; //ctypes object persistence!! tWindow disappear from time to tim
  cache.reserve(TimStepSize);
  for(unsigned i=0;i<TimStepSize; i++)
  {
    if(*(tWindow) > nTimeStep)
      throw std::runtime_error ("intermediate: elements of timStep must be smaller than number of nTimeStep");
    cache.push_back(*(tWindow++));
  }
  unsigned *timeStep=cache.data();

  if(trajSize!=nAtom*nTimeStep*3)
  {
    throw std::runtime_error ("intermediate: position vector size error");
  }
  if(FsTQSize!=qSize*TimStepSize)
  {
    throw std::runtime_error ("intermediate: return vector size error");
  }

  std::vector<double*> trajOfAtom1;
  std::vector<double*> trajOfAtom2;
  for(unsigned i=0; i<atomicPairSize/2;i++)
  {
    trajOfAtom1.push_back(traj+ (*(atomicPair+i*2))*nTimeStep*3);
    trajOfAtom2.push_back(traj+ (*(atomicPair+i*2+1))*nTimeStep*3);
  }

  #pragma omp parallel for default(none)  shared(timeStep, TimStepSize, trajOfAtom1, trajOfAtom2, nTimeStep, box, boxSize, q, qSize, FsTQ)
  for(unsigned i=0;i<TimStepSize;i++)
  {
    std::vector<double> fq;
    // printf("%p , TimStepSize %d, time step %d, i %d\n",timeStep, TimStepSize, timeStep[i], i);
    qtCorrelationHist(trajOfAtom1, trajOfAtom2,  nTimeStep, *(timeStep+i),
      box, boxSize, q, qSize, fq);
    std::copy(fq.begin(), fq.end(), FsTQ + qSize*i );
  }

}

void qtCorrelationHist(const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2, unsigned nTimeStep, unsigned timeIdx,
    double *box, unsigned boxSize, double *q, unsigned qSize, std::vector<double>& fq)
{
  if(atomPos1.size() != atomPos2.size())
  {
    throw std::runtime_error ("intScat: atomPos vector size error");
  }
  if(boxSize!=nTimeStep*3)
  {
    throw std::runtime_error ("intScat: box size error");
  }

  //quickly estimate the mean distance of x
  double sum(0.), max(-1000.);
  unsigned loop = std::min(nTimeStep-timeIdx, unsigned(10000));

  double *pTest1 = atomPos1[0];
  double *pTest2 = atomPos2[0] + timeIdx*3;

  for(unsigned i=0 ; i < loop; i++ )
  {
    double x = (*(pTest1+i*3))-(*(pTest2+i*3));
    if(ncabs(x)*2 > *(box+i*3))
    {
      x = x<0. ? x + *(box+i*3) : x - *(box+i*3);
    }
    x = ncabs(x);
    max = std::max(x,max);
    sum += x;
  }
  const double mean = sum /loop;
  fq.resize(0);
  fq.resize(qSize,0.);

  if(ncabs(mean)<1e-15 && timeIdx==0)
  {
    for(unsigned iq=0;iq<qSize;iq++)
    {
        fq[iq] = 1.;
    }
    return;
  }

  // printf("window size %d, mean %g, max %g, ratio %g\n", timeIdx, mean, max, max/mean);
  double multiplication = 10.;
  NumpyHist1D hist1d( 9999, -mean*multiplication , mean*multiplication );

  // #pragma omp parallel for reduction(+:pFq[:atomPos1.size()])
  for(unsigned n=0;n<atomPos1.size();n++)
  {
    double *pos1 = atomPos1[n];
    double *pos2 = atomPos2[n] + timeIdx*3;
    double *end = pos1 + (nTimeStep-timeIdx)*3;

    // loop over time steps
    for( ; pos1<end;  )
    {
      hist1d.fill_unguard((*(pos1++))-(*(pos2++)));
    }
  }

  double histmin = hist1d.getXMin();
  double histmax = hist1d.getXMax();
  double histNumBin = hist1d.getNBins();
  double histSpacing = (histmax-histmin)/histNumBin;
  auto cont = hist1d.getRaw();

  //calcualte mean from the histogram
  for(unsigned iq=0;iq<qSize;iq++)
  {
    double aQ = *(q+iq);
    double sumCosx (0.), weight(0.);
    for(unsigned ihist=0;ihist<histNumBin;ihist++)
    {
      double num = cont[ihist];
      if(num)
      {
        weight += num;
        sumCosx += cosf(aQ* (histmin + histSpacing*(0.5+ihist))) * num;
      }
    }
    if(weight)
      fq[iq] = sumCosx/weight;
    else
      printf("empty histogram detected at time idx %d, mean is %e \n", timeIdx, mean);
  }
}

void qtCorrelation1d(const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2, unsigned nTimeStep, unsigned timeIdx,
  double *box, unsigned boxSize, double *q, unsigned qSize, std::vector<double>& fq)
{
  if(atomPos1.size() != atomPos2.size())
  {
    throw std::runtime_error ("intScat: atomPos vector size error");
  }
  if(boxSize!=nTimeStep*3)
  {
    throw std::runtime_error ("intScat: box size error");
  }
  fq.resize(0);
  fq.resize(qSize,0.);

  std::vector<StableSum> ssum(qSize);

  double *pFq = fq.data();
  // #pragma omp parallel for reduction(+:pFq[:atomPos1.size()])
  for(unsigned n=0;n<atomPos1.size();n++)
  {
    double *pos1 = atomPos1[n];
    double *pos2 = atomPos2[n] + timeIdx*3;
    double *end = pos1 + (nTimeStep-timeIdx)*3;
    //very 3 numbers in pos belongs to 1 time step
    //fixme assuming x,y,z are always smaller than 10
    double *boxpos = box ;

    // loop over time steps
    for( ; pos1<end; boxpos +=3 )
    {
      double x = (*(pos1))-(*(pos2));
      pos1 += 3; pos2+=3;

      if(ncabs(x)*2 > *boxpos )
      {
        x = x<0. ? x + *boxpos : x - *boxpos;
      }

      // //get mean
      for(unsigned iq=0; iq<qSize;iq++)
      {
        //fixeme: use cosf?
        // *(pFq+iq)  += cos(x*(*(q+iq)));
        ssum[iq].add(cos(x*(*(q+iq))));
      }

    }
  }
  double normFact = 1./( atomPos1.size()*(nTimeStep-timeIdx) );
  // for (auto &v : fq)
  //   v *= normFact;

  for(unsigned i=0;i<qSize;i++)
  {
    fq[i]=ssum[i].sum()*normFact;
  }
}


void qtCorrelation(const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2, unsigned nTimeStep, unsigned timeIdx,
  double *box, unsigned boxSize, double *q, unsigned qSize, std::vector<double>& fq)
{
  if(atomPos1.size() != atomPos2.size())
  {
    throw std::runtime_error ("intScat: atomPos vector size error");
  }
  if(boxSize!=nTimeStep*3)
  {
    throw std::runtime_error ("intScat: box size error");
  }
  fq.resize(0);
  fq.resize(qSize,0.);

  LebedevQuadrature lquad;
  double qdirX(0.),qdirY(0.),qdirZ(1.);

  // std::vector<StableSum> ssum(qSize);

  double *pFq = fq.data();
  // #pragma omp parallel for reduction(+:pFq[:atomPos1.size()])
  for(unsigned n=0;n<atomPos1.size();n++)
  {
    double *pos1 = atomPos1[n];
    double *pos2 = atomPos2[n] + timeIdx*3;
    double *end = pos1 + (nTimeStep-timeIdx)*3;
    //very 3 numbers in pos belongs to 1 time step
    //fixme assuming x,y,z are always smaller than 10
    double *boxpos = box ;

    // loop over time steps
    for( ; pos1<end; boxpos +=3 )
    {
      double x = (*(pos1++))-(*(pos2++));
      double y = (*(pos1++))-(*(pos2++));
      double z = (*(pos1++))-(*(pos2++));

      if(ncabs(x)*2 > *boxpos )
      {
        x = x<0. ? x + *boxpos : x - *boxpos;
      }
      if(ncabs(y)*2 > *(boxpos+1) )
      {
        y = y<0. ? y + *(boxpos+1) : y - *(boxpos+1);
      }
      if(ncabs(z)*2 > *(boxpos+2))
      {
        z = z<0. ? z + *(boxpos+2) : z - *(boxpos+2);
      }

      // //get mean
      for(auto it=lquad.begin();it!=lquad.end();++it)
      {
        double dotprt = x*(*it)[0] + y*(*it)[1] +z*(*it)[2];
        // double sum(0.);
        for(unsigned iq=0; iq<qSize;iq++)
        {
          //fixeme: use cosf?
          *(pFq+iq)  += cosf(dotprt*(*(q+iq)))*(*it)[3];
          // ssum[iq].add(cos(dotprt*(*(q+iq)))*(*it)[3]);
        }
      }

      // //get mean

        // double dotprt = x;
        // // double sum(0.);
        // for(unsigned iq=0; iq<qSize;iq++)
        // {
        //   //fixeme: use cosf?
        //   *(pFq+iq)  += cosf(dotprt*(*(q+iq)));
        //   // ssum[iq].add(cos(dotprt*(*(q+iq)))*(*it)[3]);
        // }


      // // isotropic random number in fact does the same trick. TBI
      // randisotropic(qdirX,qdirY,qdirZ);
      // double dotprt = x*qdirX + y*qdirY +z*qdirZ;
      // //loop over q
      // for(unsigned iq=0; iq<qNum;iq++)
      // {
      //   *(pFq+iq) += cos(dotprt*((iq+1)*qSpacing) );
      // }
    }
  }
  double normFact = 1./( atomPos1.size()*(nTimeStep-timeIdx) );
  for (auto &v : fq)
    v *= normFact;
  // for(unsigned i=0;i<qSize;i++)
  // {
  //   fq[i]=ssum[i].sum()*normFact;
  // }
}
//
// double fsqt(double* atomPos1, unsigned atomPos1Size,
//   double* atomPos2, unsigned atomPos2Size,
//   double *box, unsigned boxSize, double q,
//   unsigned timeIdx, unsigned nAtom, unsigned nTimeStep)
// {
//   if(atomPos1Size!=nAtom*nTimeStep*3 || atomPos1Size!=atomPos2Size)
//   {
//     throw std::runtime_error ("intScat: vector size error");
//   }
//
//   LebedevQuadrature lquad;
//   double tot (0.);
//   double qdirX(0.),qdirY(0.),qdirZ(1.);
//
//
//   for(unsigned n=0;n<nAtom;n++)
//   {
//     double *pos1 = atomPos1 + n*(nTimeStep*3);
//     double *pos2 = atomPos2 + n*(nTimeStep*3)+timeIdx*3;
//     double *end = pos1 + (nTimeStep-timeIdx)*3;
//     //very 3 numbers in pos belongs to 1 time step
//     //fixme assuming x,y,z are always smaller than 10
//     double *boxpos = box ;
//
//     // printf("%g\n", threshold);
//     // abort();
//     for( ; pos1<end; boxpos +=3 )
//     {
//       double x = (*(pos1++))-(*(pos2++));
//       double y = (*(pos1++))-(*(pos2++));
//       double z = (*(pos1++))-(*(pos2++));
//
//       if(ncabs(x)*2 > *boxpos )
//       {
//         x = x<0. ? x + *boxpos : x - *boxpos;
//       }
//       if(ncabs(y)*2 > *(boxpos+1) )
//       {
//         y = y<0. ? y + *(boxpos+1) : y - *(boxpos+1);
//       }
//       if(ncabs(z)*2 > *(boxpos+2))
//       {
//         z = z<0. ? z + *(boxpos+2) : z - *(boxpos+2);
//       }
//
//       // //get mean
//       // for(auto it=lquad.begin();it!=lquad.end();++it)
//       // {
//       //   double dotprt = x*(*it)[0]*q + y*(*it)[1]*q +z*(*it)[2]*q;
//       //   tot += cos(dotprt)*(*it)[3];
//       // }
//
//       // isotropic random number in fact does the same trick. TBI
//       randisotropic(qdirX,qdirY,qdirZ);
//       double dotprt = x*qdirX*q + y*qdirY*q +z*qdirZ*q;
//       tot += cos(dotprt);
//
//     }
//   }
//   return tot/(nAtom*(nTimeStep-timeIdx));
// }


//velocity autocorrelation function
double vaf(double* vel, unsigned velSize,
  unsigned idx, unsigned nAtom, unsigned nTimeStep)
{
  return vafRl(vel, velSize,  idx,  nAtom,  nTimeStep);
}

//velocity autocorrelation function
void vaf_vec(double* vel, unsigned velSize,
             double* returnPointer, unsigned retsize,
             unsigned nAtom, unsigned nTimeStep)
{
  if(retsize>nTimeStep*0.5)
    throw std::runtime_error ("vaf_vec: returned vector size should not be greater than half of nTimeStep");

  #pragma omp parallel for default(none) shared(vel,velSize,returnPointer,retsize,nAtom,nTimeStep)
  for(unsigned i=0;i<retsize;i++)
  {
    if(!i)
      printf("vaf_vec: max thread %d, Num thread %d \n", omp_get_max_threads(), omp_get_num_threads() );
    returnPointer[i]=vaf(vel, velSize, i,nAtom, nTimeStep );
  }
}

void gtr_vec(double* traj, unsigned trajSize,
            unsigned *atomicPair, unsigned atomicPairSize,
            double *box, unsigned boxSize,
            double* returnPointer, unsigned retsize,
            unsigned numWavlengtBin, double wavelengthCut,
            double dltT, unsigned nAtom, unsigned nTimeStep)
{
  if(trajSize!=nAtom*nTimeStep*3)
  {
    throw std::runtime_error ("incoherentIntermediate: position vector size error");
  }

  std::vector<double*> trajOfAtom1;
  std::vector<double*> trajOfAtom2;
  for(unsigned i=0; i<atomicPairSize/2;i++)
  {
    trajOfAtom1.push_back(traj+ (*(atomicPair+i*2))*nTimeStep*3);
    trajOfAtom2.push_back(traj+ (*(atomicPair+i*2+1))*nTimeStep*3);
  }

  //calculate distance distribution at fixed time
  #pragma omp parallel for default(none) shared(trajOfAtom1, trajOfAtom2, box, boxSize, dltT, nAtom, nTimeStep, returnPointer, numWavlengtBin, wavelengthCut)
  for(unsigned i=0; i<nTimeStep/2;i++ ) //fixme i=0 is wasted
  {
    if(i==1)
      printf("grt_vec: max thread %d, Num thread %d \n", omp_get_max_threads(), omp_get_num_threads() );

    //space 0 to 15 Aa, time
    NumpyHist1D hist1d(numWavlengtBin, 0., wavelengthCut);

    pairCorrelation(hist1d, trajOfAtom1, trajOfAtom2, box, boxSize,  nTimeStep, i);
    //Convert the histogram for the distances at a fixed time  to a pointwise distribution density.
    //For self-correlation, density is zero for r=0, except when also t=0, density(0,0) is a delta function.
    //The delta function is not treated at the moment.
    //For collected-correlation, density at the averaged pair distance.
    double *pt = returnPointer + numWavlengtBin*i;
    *(pt++) = 0.;
    std::vector<double> density;
    density.reserve(numWavlengtBin);
    density.push_back(0.);

    double rSpacing = wavelengthCut/numWavlengtBin;
    double r = rSpacing;
    for(auto it = hist1d.getRaw().begin(); it != hist1d.getRaw().end()-1;)
    {
      density.push_back(0.5*(*it+ (*(++it))) );
      r+= rSpacing;
    }

    std::copy(density.begin(), density.end(), returnPointer + numWavlengtBin*i);

  }
}


EXPORT_SYMBOL void pairCorrelation(NumpyHist1D& hist,
             const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2,
              double *box, unsigned boxSize,
              unsigned nTimeStep, unsigned timeIdx)
{

  if(atomPos1.size()!=atomPos2.size())
  {
    throw std::runtime_error ("pairCorrelation: vector size error");
  }

  for(unsigned n=0;n<atomPos1.size();n++)
  {
    double *pos1 = atomPos1[n];
    double *pos2 = atomPos2[n] + timeIdx*3;
    double *end = pos1 + (nTimeStep-timeIdx)*3;
    double *boxpos = box ;

    for( ; pos1<end; boxpos +=3 )
    {
      double x = (*(pos1++))-(*(pos2++));
      double y = (*(pos1++))-(*(pos2++));
      double z = (*(pos1++))-(*(pos2++));

      if(ncabs(x)*2 > *boxpos )
      {
        x = x<0. ? x + *boxpos : x - *boxpos;
      }
      if(ncabs(y)*2 > *(boxpos+1) )
      {
        y = y<0. ? y + *(boxpos+1) : y - *(boxpos+1);
      }
      if(ncabs(z)*2 > *(boxpos+2))
      {
        z = z<0. ? z + *(boxpos+2) : z - *(boxpos+2);
      }

      hist.fill_unguard(sqrt(x*x+y*y+z*z));
    }
  }
    // if(timeIdx%100==0)
    // {
    //   double overScope = hist.getOverflow() + hist.getUnderflow();
    //   double tot = hist.getIntegral()+overScope;
    //   printf("lagtime idx %d, miss %g\n", timeIdx, overScope/tot);
    //   printf("integral %g, underflow %g, overflow %g \n",  hist.getIntegral(),
    //   hist.getUnderflow(), hist.getOverflow());
    // }

}
