#include "LDA.hpp"

using namespace cppR;
using namespace ublas;
using std::cout;
using std::cerr;

namespace CEBL {

  /*! Trains LDA classifier using training data.

    \param data training data
  */
  void LDA::train(const EEGTrainingData& data)
  {
    matrix<double> X = t(data.collapse());
    ublas::vector<int> Y = data.getTargets();

    // save vector of class numbers
    this->classes = unique(Y);
    int n_classes = classes.size();
    int n_samples = nrow(X);
    int n_features = ncol(X);

    // create an empty member variables, not necessary, but makes clear
    //  what the sizes of these are going to be
    this->covariance = createMatrix(0,n_features,n_features);
    this->covarianceInv = createMatrix(0,n_features,n_features); 
    this->priors = createVector(0,n_classes);
    this->means = createMatrix(0, n_classes, n_features);
    this->weights = createMatrix( 0,n_classes, n_features );
    this->bias = createVector(0,n_classes);


    // compute the covariance matrix by looping though classes and updating
    for(int k=0; k < n_classes; k++) {

      //allow for inturruption here
      this->inturruptionPoint();

      std::vector<bool> mask = createMask(classes[k], Y);
      matrix<double> Z = rowMask(X,mask);

      int n_samples_this_class = nrow(Z);
      priors(k) = double(n_samples_this_class) / n_samples;

      row(this->means,k) = colMeans(Z);

      matrix<double> temp = 
        createMatrix(ublas::vector<double>(row(means,k)),
                     n_samples_this_class, n_features, true);
      matrix<double> Zc = Z - temp;

      //update covariance
      this->covariance = this->covariance + prod(t(Zc), Zc);
    }

    //divide covariance by N - K
    this->covariance = this->covariance / (n_samples - n_classes);

    //invert covariance
    this->covarianceInv = solve(this->covariance);

    //compute weights and biases
    this->weights = prod(means,covarianceInv);
    matrix<double> wm = compProd(weights,means);
    for (unsigned int i=0; i < priors.size(); i++)
      priors[i] = log(priors[i]);
    this->bias = -0.5 * rowSums(wm) + priors;

    // set trained flag
    trained = true;
  }


  /*! Classifies data samples.

    \param data samples to classify
    
    \return vector of predicted classes
  */
  ublas::vector<int> LDA::use(const ublas::matrix<double> & data)
  {
    //X is transpose of data
    ublas::matrix<double> X = t(data);

    int n_classes = classes.size();
    int n_samples = nrow(X);

    //compute discriminant functions
    ublas::matrix<double> disc_functions(n_samples,n_classes);
    disc_functions = prod(X,t(weights));

    //add class bias to discriminant functions for each sample
    for (unsigned int s = 0; s < disc_functions.size1(); s++)
      row(disc_functions,s) = row(disc_functions,s) + bias;

    //for each column of deltas find max
    ublas::vector<int> predicted_classes;
    predicted_classes.resize(n_samples);

    for(int j=0; j<nrow(disc_functions); j++) {
      ublas::vector<double> disc_functionsRow = row(disc_functions,j);
      predicted_classes[j] = classes[whichMax(disc_functionsRow)];
    }

    //----------------------------------------------------------------------
    //compute probabilities for each class

    if(compute_probs)
      {
        // get vector of maximum for each row in disc functions
        ublas::vector<double> max_disc = 
          cppR::rowApply<double>(disc_functions,&cppR::max<double>);

        // create a matrix of the same size as disc functions whose
        //  columns are max_disc
        matrix<double> max_discs = 
          createMatrix(max_disc,disc_functions.size1(),disc_functions.size2());

        //take the exponent of the disc functions subtracted from their max
        matrix<double> probabilities = 
          apply(matrix<double>(disc_functions - max_discs), exp);

        //get rowsums from probabilities
        ublas::vector<double> sum_p = cppR::rowSums(probabilities);
        
        //divide probabilities by rowsums
        matrix<double> sum_p_rep = 
          createMatrix(sum_p,probabilities.size1(),probabilities.size2());
        probabilities = compDiv(probabilities, sum_p_rep);

        //copy to probabilities member variable
        this->probabilities.resize(probabilities.size1());
        for(unsigned i=0;i<probabilities.size1();i++)
          this->probabilities[i] = 
            asStdVector(ublas::vector<double>(row(probabilities,i)));

      }


    //return prediction
    return predicted_classes;
  }

  /*! Serializes state information.
    
    \return map of serialized state
  */
  map<string, SerializedObject> LDA::save() const
  {
    map<string, SerializedObject> ret;
    ret["trained"] = serialize(trained);
    ret["means"] = serialize(means);
    ret["covariance"] = serialize(covariance);
    ret["covarianceInv"] = serialize(covarianceInv);
    ret["classes"] = serialize(classes);
    ret["weights"] = serialize(weights);
    ret["bias"] = serialize(bias);
    return ret;
  }

  /*! Loads state from serialization.
    \param objects map of serialized objects
  */
  void LDA::load(map<string, SerializedObject> objects)
  {
    deserialize(objects["trained"],trained);
    deserialize(objects["means"],means);
    deserialize(objects["covariance"],covariance);
    deserialize(objects["covarianceInv"],covarianceInv);
    deserialize(objects["classes"],classes);
    deserialize(objects["weights"],weights);
    deserialize(objects["bias"],bias);
  }


  //! get list of parameters needed for classifier
  std::map<std::string, CEBL::Param> LDA::getParamsList()
  {
    std::map<std::string, CEBL::Param> params;
    CEBL::Param probs("Compute Probabilities",
                    "Should LDA compute probabilities? If this is checked, LDA will need to create covariance matrices when training.",
                    compute_probs);
    params["probs"] = probs;

    return params;
  }

  //! set list of parameters
  void LDA::setParamsList( std::map<std::string, CEBL::Param> &p)
  {
    compute_probs = p["probs"].getBool();
  }
}


/*************************************************************/
//DYNAMIC LOADING

extern "C" CEBL::Classifier* ObjectCreate()
{
  return new CEBL::LDA;
}

extern "C" void ObjectDestroy(CEBL::Classifier* p)
{
  delete p;
}
