#include "QDA.hpp"

using namespace cppR;
using namespace ublas;
using namespace std;

namespace CEBL {

  //! Train classifier over training data
  void QDA::train(const EEGTrainingData& data) {

    bool debug = false;

    //X is nsamples by nfeatures
    ublas::matrix<double> X = t(data.collapse());
    if(debug)
      cout << "Starting QDA Training\n"
           << "X = matrix[" << X.size1() << "," << X.size2() << "]\n";

    ublas::vector<int> Y = data.getTargets();

    classes = unique(Y);
    int nClasses = using_classes;
    int classesSize = classes.size();

    if(using_classes<classesSize){
      cout << "More than the requested number of classes was collected.  Only the first " << nClasses << " will be trained."<<endl;
    }
    else if(using_classes > classesSize){
      cerr << "Data was only collected for " << classes.size() << " classes.  Cannot train " << using_classes << " classes." <<endl;
      trained = false;
      return;
    }

    int nSamples = nrow(X);
    int nFeatures = ncol(X);
    covariances.resize(nClasses);
    covariancesInv.resize(nClasses);
    for (int k = 0; k < nClasses; k++) {
      //allow for inturruption here
      this->inturruptionPoint();

      covariances[k] = createMatrix(0,nFeatures,nFeatures);
      covariancesInv[k] = createMatrix(0,nFeatures,nFeatures);
    }
    covariancesDet = createVector(0,nClasses);

    priors = rep(0, nClasses);

    means = createMatrix(0, nClasses, nFeatures);
    if(debug)
      cout << "Classes: " << classes << endl << flush;

    for(int k=0; k< nClasses; k++) {
      //allow for inturruption here
      this->inturruptionPoint();

      // Select samples (X rows) corresponding to targets
      // (Y) of particular class
      std::vector<bool> mask = createMask(classes[k], Y);
      ublas::matrix<double> Z(count(k,classes), nFeatures);
      Z = rowMask(X, mask);
      // Does this work instead?    ublas::matrix<double> Z = rowMask(X, mask);

      int nSamplesThisClass = nrow(Z);

      // Priors are proportion of sample from each class
      priors[k] = double(nSamplesThisClass) / nSamples;

      // Class means
      row(means,k) = colMeans(Z);

      //R Code: x.c[mask,] <- x[mask,] - matrix( clsf$means[k,], N.k, p, byrow=TRUE);
      matrix<double> temp = createMatrix(ublas::vector<double>(row(means,k)),
                                         nSamplesThisClass, nFeatures, true);
      matrix<double> Zc = Z - temp;

      //allow for inturruption here
      this->inturruptionPoint();

      covariances[k] =  prod(t(Zc), Zc) / nSamplesThisClass;
      covariancesInv[k] = solve(covariances[k]);
      covariancesDet[k] = det(covariances[k]);

      //CHECK RANK
      // if(rank(c) < nrow(c)) {
      //             cout << "Rank is too small. Inflating covariance matrix\n";
      //             c = c * .9 + diag(.1, nrow(c));
      //          }

    }

    //trained_lags = nFeatures/19-1; //the 19 here is from having 19 channels activated
                                     //       when this was written
                                     //change when/if channels are recorded in classifier
    trained_classes = nClasses;
    trained = true;
  }

  //! Use classifier on data and return classes
  ublas::vector<int> QDA::use(const ublas::matrix<double> & data)
  {
    if(data.size1() == 0)
      {
        cerr << "QDA use error: 0 size data matrix given.\n";
        ublas::vector<int> ret;
        return ret;
      }

    ublas::matrix<double> X = data;
    X = t(X);

    //ublas::vector<int> Y = data.getMatrixClassVector();
    int nClasses = using_classes;
    int nSamples = nrow(X);
    int nFeatures = ncol(X);

    ublas::matrix<double> disc_functions(nSamples,nClasses);

    //create discriminate function for each class
    for(int k=0; k< nClasses; k++)
      {
        matrix<double> temp = createMatrix(ublas::vector<double>(row(means,k)),
                                           nSamples, nFeatures, true);
        ublas::matrix<double> Xc = X - temp; //Xc is X - mu_k
        double scalarpart = -0.5 * log(covariancesDet[k]) + log(priors[k]);
        ublas::matrix<double> a = prod( Xc,covariancesInv[k]);
        ublas::matrix<double> sa = compProd(a,Xc);
        ublas::vector<double> vectorpart = -0.5 * rowSums(sa);
        for(unsigned int vi=0; vi<vectorpart.size(); vi++)
          disc_functions(vi,k) = vectorpart[vi] + scalarpart;

      }
    //for each column of deltas find max
    ublas::vector<int> predicted_classes;
    predicted_classes.resize(nSamples);

    for(int j=0; j<nrow(disc_functions); j++)
      {
        ublas::vector<double> disc_functionsRow = row(disc_functions,j);
        predicted_classes[j] = classes[whichMax(disc_functionsRow)];
      }

    //----------------------------------------------------------------------
    //compute probabilities for each class
    if(compute_probs)
      {
        // get vector of maximum for each row in disc functions
        ublas::vector<double> max_disc =
          cppR::rowApply<double>(disc_functions,&cppR::max<double>);

        // create a matrix of the same size as disc functions whose
        //  columns are max_disc
        matrix<double> max_discs =
          createMatrix(max_disc,disc_functions.size1(),disc_functions.size2());

        //take the exponent of the disc functions subtracted from their max
        matrix<double> probabilities =
          apply(matrix<double>(disc_functions - max_discs), exp);

        //get rowsums from probabilities
        ublas::vector<double> sum_p = cppR::rowSums(probabilities);

        //divide probabilities by rowsums
        matrix<double> sum_p_rep =
          createMatrix(sum_p,probabilities.size1(),probabilities.size2());
        probabilities = compDiv(probabilities, sum_p_rep);

        //copy to probabilities member variable
        this->probabilities.resize(probabilities.size1());
        for(unsigned i=0;i<probabilities.size1();i++)
          this->probabilities[i] =
            asStdVector(ublas::vector<double>(row(probabilities,i)));

      }

    //----------------------------------------------------------------------
    //return predicted classes
   return predicted_classes;
  }



  //! serialize for saving to archive
  map<string, SerializedObject> QDA::save() const
  {
    map<string, SerializedObject> ret;
    ret["trained"] = serialize(trained);
    ret["priors"] = serialize(priors);
    ret["means"] = serialize(means);
    ret["covariances"] = serialize(covariances);
    ret["covariancesInv"] = serialize(covariancesInv);
    ret["covariancesDet"] = serialize(covariancesDet);
    ret["classes"] = serialize(classes);
    ret["trained_classes"] = serialize(trained_classes);
    ret["using_classes"] = serialize(using_classes);
    ret["using_lags"] = serialize(using_lags);
    ret["trained_lags"] = serialize(trained_lags);
    return ret;
  }

  //! serialize class to load form archive
  void QDA::load(map<string, SerializedObject> objects)
  {
    deserialize(objects["trained"],trained);
    deserialize(objects["priors"],priors);
    deserialize(objects["means"],means);
    deserialize(objects["covariances"],covariances);
    deserialize(objects["covariancesInv"],covariancesInv);
    deserialize(objects["covariancesDet"],covariancesDet);
    deserialize(objects["classes"],classes);
    deserialize(objects["trained_classes"],trained_classes);
    deserialize(objects["using_classes"],using_classes);
    deserialize(objects["using_lags"],using_lags);
    deserialize(objects["trained_lags"],trained_lags);
  }



  //! get list of parameters needed for classifier
  std::map<std::string, CEBL::Param> QDA::getParamsList()
  {
    std::map<std::string, CEBL::Param> params;
    CEBL::Param probs("Compute Probabilities",
                    "Should QDA compute probabilities when you use the classifier?",
                    compute_probs);
    params["probs"] = probs;

    return params;
  }

  //! set list of parameters
  void QDA::setParamsList( std::map<std::string, CEBL::Param> &p)
  {
    compute_probs = p["probs"].getBool();
  }
}



/*************************************************************/
//DYNAMIC LOADING

extern "C" CEBL::Classifier* ObjectCreate()
{
  return new CEBL::QDA;
}

extern "C" void ObjectDestroy(CEBL::Classifier* p)
{
  delete p;
}
