#include "StructureFactor.hh"
#include <numeric>
#include <omp.h>
#include <cmath>
#include <algorithm>
#include "NumpyHist1D.hh"
#include <time.h>       /* time */
#include <algorithm>    // std::min
#include <iostream>

void structFact(double* traj, unsigned trajSize,
  unsigned *atomicPair, unsigned atomicPairSize,
  double *box, unsigned boxSize,
  double *q, unsigned qSize,
  unsigned *tWindow, unsigned TimStepSize,
  unsigned nAtom, unsigned nTimeStep, double* FsTQ, unsigned FsTQSize)
{
  std::vector<unsigned> cache; //ctypes object persistence!! tWindow disappear from time to tim
  cache.reserve(TimStepSize);
  for(unsigned i=0;i<TimStepSize; i++)
  {
    if(*(tWindow) > nTimeStep)
      throw std::runtime_error ("intermediate: elements of timStep must be smaller than number of nTimeStep");
    cache.push_back(*(tWindow++));
  }
  unsigned *timeStep=cache.data();

  if(trajSize!=nAtom*nTimeStep*3)
  {
    throw std::runtime_error ("intermediate: position vector size error");
  }

  if(FsTQSize!=qSize*TimStepSize)
  {
    throw std::runtime_error ("intermediate: return vector size error");
  }

  std::vector<double*> trajOfAtom1;
  std::vector<double*> trajOfAtom2;
  for(unsigned i=0; i<nAtom;i++)
  {
    trajOfAtom1.push_back(traj+ i*nTimeStep*3);
  }

  // for(unsigned i=0; i<atomicPairSize/2;i++)
  // {
  //   trajOfAtom1.push_back(traj+ (*(atomicPair+i*2))*nTimeStep*3);
  //   trajOfAtom1.push_back(traj+ (*(atomicPair+i*2+1))*nTimeStep*3);
  // }

  // #pragma omp parallel for default(none)  shared(timeStep, TimStepSize, trajOfAtom1, trajOfAtom2, nTimeStep, box, boxSize, q, qSize, FsTQ)
  // #pragma omp parallel for firstprivate(trajOfAtom1)
  for(unsigned i=0;i<TimStepSize;i++)
  {
    std::vector<double> fq;
    printf("%p , TimStepSize %d, time step %d, i %d\n",timeStep, TimStepSize, timeStep[i], i);
    // structFactHistT(trajOfAtom1, trajOfAtom2,  nTimeStep, *(timeStep+i),
    //   box, boxSize, q, qSize, fq);
    structFactRaw(trajOfAtom1, trajOfAtom2,  nTimeStep, *(timeStep+i),
      box, boxSize, q, qSize, fq);
    std::copy(fq.begin(), fq.end(), FsTQ + qSize*i );
  }

}

void structFactHistT(const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2, unsigned nTimeStep, unsigned timeIdx,
    double *box, unsigned boxSize, double *q, unsigned qSize, std::vector<double>& fq)
{
  if(atomPos1.size() != atomPos2.size() || atomPos1.empty())
  {
    throw std::runtime_error ("intScat: atomPos vector size error");
  }
  if(boxSize!=nTimeStep*3)
  {
    throw std::runtime_error ("intScat: box size error");
  }

  bool isIncoherent = true;
  if ( std::find(atomPos1.begin(), atomPos1.end(), atomPos2[0]) != atomPos2.end() )
    isIncoherent = false;

  double minBoxL = *std::min_element( box, box+boxSize );



  unsigned histNumBin = 9999;
  double histmin = -minBoxL;
  double histmax = minBoxL;
  double histSpacing = (histmax-histmin)/histNumBin;

  NumpyHist1D hist1( histNumBin, histmin, histmax);
  NumpyHist1D hist2( histNumBin, histmin, histmax);

  // printf("window size %d, mean %g, max %g, hist max %g\n", timeIdx, mean, max, histrange);

  //calcualte mean from the histogram
  fq.resize(0);
  fq.resize(qSize,0);
  // for (size_t frame_number = 0; frame_number < nTimeStep; ++frame_number)
  for (unsigned frame_number = 0; frame_number < nTimeStep; ++frame_number)
  {
    hist1.reset();
    hist2.reset();
    double sum_cos_term = 0.0;
    double sum_sin_term = 0.0;
    for(unsigned n=0;n<atomPos1.size();n++)
    {
      double *pos1 = atomPos1[n]+frame_number*3;
      double *pos2 = atomPos2[n] + (timeIdx+frame_number)*3;
      double *end = pos1 + (nTimeStep-timeIdx)*3;

      for(unsigned i=0;i<3;i++)
      {
        hist1.fill_unguard(*pos1);
        hist2.fill_unguard(*pos2);
        pos1++; pos2++;end++;
      }
    }


    auto histRaw = hist1.getRaw();
    auto histRaw2 = hist1.getRaw();

    for(unsigned iq=0;iq<qSize;iq++)
    {
      double aQ = *(q+iq);
      double sumCosx (0.), sumSinx (0.), sumCosx2 (0.), sumSinx2 (0.),weight(0.);
      for(unsigned ihist=0;ihist<histNumBin;ihist++)
      {
        double num = histRaw[ihist];
        double r = (histmin + histSpacing*(0.5+ihist));
        if(num)
        {
          weight += num;
          sumCosx += cosf(aQ* r) * num;
          sumSinx += sinf(aQ* r) * num;
        }

        num = histRaw2[ihist];
        if(num)
        {
          weight += num;
          sumCosx2 += cosf(aQ* r) * num;
          sumSinx2 += sinf(aQ* r) * num;
        }
      }
      double frameContribution =  2*(sumCosx * sumCosx2 + sumSinx * sumSinx2)/(weight*nTimeStep);
      // printf("frameContribution k %d, frame %d, cont %g, cos %g, sin %g\n",k_index,frame_number, frameContribution, sum_cos_term, sum_sin_term);
      fq[iq] += frameContribution;
    }
  }
}

void generate_k_vectors(unsigned int const & k_squared, std::vector< std::vector< unsigned int > > & k_vectors)
{
  unsigned int i_max = static_cast< unsigned int > (sqrt(k_squared/3.0));
  for (unsigned int i = 0; i <= i_max; ++i) {
      unsigned int j_max = static_cast< unsigned int > (sqrt((k_squared - i*i)/2.0));
      for (unsigned int j = i; j <= j_max; ++j) {
          unsigned int k_sqr = k_squared - i*i - j*j;
          unsigned int k = static_cast< unsigned int > (sqrt(k_sqr));
          // check if k is a perfect square
          if ( fabs(k_sqr - k*k) < 1e-16 ) {
              // Add all permutations since i <= j <= k;
              k_vectors.push_back( {i, j, k} );
              if (i == j && j == k) {
                  continue;
              }
              else if ( i == j || j == k ) {
                  k_vectors.push_back( {j, k, i} );
                  k_vectors.push_back( {k, i, j} );
              }
              else {
                  // Write down all remaining permuations
                  k_vectors.push_back( {i, k, j} );
                  k_vectors.push_back( {j, i, k} );
                  k_vectors.push_back( {j, k, i} );
                  k_vectors.push_back( {k, i, j} );
                  k_vectors.push_back( {k, j, i} );
              }
          }
      }
  }
}


void structFactRaw(const std::vector<double*>& atomPos1, const std::vector<double*>& atomPos2, unsigned nTimeStep, unsigned timeIdx,
    double *box, unsigned boxSize, double *q, unsigned qSize, std::vector<double>& fq)
{
  if( atomPos1.empty())
  {
    throw std::runtime_error ("intScat: atomPos vector size error");
  }
  if(boxSize!=nTimeStep*3)
  {
    throw std::runtime_error ("intScat: box size error");
  }

  double minBoxL = *std::min_element( box, box+boxSize );
  double qResolution = 2*M_PI/minBoxL;

  fq.resize(0);
  fq.resize(qSize,0.);

  std::vector< std::vector< std::vector< unsigned int > > > k_vectors(qSize);

  for(unsigned i=0;i<qSize;i++)
  {
    generate_k_vectors((i+1)*(i+1), k_vectors[i]);
  }

  for(unsigned i=0;i<k_vectors.size();i++)
  {
    std::cout << i << " " << k_vectors[i].size() <<std::endl;
    // for(auto v1:v)
    //   std::cout << v1[0] << " " << v1[1] << " " << v1[2] <<std::endl;
  }

  #pragma omp parallel for firstprivate(k_vectors)
  for (unsigned k_index = 0; k_index < qSize; ++k_index)
  {
    double scale = 1./k_vectors[k_index].size();
    for (unsigned frame_number = 0; frame_number < nTimeStep; ++frame_number)
    {
      for(unsigned vecID=0;vecID<k_vectors[k_index].size();vecID++)
     {
       double sum_cos_term = 0.0;
       double sum_sin_term = 0.0;

      for(unsigned n=0;n<atomPos1.size();n++)
      {
        // double scattering_length = n%3==0? 0.5803: -0.37406;
        double scattering_length = 1.0;
        double *pos1 = atomPos1[n]+frame_number*3;
        double *end = pos1 + (nTimeStep-timeIdx)*3;
        double *pbox = box+frame_number*3;

        double dotprd = 2*M_PI/(*pbox)*(*pos1)*k_vectors[k_index][vecID][0] +
            2*M_PI/(*(pbox+1))*(*(pos1+1))*k_vectors[k_index][vecID][1] +
            2*M_PI/(*(pbox+2))*(*(pos1+2))*k_vectors[k_index][vecID][2] ;

        sum_cos_term += scattering_length * cosf( dotprd);
        sum_sin_term += scattering_length* sinf( dotprd);

        // double dotprd = (*pos1)*k_vectors[k_index][vecID][0] +
        //     (*(pos1+1))*k_vectors[k_index][vecID][1] +
        //     (*(pos1+2))*k_vectors[k_index][vecID][2] ;
        //
        // sum_cos_term += scattering_length * cos(qResolution* dotprd);
        // sum_sin_term += scattering_length* sin(qResolution* dotprd);
        }

        double frameContribution =  (sum_cos_term * sum_cos_term + sum_sin_term * sum_sin_term);
        fq[k_index] += frameContribution;
        // printf("vector %d constribution to frame %d, k %d,  cont %g, cos %g, sin %g\n",vecID, frame_number, k_index, fq[k_index], sum_cos_term, sum_sin_term);

      }
    }
    double normFact = 1./( atomPos1.size()*(nTimeStep-timeIdx) );
    fq[k_index] *= normFact*scale;
    printf("fq[%d]=%g\n", k_index, fq[k_index]);
  }
}
