//local includes
#include "MNF.hpp"

//std includes
#include <iostream>

//namespaces
using namespace cppR;
using namespace ublas;
using namespace std;


namespace CEBL
{
  void MNF::mnf(const ublas::matrix<double> & data)
  {
    matrix<double> X = data;

    int nSamples = ncol(X);
    X -= createMatrix(rep(rowMeans(X), nSamples), 0, nSamples);

    //R code: X %*% t(X)
    matrix<double> covX = prod(X, t(X));

    //first in pair is eigenvalues, second is eigenvectors
    //R code: eigen(covX)
    EigStruct<double> r = eigen(covX);

    //eigenvectors
    //R code: r$vectors
    matrix<double> Vx = r.vectors;

    //sqrt of eigenvalues
    ublas::vector<double> DSx = vsqrt(r.values);
    matrix<double> INVDSx = diag(1.0/DSx);

    //R code: matrix(0, nrow(X), nSamples)
    matrix<double> SX = createMatrix(0, nrow(X), nSamples);

    //R code: SX[,2:nSamples] <- X[,1:(nSamples-1)]
    submatrix(SX,0,0,1,nSamples-1) = submatrix(X,0,0,0,nSamples-2);


    //SX[,1] <- X[,2];
    column(SX,0) = column(X,1);

    matrix<double> Z1 = prod(SX, t(X));

    matrix<double> Z2 = prod(X, t(SX));

    //double s1 = frobeniusNorm(static_cast<matrix<double> >(X-SX)) / frobeniusNorm(X);

    //get singular value decomposition of Z2
    SvdStruct<double> svd_value = svd(Z2);

    //s2 is max of the singular values of Z2
    //double s2 = max(svd_value.d);

    matrix<double> Z3 = covX;


    matrix<double> Z4 = prod(SX, t(SX));


    matrix<double> Z = (Z3+Z4-Z1-Z2) * 0.5;

    //this part is a little ugly
    //R code: ZHAT <- INVDSx %*% t(Vx) %*% Z %*% Vx %*% INVDSx
    matrix<double> p1,p2,p3;
    p1 = prod(Vx, INVDSx);
    p2 = prod(Z, p1);
    p3 = prod(t(Vx),p2);
    matrix<double> ZHAT = prod(INVDSx, p3);
    //-----------

    r = eigen(ZHAT);
    matrix<double> PSIHAT = r.vectors;

    psi = prod(p1, PSIHAT);
    phi = prod(t(X), psi);

    phi = t(phi);
    psi = t(psi);
  }

  /*********************************************************/
  void MNF::make(const ublas::matrix<double> & data, const std::vector<int> &remove)
  {
    matrix<double> X = static_cast<matrix<double> >(data);
    int n_samples = ncol(X);
    int n_components = nrow(X);
    this->means = rowMeans(X);

    //R Code: Xm <- X - matrix(rep(means,nSamples),ncol=nSamples)
    matrix<double> Xm = X - createMatrix(rep(means, n_samples),0,n_samples);

    mnf(Xm);

    ublas::vector<double> z = rep(1,n_components);

    //set components of z specified in remove to 0.0
    for(unsigned int i=0; i < remove.size(); i++)
      //make sure subscript is within bounds
      if(unsigned(remove[i]) < z.size())
        //loop through columns of z
        //for(int j=0;j<z.size2();j++)
          //set cell to 0.0
          z[remove[i]] = 0.0;

    matrix<double> selector = diag(z);

    matrix<double> spsi = solve(psi);
    matrix<double> ppsi_sel = prod(spsi, selector);
    this->filter = prod(ppsi_sel
        		,psi);
    created = true;
  }

  //filter data
  ublas::matrix<double> MNF::apply(const ublas::matrix<double> &data) const
  {

    ublas::matrix<double> X = data;
    int n_samples = ncol(X);
    matrix<double> X1 = X - createMatrix(rep(means, n_samples),0,n_samples);
    matrix<double> filtered = prod(filter, X1);
    ublas::matrix<double> filtered_ret = filtered;
    return filtered_ret;
  }

  ublas::matrix<double> MNF::extract(const ublas::matrix<double> & data)
  {

    mnf(data);
    ublas::matrix<double> m = static_cast<ublas::matrix<double> >(phi * 20);
    return m;
  }


  map<string, SerializedObject> MNF::save() const
  {
    map<string, SerializedObject> ret;
    ret["created"] = serialize(created);
    ret["psi"] = serialize(psi);
    ret["phi"] = serialize(phi);
    ret["filter"] = serialize(filter);
    ret["means"] = serialize(means);

    return ret;
  }
  void MNF::load(map<string, SerializedObject> objects)
  {
    deserialize(objects["created"], created);
    deserialize(objects["psi"],psi);
    deserialize(objects["phi"],phi);
    deserialize(objects["filter"],filter);
    deserialize(objects["means"],means);
  }
}



/*************************************************************/
//DYNAMIC LOADING

extern "C" CEBL::Filter* ObjectCreate()
{
  return new CEBL::MNF;
}

extern "C" void ObjectDestroy(CEBL::Filter* p)
{
  delete p;
}
