/*! RealTimeClassification.cpp
 * \author Jeshua Bratman
 *
 * Controls real-time classifying.
 */

#include "RealTimeClassification.hpp"
#include <boost/thread/thread.hpp>
#include <boost/thread/mutex.hpp>
#include <boost/bind.hpp>
#include "../CEBLModel.hpp"
#include "cppR/cppR.hpp"
//----------------------------------------------------------------------
// CONSTRUCTORS / DESTRUCTORS


RealTimeClassification::RealTimeClassification(CEBLModel *model)
{
  this->model = model;
  this->is_classifying = false;
  this->timeout_length = 100;
  this->currently_training_classifier = false;
  this->train_classifier_thread = NULL;
  this->selected_class = -1;
  this->training_failed = false;
}

RealTimeClassification::~RealTimeClassification()
{
}

//----------------------------------------------------------------------
//GETTING OPERATIONS
bool RealTimeClassification::isReady() const
{
  return model->classifierIsTrained();
}
//! check whether data is currently being classified
bool RealTimeClassification::isClassifying() const
{
  return is_classifying;
}
//! reads queue of classified sequences and clears the queue
std::vector<int> RealTimeClassification::readClassificationQueue()
{
  std::vector<int> ret = classification_queue;
  clearClassificationQueue();
  return ret;
}

//! reads queue of classified sequences without clearing the queue
std::vector<int> RealTimeClassification::peekClassificationQueue() const
{
  return classification_queue;
}


//----------------------------------------------------------------------
//SETTING OPERATIONS
void RealTimeClassification::clearClassificationQueue()
{
  {
    boost::mutex::scoped_lock lock(thread_lock);
    classification_queue.resize(0);
  }
}


//----------------------------------------------------------------------
//CONTROL CLASSIFICATION
void RealTimeClassification::trainClassifier()
{
  this->training_failed = false;
  this->currently_training_classifier = true;
  this->halt_training = false;
  //--------------------------------------------------
  // process and featurize training data
  EEGTrainingData training_data = this->model->trainingGetData();
  EEGTrainingData processed_data;
  EEGData temp;
  processed_data.reserve(training_data.numClasses(),training_data.numSequences());
  for(int cls = 0; cls < training_data.numClasses(); cls++)
    {
      for(int seq = 0; seq < training_data.numSequences(cls); seq++)
        {
          //halt the training of classifier
          if(halt_training)
            {
              this->currently_training_classifier = false;
              this->training_failed = true;
              return;
            }
          //process the data
          temp = training_data.get(cls,seq);
          temp = model->processData(temp);
          //extract features
          try
            {
              model->featureReset();
              cout << "Featurizing class " << cls << ", seq " << seq << "\n";
              temp = model->featuresExtract(temp);
            }
          catch(exception &e)
            {
              this->training_failed = true;
              cerr << e.what() << ". Make sure you have installed the most recent feature plugins.\n";
              this->currently_training_classifier = false;
              return;
            }
          //add to new training data
          processed_data.set(cls,seq,temp);
          //halt the training of classifier
          if(halt_training)
            {
              this->currently_training_classifier = false;
              this->training_failed = true;
              return;
            }
        }
    }

  //--------------------------------------------------
  // train classifier on features
  try
    {
      this->model->classifierTrain(processed_data);
    }
  catch(exception &e)
    {
      cerr << e.what() << "\n";
      this->currently_training_classifier = false;
      this->training_failed = true;
      return;
    }
  this->currently_training_classifier = false;

}

void RealTimeClassification::startClassifying()
{
  if(isReady())
    {
      this->model->dataStart();
      //init the decision maker
      model->decisionInit(model->trainingGetNumClasses());
      try
        {
          this->timeoutStart();
          is_classifying = true;
        }
      catch(...)
        {
          throw ClassificationException("Failed to start timeout for real-time classification.");
        }
    }
  else
    {
      throw ClassificationException("Classifier is not ready.");
    }

}
void RealTimeClassification::stopClassifying()
{
  this->haltAndJoin();
  this->model->dataStop();
  this->is_classifying = false;
  this->clearClassificationQueue();
}

void RealTimeClassification::timeoutFunction()
{
  //read processed data
  EEGData new_data = model->dataReadAll();
  if(new_data.size1() == 0)
    return;

  //extract features from the data
  EEGData features;
  try
    {
      features = model->featuresExtract(new_data);
    }
  catch(...)
    {
      cerr << "Caught exception when extracting features.\n";
      halt = true;
      this->is_classifying = false;
      return;
    }
  //classify
  ublas::vector<int> classes;
  try
    {
      classes = model->classifierUse(features);
    }
  catch(exception &e)
    {
      cerr << "Caught exception when classifying: " << e.what() << "\n";
      halt = true;
      this->is_classifying = false;
      return;
    }
  catch(...)
    {
      cerr << "Caught exception when classifying: No Message\n";
      halt = true;
      this->is_classifying = false;
      return;
    }
  //add these classes to the queue
  {
    boost::mutex::scoped_lock lock(thread_lock);
    for(unsigned i=0;i<classes.size();i++)
      classification_queue.push_back(classes[i]);
  }

  //----------------------------------------
  //update proportions
  {
    using namespace cppR;

    // update decision
    if(model->classifierGetUseProbs())
      {  
        std::vector<std::vector<double> > probs 
          = model->classifierGetLastProbs();
        if(probs.size() > 0)
          {
            model->decisionUpdateWithProbabilities(probs);
          }
        else
          {
            cerr << "Realtime Classificaton: Probabilities are empty. Using predicted classes instead.\n";
            model->decisionUpdateWithClassification(classes);
          }
      }
    else
      {
        model->decisionUpdateWithClassification(classes);
      }

    // get decision proportions
    this->class_proportions = model->decisionDecideClasses();
    ublas::vector<double> props = 
      asUblasVector(this->class_proportions); 

    // if a proportion has reached 100%, set selected class
    if(max(props) >= 1.0)
      {
        this->selected_class = whichMax(props);
      }
  }
}



//----------------------------------------------------------------------
// Threaded Train Classifier

bool RealTimeClassification::isTrainingClassifier()
{
  return currently_training_classifier;
}

void runTrainClassifier(RealTimeClassification *obj)
{
  obj->currently_training_classifier = true;
  try
    {
      obj->trainClassifier();
    }
  catch(exception &e)
    {
      cerr << "Exception caught when training classifier: " << e.what() << "\n";
    }
  obj->currently_training_classifier = false;
  delete obj->train_classifier_thread;
  obj->train_classifier_thread = NULL;
  cout << "Classifier training stopped.\n";
}

void RealTimeClassification::trainClassifierThreaded()
{
  train_classifier_thread = new boost::thread(boost::bind(&runTrainClassifier, this));
}

void RealTimeClassification::trainClassifierHalt()
{
  if(currently_training_classifier && train_classifier_thread != NULL)
    {
      halt_training = true;
      //try halting the classifier
      model->classifierHaltTrain();
      model->featuresHalt();
    }
}


