/*! Training.cpp
 * \author Jeshua Bratman
 *
 * Contains training options including num classes, sequences.
 * Using data from data source, collects data for training.
 */


#include "Training.hpp"
#include "DataProcess.hpp"
#include "FilterConfig.hpp"
#include "../CEBLModel.hpp"
#include "../cppR/cppR.hpp"
#include "DataIO.hpp"
#include "../TextUtils.hpp"
using namespace cppR;

//----------------------------------------------------------------------
// CONSTRUCTORS / DESTRUCTORS


Training::Training(CEBLModel *model)
{
  this->model = model;

  //training options
  this->num_sequences = 3;
  this->sequence_length = 5;
  this->pause_length = 1;

  //feedback options
  this->classification_feedback = false;
  this->classifier_trained = false;

  //data
  this->data_is_loaded = false;
  this->training_data.clear();
  this->training_data_filtered.clear();
  this->training_data_filtered.setFiltered(true);

  //file loading
  this->data_file_loaded = false;

  //training process
  this->training_is_active = false;
  this->current_training_class = 0;
  this->current_training_sequence = 0;
  this->timeout_length = 100; //ms
  this->waiting = false;
  this->training_failed = false;
  this->training_index = -1;

  //set to 3 classes by default
  this->setNumClasses(3);
}

Training::~Training()
{

}


//----------------------------------------------------------------------

//GETTING OPERATIONS


std::vector<string> Training::getClassLabels()
{
  return class_labels;
}

string Training::getClassLabel(int class_num)
{
  return class_labels[class_num];
}

int Training::getNumClasses()
{
  return num_classes;
}

int Training::getNumSequences()
{
  return num_sequences;
}

int Training::getSequenceLength()
{
  return sequence_length;
}

int Training::getPauseLength()
{
  return pause_length;
}

EEGTrainingData Training::getData()
{
  return training_data;
}

bool Training::dataIsLoaded()
{
  return data_is_loaded;
}

bool Training::isDataFileLoaded()
{
  return data_file_loaded;
}

string Training::getDataFilename()
{
  return data_filename;
}

bool Training::isActive()
{
  return training_is_active;
}

bool Training::failed()
{
  return training_failed;
}

string Training::getFailureMessage()
{
  return failure_message;
}

bool Training::isPaused()
{
  return waiting;
}

int Training::getTrainingClass()
{
  return current_training_class;
}

int Training::getTrainingSequence()
{
  return current_training_sequence;
}

//--------------------------------------------------------------------------------

//SETTING OPERATIONS




void Training::setNumClasses(int n)
{
  this->num_classes = n;
  //class labels
  if(class_labels.size() < unsigned(num_classes))
    {
      int start = class_labels.size();
      for(int i=start;i<num_classes;i++)
        {
          class_labels.push_back("Class " + TextUtils::IntToString(i));
        }
    }
}

void Training::setNumSequences(int n)
{
  this->num_sequences = n;
}

void Training::setSequenceLength(int n)
{
  this->sequence_length = n;
}

void Training::setPauseLength(int n)
{
  this->pause_length = n;
}

void Training::setClassLabels(std::vector<string> labels)
{
  this->class_labels = labels;
}

void Training::setClassLabel(int class_number, string label)
{
  this->class_labels[class_number] = label;
}

void Training::loadData(string filename)
{
  this->training_data = DataIO::loadTrainingDataFromFile(filename);
  this->setDataIsLoaded(true);
  this->data_file_loaded = true;
  this->data_filename = filename;
}

void Training::clearData()
{
  this->training_data.clear();
  this->training_data_filtered.clear();
}

void Training::saveData(string filename)
{
  if(this->training_data_filtered.numClasses() == this->training_data.numClasses())
    {
      int filter_lags = this->model->filterGetNumLags();
      std::vector<int> removed_components = this->model->filterGetSelectedComponents();
      ublas::matrix<double> filter_matrix = this->model->filterGetFilterMatrix();
      DataIO::saveTrainingSessionToFile(this->training_data,
                                        filename,
                                        this->training_data_filtered,
                                        filter_lags,
                                        removed_components,
                                        filter_matrix);
    }
  else
    {
      DataIO::saveTrainingDataToFile(this->training_data,
                                     filename);
    }
}

//--------------------------------------------------------------------------------
// TRAINING PROCESS


void Training::start()
{
  if(training_is_active && !classification_feedback)
    {
      return;
    }
  //check source to see if it is ready
  try
    {
      this->model->dataStart();
      if(!model->dataIsStarted())
        {
          throw TrainingException("Data source is not ready.");
        }
      this->model->dataStop();
    }
  catch(...)
    {
      throw TrainingException("Data source is not ready.");
    }

  this->initializeTraining();

  //start the thread
  timeoutStart();
  cout << "starting timeout\n";
}

void Training::initializeTraining()
{
  this->training_is_active = true;
  this->waiting = false;
  this->training_failed = false;
  this->training_index = -1;
  this->training_data.clear();
  this->training_data.reserve(num_classes,num_sequences);
  this->training_data.setChannelNames(this->model->channelsGetEnabledNames());
  this->training_data.setClassLabels(this->class_labels);
  this->training_data_filtered.clear();

  //create randomized list of classes to train on
  ublas::vector<int> ord = sample(rep(vectorRange(0,num_classes-1),
                                      num_sequences));
  training_class_ordering = asStdVector(ord);
  current_class_sequence = asStdVector(rep(0,num_classes));

  //  printVector(training_class_ordering);
  this->training_data.setSequenceOrder(training_class_ordering);

}

void Training::stop()
{
  if(!isStarted())
    {
      training_is_active = false;
      return;
    }
  else
    {
      // let the thread finish
      haltAndJoin();

      //stop the training process
      stopFailure("Stopped manually.");
    }
}

//stopping functions
void Training::stopSuccess()
{
  training_failed = false;
  training_is_active = false;
  data_is_loaded = true;
  classifier_trained = false;
  halt = true;
}

void Training::stopFailure(string msg)
{
  if(training_is_active)
    {
      training_failed = true;
      training_is_active = false;
      data_is_loaded = false;
      failure_message = msg;
      halt = true;
      classifier_trained = false;
      cout << "Training stopped: " << msg << "\n";
    }
}


void Training::timeoutFunction()
{
  int wait_time = 1000 * pause_length; // 1 second
  int train_time = 1000 * sequence_length;

  if(!this->halt)
    {

      //initialize
      if(training_index < 0)
        {
          training_index = 0;
          training_timer.restart();
          waiting = true;
        }
      //set the current training class
      if(unsigned(training_index) >= training_class_ordering.size())
        {
          stopFailure("Training index out of range.");
          return;
        }
      current_training_class = training_class_ordering.at(training_index);
      current_training_sequence =
        current_class_sequence[current_training_class];
      //if we are waiting, don't collect data, but wait
      if(waiting)
        {
          if(training_timer.elapsed() > wait_time)
            {
              //start the data source and begin collecting
              waiting = false;
              //set current class etc and start recording
              cout << "* collecting for class " << current_training_class << " sequence " << current_class_sequence[current_training_class] <<"\n";

              //try starting the data source
              try
                {
                  this->model->dataStart();
                }
              catch(exception & e)
                {
                  stopFailure(e.what());
                }

              //reset the timer
              training_timer.restart();
            }
        }
      //we are not waiting, now we are collecting
      else
        {
          //save the recorded data
          try
            {
              EEGData new_data = this->model->dataReadAllRaw();
              EEGData unfiltered = this->model->getDataProcess()->process(new_data,true,true,false);
              bool filter_enabled = this->model->getDataProcess()->getFilterEnabled();
              bool filter_trained =  this->model->getFilterConfig()->isTrained();
              int cls = current_training_class;
              int seq = current_class_sequence[cls];
              //save unfiltered data
              training_data.append(cls, seq, unfiltered);
              // cout << "class " << cls << " seq " << seq << " has " << training_data.get(cls,seq).numSamples() << "\n";

              //also save the filtered version if filter is enabled and trained
              EEGData filtered;
              if(filter_enabled && filter_trained)
                {
                  filtered = this->model->getDataProcess()->process(new_data,true,true,true);
                  training_data_filtered.append(cls, seq, filtered);
                }

              //if we are providing classification feedback, classify these
              //new samples using the previously trained classifier
              if(classification_feedback && classifier_trained)
                {
                  if(filter_enabled && filter_trained)
                    this->classifySamples(filtered);
                  else
                    this->classifySamples(unfiltered);
                }
            }
          catch(exception & e)
            {
              stopFailure(e.what());
            }

          //see if it is time to stop recording
          if(training_timer.elapsed() > train_time)
            {
              training_timer.restart();
              cout << "* done collecting for class " << current_training_class <<  " sequence " << current_class_sequence[current_training_class] << "\n";

              //try stopping the data source
              try
                {
                  this->model->dataStop();
                }
              catch(exception & e)
                {
                  stopFailure(e.what());
                }
              //we are done
              //cout << "index = " << training_index << ", size = " << training_class_ordering.size() << "\n";
              if(unsigned(training_index) >= training_class_ordering.size()-1)
                {
                  //if we are just training. finish.
                  if(!classification_feedback)
                    {
                      stopSuccess();
                      cout << "* training complete\n";
                    } 
                  //if we want feedback, we need to now train the classifier
                  //and do it again
                  else
                    {
                      trainClassifierAndContinue();
                    }
                }
              else
                {
                  //increment class index
                  training_index++;
                  current_class_sequence[current_training_class]++;

                  //go back to waiting
                  waiting = true;
                  current_training_class = training_class_ordering.at(training_index);
                  training_timer.restart();
                }
            }
        }
    }

  if(this->halt)
    {
      this->model->dataStop();
      return;
    }
}


//------------------------------------------------------------------------------
// TRAINING FEEDBACK

bool Training::isTrainingClassifier()
{ 
  return model->realtimeIsTrainingClassifier(); 
}

void Training::classifySamples(EEGData samples)
{
  EEGData features;
  try
    {
      features = model->featuresExtract(samples);
    }
  catch(...)
    {
      cerr << "Caught exception when extracting features.\n";
      stopFailure("Failed to extract features.");
      return;
    }
  //classify
  ublas::vector<int> classes;
  try
    {
      classes = model->classifierUse(features);
    }
  catch(exception &e)
    {
      cerr << "Caught exception when classifying: " << e.what() << "\n";
      stopFailure("Failed to classify.");
      return;
    }
  catch(...)
    {
      cerr << "Caught exception when classifying: No Message\n";
      stopFailure("Failed to classify.");
      return;
    }

  //----------------------------------------
  //update proportions
  {
    using namespace cppR;
    // update decision
    if(model->classifierGetUseProbs())
      {
        model->decisionUpdateWithProbabilities(model->classifierGetLastProbs());
      }
    else
      {
        model->decisionUpdateWithClassification(classes);
      }
    // get decision proportions
    this->class_proportions = model->decisionDecideClasses();
  }
}


void Training::trainClassifierAndContinue()
{
  cout << "training classifier\n";
  training_failed = false;
  data_is_loaded = true;
  classifier_trained = false;
  if(!data_is_loaded)
    {
      stopFailure("Tried to train classifier without training data.");
      return;
    }
  else
    {
      //train classifier on most recent training data
      cout << training_data << "\n";
      model->realtimeTrainClassifierThreaded();
      while(model->realtimeIsTrainingClassifier())
        {
          this->sleep(1000);
        }
      if(!model->realtimeIsReady() || model->realtimeLastTrainFailed())
        {
          stopFailure("Failed to train classifier.");
        }
      //if it succeeded, start training again
      else
        {
          model->decisionInit(num_classes);
          classifier_trained = true;
          initializeTraining();
        }
    }  
}
