#include "PassBandFunctions.hpp"

#include <iostream>
using namespace std;
using namespace cppR;


/**************************************************/

double const pi = M_PI;

template <typename T>
double logb(int base, T value)
{
  return log(value)/log(base);
}
template <typename T>
double log2(T value)
{
  return logb(2,value);
}



/**************************************************/

// ORIGINAL DOCUMENTATION FROM chebbp2.m
// Source: http://www.dsp.rice.edu/software/fir.shtml

//  A second program for the design of symmetric bandpass FIR filters with
//  flat monotonically decreasing passbands (on either side of wp)
//  and equiripple stopbands. This program is similar to chebbp.m,
//  but it uses a different set of input parameters.
//
//  output
//   h  : filter
//  input
//   N   : length of total filter
//   L   : degree of flatness
//   wp  : passband frequency of flatness
//   ws1 : first stopband edge
//   ws2 : second stopband edge
//   Need: 0 < ws1 < wp < ws2 < pi
//
//  Author: Ivan W. Selesnick, Rice University



Matrix chebbp2(int N, int L, double wp, double ws1, double ws2)
{

  //check numbers are in the right range
  if((N%2)==0 || (L%4) != 0)
    {
      cerr << "N must be odd and L must be divisible by 4\n";
      return Matrix(0,0);
    }
  if(0 >= ws1 || ws1 >= wp || wp >= ws2 || ws2 >= pi)
    {
      cerr << "need: 0 < ws1 < wp < ws2 < pi\n";
      return Matrix(0,0);
    }
  if(L < 1)
    {
      cerr << "L must be positive\n";
      return Matrix(0,0);
    }


  int q = int((N-L + 1) / 2);
  int g = int(pow(2,ceil(log2(8*N)))); //number of grid points
  double SN = 1e-8; //SMALL NUMBER

  //w <- t(0:g) * pi / g
  Matrix w(1, g+1);
  for(int i=0; i<w.size2(); i++)
      w(0,i) = i * pi / g;

  double d = ws1 / (pi-ws2);
  int q1 = int(round((q+1)/(1+1/d)));

  if (q1 == 0)
    q1 = 1;
  else if (q1 == q+1)
    q1 = q;

  int q2 = q + 1 - q1;


  Matrix rs1(1,q1);
  if(q1 == 1)
    {
      rs1.resize(1,1);
      rs1(0,0) = ws1;
    }
  else
    {
      for(int i=0; i<rs1.size2(); i++)
        rs1(0,i) = i * (ws1/(q1-1));
    }
  Matrix rs2(1,q2);
  if(q2 == 1)
    {
      rs2.resize(1,1);
      rs2(0,0) = ws2;
    }
  else
    {
      for(int i=0; i<rs2.size2(); i++)
        rs2(0,i) = i * ((pi-ws2)/(q2-1)) + ws2;
    }


  Matrix rs(rs1.size2() + rs2.size2(),1);
  for(int i=0;i<rs1.size2();i++)
    rs(i,0) = rs1(0,i);
  for(int i=0;i<rs2.size2();i++)
    rs(i+rs1.size2(),0) = rs2(0,i);


  //******************************

  //Matrix Z(int(2 * (g+1-q))-1,1);
  int Z_size = 2 * (g+1-q)-1;

  // A1 <- (-1)^(L/2) * (sin(w/2-wp/2) * sin(w/2+wp/2)) ^(L/2)
  Matrix temp = compProd(apply(Matrix(w/2) - double(wp/2),sin),
        		 apply(Matrix(w/2) + double(wp/2),sin));
  Matrix A1 = apply(temp, pow, L/2) * pow(-1.0,L/2);


  //si = q iterations of alternating +1, -1
  Matrix si(q+1,1);
  for(int i=0;i<si.size1();i++)
    si(i,0) = (i % 2)==0 ? 1 : -1;

  //n = single row of numbers 0:q-1
  Matrix n(1,q);
  for(int i=0;i<n.size2();i++)
    n(0,i) = i;

  //A1r <- (-1)^(L/2) * (sin(rs/2-wp/2) * sin(rs/2+wp/2))^(L/2)
  temp = compProd(apply(Matrix(rs/2) - double(wp/2),sin),
        	  apply(Matrix(rs/2) + double(wp/2),sin));
  Matrix A1r = apply(temp, pow, L/2) * pow(-1.0,L/2);

  //******************************
  Matrix a;

  for(int it=0; it<15; it++)
    {
      temp = cbind(Matrix(apply(Matrix(prod(rs,n)),cos)),
        	   Matrix(compDiv(si,A1r)));
      Matrix x = solve(temp,compDiv(1.0,A1r));

      a.resize(x.size1()-1,1);
      for(int i=0;i<a.size1();i++)
        a(i,0) = -x(i,0);

      double del = x(q,0);
      //matrix to compute fft of
      Matrix F;
      if(q > 2)
        {
          F = createMatrix(0,2*q-1 + Z_size,1);
          F(0,0) = a(0,0);
          //secondcomp
          int i;
          for(i=1;i<q;i++)
            {
              F(i,0) = a(i,0)/2;
              F(F.size1()-i,0) = a(i,0)/2;
            }
        }
      else
        {
          F = createMatrix(0,1 + Z_size,1);
          F(0,0) = a(0,0);
        }

      //compute fft of F and save the real part
      Matrix A2 = Re(fft(F));
      A2 = submatrix(A2,0,g,0,0);

      //1 X g matrix A
      Matrix A = compProd(A1,t(A2)) + 1.0;
      Matrix Y = si*del;

      //local max of A and A-1
      std::vector<int> lmA = localMax(A);

      std::vector<int> lmAInv = localMax(A*(-1));
      for(int i=0;i<lmAInv.size();i++)
        lmA.push_back(lmAInv[i]);
      sort(lmA.begin(),lmA.end());


      int llmA = lmA.size();
      //cout << "llmA = " << llmA << "\n";
      //cout << "A SIZE: " << A.size1() <<"," << A.size2() << "\n";
      if(llmA != q+1)
        {
          /*cout << "A SIZE: " << A.size1() <<"," << A.size2() << "\n";
          int ind1 = lmA[llmA-1]-1;
          int ind2 = lmA[llmA-2]-1;
          int ind3 = lmA[0]-1;
          int ind4 = lmA[1]-1;
          cout << "INDS: " << ind1 << ", " << ind2 << ", " << ind3 << ", " << ind4 << "\n";
          */
          if(abs(A(0,lmA[llmA-1]-1)-A(0,lmA[llmA-2]-1))
             < abs(A(0,lmA[0]-1)-A(0,lmA[1]-1)))
            {
              lmA.resize(llmA-1);
            }
          else
            {
              lmA.erase(lmA.begin());
            }
        }

      //ri is the indexes of A which are local max for A and -A
      std::vector<double> ri(lmA.size());
      int k = 0;
      double min = DBL_MAX;
      for(int i=0;i<lmA.size();i++)
        {
          ri[i] = (double(lmA[i])-1.0) * pi/g;
          double v = abs(ri[i]-wp);
          if(v < min)
            {
              min = v;
              k = i;
            }
        }

      ri.erase(ri.begin()+k);

      //compute Aws1 and Aws2
      int sign = int(pow(-1.0, L/2));
      double Aws1 = (prod(apply(Matrix(n*ws1),cos), a))(0,0) * sign * pow(sin(ws1/2-wp/2) * sin(ws1/2 + wp/2), L/2) + 1;
      double Aws2 = (prod(apply(Matrix(n*ws2),cos), a))(0,0) * sign * pow(sin(ws2/2-wp/2) * sin(ws2/2 + wp/2), L/2) + 1;

      //check in any values of ri are between wp and ws2
      bool in_range = false;
      for(int i=0;i<ri.size();i++)
        if(ri[i] > wp && ri[i] < ws2)
          {
            in_range = true;
            break;
          }

      //add ws1 or ws2
      if((Aws1 > Aws2) || (in_range))
        ri.push_back(ws1);
      else
        ri.push_back(ws2);

      sort(ri.begin(),ri.end());

      //copy ri to rs
      //cout << "ri.size() = " << ri.size() << ", rs.size1() = " << rs.size1() << "\n";
      for(int i=0;i<ri.size();i++)
        rs(i,0) = ri[i];

      A1r = apply(compProd(apply(Matrix(rs * (.5)) - (wp/2),sin),
        			  apply(Matrix(rs * (.5)) + (wp/2),sin))
        			, pow
        			, L/2) * sign;

      Matrix Ar = Matrix(compProd(Matrix(prod(apply(Matrix(prod(rs,n)),cos),a)), A1r)) + 1.0;
      double armax = cppR::max(Ar-abs(del));
      double e1 = abs(del)-abs(cppR::min(Ar));

      double Err = armax > e1 ? armax : e1;
      /*if(it == 3)
        {
          Matrix temp = Ar-abs(del);
          for(int i=0;i<temp.size1();i++)
            {
              cout << (i+1) << "] " << temp(i,0) << "\n";
            }
            }*/
      //cout << "abs(del) = " << abs(del) << "\n";
      //cout << "armax = " << armax << "   e1 = " << e1 << "\n";
      cout << "\tErr is " << Err << "\n";
      if(Err < SN)
        {
          cout << "\tI have converged\n";
          break;
        }
    }//end for loop
  /* for(int i=0;i<a.size1();i++)
    {
    cout << (i+1) << "] " << a(i,0) << "\n";
    }*/

  //set up h value vector for convolution
  Matrix h;
  if(q > 2)
    {
      h.resize(q*2-1,1);
      int h_index = 0;
      for(int i=q-1;i>=1;i--)
        {
          h(h_index++,0) = a(i,0)/2;
        }
      h(h_index++,0) = a(0,0);
      for(int i=1;i<q;i++)
        {
          h(h_index++,0) = a(i,0)/2;
        }
    }
  else
    {
      h.resize(1,1);
      h(0,0) = a(0,0);
    }

  //set up y vector for convolution
  Matrix y(3,1);
  y(0,0) = 1;
  y(1,0) = -2 * cos(wp);
  y(2,0) = 1;

  //do the convolutions
  for(int i=0;i<L/2;i++)
    {
      h = convolve(h, y) * (1.0/4.0);
    }
  int ind = ((N+1)/2) - 1;
  h(ind,0) += 1.0;

  return(h);

}

/**************************************************/

std::vector<int> localMax(ublas::matrix<double> X)
{
  std::vector<int> k;

  if(X.size1() < X.size2())
    X = t(X);
  int size = X.size1();

  for(int i=0;i<size-2;i++)
    {
      bool b1 = X(i,0) <= X(i+1,0);
      bool b2 = X(i+1,0) > X(i+2,0);
      if(b1 && b2)
        k.push_back(i+2);
    }

  if(X(0,0) > X(1,0))
    k.push_back(1);
  if(X(size-1,0) > X(size-2,0))
    k.push_back(size);

  std::sort(k.begin(), k.end());

  return k;
}


/**************************************************/

Matrix makePassband(int N, int L, int Fs, double fstop1, double fpass, double fstop2)
{

  printf("fstop1=%f, fpass=%f, fstop2=%f\n",fstop1, fpass, fstop2);
  double wp =  fpass /(Fs/2) * pi;
  double ws1 = fstop1/(Fs/2) * pi;
  double ws2 = fstop2/(Fs/2) * pi;

  printf("wp: %f, ws1: %f, ws2 %f\n",wp,ws1,ws2);

  return(chebbp2(N,L,wp,ws1,ws2));
}


/*************************************************/
//convolve
ublas::matrix<double> convolve(ublas::matrix<double> x, ublas::matrix<double> y)
{
  /*
    ASSUME:
     x,y are numeric
     conj = true
     type = open

  */

  int nx = x.size1();
  int ny = y.size1();
  int n = nx;

  //type is open
  Matrix x1(ny + nx - 1,1);

  int x1_index = 0;
  for(int i=0;i<ny-1;i++)
    {
      x1(x1_index++,0) = 0;
    }
  for(int i=0; i<nx; i++)
    {
      x1(x1_index++,0) = x(i,0);
    }

  if(nx > 1)
    {
      y.resize(ny + nx - 1, 1);
      for(int i=0;i<nx-1;i++)
        y(ny + i,0) = 0;
      n = y.size1();
    }

  ublas::matrix< complex<double> > xfft1 = fft(x1);
  //assume conj=TRUE
  ublas::matrix< complex<double> > yfft1 = Conj(fft(y));

  //assume real=TRUE
  ublas::matrix<double> ret = Re(fft(compProd(xfft1,yfft1),true)) * (1.0/n);

  return ret;
}

/**************************************************/


FilterResult filter(Matrix B,  Matrix x, FilterState state)
{
  Matrix A = createMatrix(1,1,1);

  if(B.size2() != 1 || A.size2() != 1)
    throw("B.size2() == 1 or A.size2() == 1");

  int nB = B.size1();
  int nA = A.size1();

  int nX = x.size2();
  int nDim = x.size1();

  Matrix state_x;
  Matrix state_y;
  if(state.empty)
    {
      state_x = createMatrix(0.0, nDim, nB-1);
      state_y = createMatrix(0.0, nDim, nA-1);
    }
  else
    {
      state_x = state.x;
      state_y = state.y;
    }
  if(state_x.size1() == 0 || state_x.size2()  == 0)
    state_x = createMatrix(0.0, nDim, nB-1);
  if(state_y.size1() == 0 || state_y.size2()  == 0)
    state_y = createMatrix(0.0, nDim, nA-1);

  state_x = cbind(state_x, x);
  int ncsx = ncol(state_x);

  Matrix v = x * B(0,0);
  //cout << "B(0,0) = " << B(0,0) << "\n";
  //cout << "v(0,0) = " << v(0,0) << "\n";

  writeTable(v,"v.txt");
  if(nB > 1)
    {
      for(int i=1; i<nB; i++)
        {
          v = v + submatrix(state_x, 0, 0, ncsx-i-nX+1, ncsx-i) * B(i,0);
        }
      state_x = submatrix(state_x,0,0, ncsx-nB+1, ncsx-1);
    }
 state_y = cbind(state_y, v);
  if(nA > 1)
    {
      Matrix Arev = -rev(Matrix(submatrix(A,1,0,0,0)));
      for(int i=0; i<nX; i++)
        {
          submatrix(state_y, 0, 0, nA+i-2, nA+i-2) = submatrix(state_y, 0, 0, nA+i-2, nA+i-2)
            + prod(submatrix(state_y, 0, 0, i, i+nA-3) , Arev);
        }
    }
  int ncsy = ncol(state_y);
  Matrix y = submatrix(state_y,0,0, ncsy-nX, ncsy-1);
  /*writeTable(x,"x.txt");
  writeTable(state_y,"sty.txt");
  writeTable(y,"y.txt");*/
  if(nA > 1)
    state_y = submatrix(state_y,0,0,ncsy-nA, ncsy-1);
  else
    state_y.resize(0,0);


  //Return values
  state.x = state_x;
  state.y = state_y;
  state.empty = false;
  FilterResult res;
  res.state = state;
  res.filtered = y;
  return(res);
}

