#include "StructureFactor.hh"
#include <numeric>
#include <omp.h>
#include <cmath>
#include <algorithm>
#include "NumpyHist1D.hh"
#include <time.h>       /* time */
#include <algorithm>    // std::min

void structFact(double* traj, unsigned trajSize,
  unsigned *atomicPair, unsigned atomicPairSize,
  double *box, unsigned boxSize,
  double *q, unsigned qSize,
  unsigned *tWindow, unsigned TimStepSize,
  unsigned nAtom, unsigned nTimeStep, double* FsTQ, unsigned FsTQSize)
{
  std::vector<unsigned> cache; //ctypes object persistence!! tWindow disappear from time to tim
  cache.reserve(TimStepSize);
  for(unsigned i=0;i<TimStepSize; i++)
  {
    if(*(tWindow) > nTimeStep)
      throw std::runtime_error ("intermediate: elements of timStep must be smaller than number of nTimeStep");
    cache.push_back(*(tWindow++));
  }
  unsigned *timeStep=cache.data();

  if(trajSize!=nAtom*nTimeStep*3)
  {
    throw std::runtime_error ("intermediate: position vector size error");
  }

  if(FsTQSize!=qSize*TimStepSize)
  {
    throw std::runtime_error ("intermediate: return vector size error");
  }

  std::vector<double*> trajOfAtom1;
  std::vector<double*> trajOfAtom2;
  for(unsigned i=0; i<atomicPairSize/2;i++)
  {
    trajOfAtom1.push_back(traj+ (*(atomicPair+i*2))*nTimeStep*3);
    trajOfAtom2.push_back(traj+ (*(atomicPair+i*2+1))*nTimeStep*3);
  }

  #pragma omp parallel for default(none)  shared(timeStep, TimStepSize, trajOfAtom1, trajOfAtom2, nTimeStep, box, boxSize, q, qSize, FsTQ)
  for(unsigned i=0;i<TimStepSize;i++)
  {
    std::vector<double> fq;
    printf("%p , TimStepSize %d, time step %d, i %d\n",timeStep, TimStepSize, timeStep[i], i);
    // structFactHistT(trajOfAtom1, trajOfAtom2,  nTimeStep, *(timeStep+i),
    //   box, boxSize, q, qSize, fq);
    structFactRaw(trajOfAtom1, trajOfAtom2,  nTimeStep, *(timeStep+i),
      box, boxSize, q, qSize, fq);
    std::copy(fq.begin(), fq.end(), FsTQ + qSize*i );
  }

}

void structFactHistT(const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2, unsigned nTimeStep, unsigned timeIdx,
    double *box, unsigned boxSize, double *q, unsigned qSize, std::vector<double>& fq)
{
  if(atomPos1.size() != atomPos2.size() || atomPos1.empty())
  {
    throw std::runtime_error ("intScat: atomPos vector size error");
  }
  if(boxSize!=nTimeStep*3)
  {
    throw std::runtime_error ("intScat: box size error");
  }

  bool isIncoherent = true;
  if ( std::find(atomPos1.begin(), atomPos1.end(), atomPos2[0]) != atomPos2.end() )
    isIncoherent = false;

  double maxBoxL = *std::max_element( box, box+boxSize );


  NumpyHist1D hist1d( 999999, -maxBoxL,maxBoxL);
  // printf("window size %d, mean %g, max %g, hist max %g\n", timeIdx, mean, max, histrange);

  for(unsigned n=0;n<atomPos1.size();n++)
  {
    double *pos1 = atomPos1[n];
    double *pos2 = atomPos2[n] + timeIdx*3;
    double *end = pos1 + (nTimeStep-timeIdx)*3;

    // loop over time steps
    for( ; pos1<end;  )
    {
      hist1d.fill_unguard(*(pos1++));
      hist1d.fill_unguard(*(pos2++));
    }
  }

  double histmin = hist1d.getXMin();
  double histmax = hist1d.getXMax();
  double histNumBin = hist1d.getNBins();
  double histSpacing = (histmax-histmin)/histNumBin;
  auto histRaw = hist1d.getRaw();

  //calcualte mean from the histogram
  fq.resize(0);
  fq.resize(qSize,0);
  double *pFq = fq.data();
  // #pragma omp parallel for reduction(+:pFq[:fq.size()])
  for(unsigned iq=0;iq<qSize;iq++)
  {
    double aQ = *(q+iq);
    double sumCosx (0.), sumSinx (0.), weight(0.);
    for(unsigned ihist=0;ihist<histNumBin;ihist++)
    {
      double num = histRaw[ihist];
      if(num)
      {
        weight += num;
        sumCosx += cosf(aQ* (histmin + histSpacing*(0.5+ihist))) * num;
        sumSinx += sinf(aQ* (histmin + histSpacing*(0.5+ihist))) * num;
      }
    }
    if(weight)
      *(pFq+iq) = (sumCosx*sumCosx + sumSinx*sumSinx) /weight;
    else
      printf("!!!empty histogram detected at time idx %d\n", timeIdx);
  }

}

void structFactRaw(const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2, unsigned nTimeStep, unsigned timeIdx,
    double *box, unsigned boxSize, double *q, unsigned qSize, std::vector<double>& fq)
{
  if(atomPos1.size() != atomPos2.size() || atomPos1.empty())
  {
    throw std::runtime_error ("intScat: atomPos vector size error");
  }
  if(boxSize!=nTimeStep*3)
  {
    throw std::runtime_error ("intScat: box size error");
  }

  bool isIncoherent = true;
  if ( std::find(atomPos1.begin(), atomPos1.end(), atomPos2[0]) != atomPos2.end() )
    isIncoherent = false;

  //estimate avg box size
  double minBoxSize (1e200);
  unsigned boxloop = std::min(unsigned(1000),boxSize/3);
  for(unsigned i=0 ; i<boxloop;i++)
  {
    minBoxSize = std::min(minBoxSize,box[i]);
  }

  //quickly estimate the mean distance of x
  double sum(0.), max(-1000.);
  unsigned loop = std::min(nTimeStep-timeIdx, unsigned(10000));

  double *pTest1 = atomPos1[0];
  double *pTest2 = atomPos2[0] + timeIdx*3;

  for(unsigned i=0 ; i < loop; i++ )
  {
    double x = (*(pTest1+i*3))-(*(pTest2+i*3));
    if(ncabs(x)*2 > *(box+i*3))
    {
      x = x<0. ? x + *(box+i*3) : x - *(box+i*3);
    }
    x = ncabs(x);
    max = std::max(x,max);
    sum += x;
  }
  const double mean = sum /loop;
  fq.resize(0);
  fq.resize(qSize,0.);

  if(ncabs(mean)<1e-15 && timeIdx==0)
  {
    for(unsigned iq=0;iq<qSize;iq++)
    {
        fq[iq] = 1.;
    }
    return;
  }

  double scattering_length = 1.;
  //#pragma omp parallel for firstprivate(k_vectors)
  for (unsigned k_index = 0; k_index < qSize; ++k_index)
  {
    // for (size_t frame_number = 0; frame_number < nTimeStep; ++frame_number)
    for (unsigned frame_number = 0; frame_number < nTimeStep; ++frame_number)
    {
      double sum_cos_term = 0.0;
      double sum_sin_term = 0.0;
      for(unsigned n=0;n<atomPos1.size();n++)
      {
        double *pos1 = atomPos1[n]+frame_number*3;
        double *pos2 = atomPos2[n] + (timeIdx+frame_number)*3;
        double *end = pos1 + (nTimeStep-timeIdx)*3;

        // for(unsigned i=0;i<3;i++)
        {
          sum_cos_term += scattering_length * cosf(q[k_index] * (*pos1));
          sum_sin_term += scattering_length* sinf(q[k_index] * (*pos1));

          sum_cos_term += scattering_length * cosf(q[k_index] * (*pos2));
          sum_sin_term += scattering_length* sinf(q[k_index]* (*pos2));

          pos1++; pos2++;end++;
        }

      }
      double frameContribution =  (sum_cos_term * sum_cos_term + sum_sin_term * sum_sin_term);
      // printf("frameContribution k %d, frame %d, cont %g, cos %g, sin %g\n",k_index,frame_number, frameContribution, sum_cos_term, sum_sin_term);
      fq[k_index] += frameContribution;
    }
    double normFact = 1./( 2*atomPos1.size()*(nTimeStep-timeIdx) );
    fq[k_index] *= normFact;
  }
}
