/*! SessionManager.cpp
 * \author Jeshua Bratman
 *
 * Manages saving and loading cebl model configuration
 */

#include "../CEBLModel.hpp"
#include "Serialization.hpp"
#include "Session.hpp"
#include "SessionManager.hpp"
#include "../TextUtils.hpp"

// Model internal classes.
//  SessionManager has to be a friend to each of these classes.
#include "DataSource.hpp"
#include "Training.hpp"
#include "ClassifiersConfig.hpp"
#include "FeaturesConfig.hpp"
#include "FilterConfig.hpp"
#include "DecisionConfig.hpp"


//----------------------------------------------------------------------
// CONSTRUCTORS / DESTRUCTORS


SessionManager::SessionManager(CEBLModel *model)
{
  this->model = model;
  current_session = new Session();
}

SessionManager::~SessionManager()
{
  delete current_session;
}


//----------------------------------------------------------------------
// PUBLIC METHODS

void SessionManager::save()
{
  if(shouldSaveAs())
    {
      throw FileException("Session has not yet been saved.");
    }
  else
    {
      this->updateSession();
      current_session->save();
    }
}

void SessionManager::saveAs(string filename)
{
  this->updateSession();
  current_session->save(filename.c_str());
}

void SessionManager::load(string filename)
{
  current_session->load(filename.c_str());
  this->updateModel();
}

bool SessionManager::shouldSaveAs()
{
  return current_session->shouldSaveAs();
}

string SessionManager::encodeKey(string str)
{
  std::replace(str.begin(), str.end(), ' ', '_');
  return str;
}

string SessionManager::decodeKey(string str)
{
  std::replace(str.begin(), str.end(), '_', ' ');
  return str;
}


//----------------------------------------------------------------------
// UPDATE MODEL AND SESSION



void SessionManager::updateModel()
{
  using namespace TextUtils;

  Session &s = *current_session;
  CEBLModel &m = *model;
  string temp;

  //--------------------------------------------------
  //data source
  s.setCurrentSection("data_source");
  if(s.exists("source"))
    {
      int temp = s.get<int>("source");
      m.dataSetSource(temp);
    }
  if(s.exists("stored_buffer"))
    {
      EEGData temp = s.get<ublas::matrix<double> >("stored_buffer");
      m.getDataSource()->setDataBuffer(temp);
    }

  //--------------------------------------------------
  //process
  s.setCurrentSection("data_process");
  if(s.exists("reference_enabled"))
    {
      bool temp = s.get<bool>("reference_enabled");
      m.processSetReferenceEnabled(temp);
    }
  if(s.exists("remove_enabled"))
    {
      bool temp = s.get<bool>("remove_enabled");
      m.processSetRemoveEnabled(temp);
    }
  if(s.exists("filter_enabled"))
    {
      bool temp = s.get<bool>("filter_enabled");
      m.processSetFilterEnabled(temp);
    }

  //--------------------------------------------------
  //channels
  s.setCurrentSection("channels");
  if(s.exists("configuration_string"))
    {
      string temp = s.get<string>("configuration_string");
      m.channelsSetConfigurationFromString(temp);
    }

  //--------------------------------------------------
  //training
  s.setCurrentSection("training");
  if(s.exists("training_data_is_loaded"))
    {
      bool temp = s.get<bool>("training_data_is_loaded");
      m.getTraining()->setDataIsLoaded(temp);
      if(temp)
        {
          if(s.exists("training_data"))
            {
              EEGTrainingData temp = s.get<EEGTrainingData>("training_data");
              m.getTraining()->setTrainingData(temp);
            }
        }
    }
  if(s.exists("num_classes"))
    {
      int temp = s.get<int>("num_classes");
      m.trainingSetNumClasses(temp);
    }
  if(s.exists("num_sequences"))
    {
      int temp = s.get<int>("num_sequences");
      m.trainingSetNumSequences(temp);
    }
  if(s.exists("sequence_length"))
    {
      int temp = s.get<int>("sequence_length");
      m.trainingSetSequenceLength(temp);
    }
  if(s.exists("pause_length"))
    {
      int temp = s.get<int>("pause_length");
      m.trainingSetPauseLength(temp);
    }

  if(s.exists("class_labels"))
    {
      std::vector<string> temp = s.get<std::vector<string> >("class_labels");
      m.trainingSetClassLabels(temp);
    }

  //--------------------------------------------------
  //features
  {
    s.setCurrentSection("features");
    if(s.exists("selected"))
      {
        string temp = s.get<string>("selected");
        m.featuresSetSelected(temp);
      }

    FeaturesConfig * cc = m.getFeaturesConfig();
    PluginLoader<Feature> *plugins = cc->getPluginLoader();
    //now load each plugin
    if(s.exists("names"))
      {
        vector<string> names;
        s.get("names",&names);
        for(unsigned i=0;i<names.size();i++)
          {
            Feature * p = plugins->getPlugin(names[i]);
            if(p != NULL)
              p->load(s.get<map<string, SerializedObject> >(encodeKey(names[i])
                                                            +"_internals"));
              
          }
      }
  }
  //--------------------------------------------------
  //classifiers
  {
    s.setCurrentSection("classifiers");
    if(s.exists("selected"))
      {
        string temp = s.get<string>("selected");
        m.classifiersSetSelected(temp);
      }

    ClassifiersConfig * cc = m.getClassifiersConfig();
    PluginLoader<Classifier> *plugins = cc->getPluginLoader();
    //now load each plugin
    if(s.exists("names"))
      {
        vector<string> names;
        s.get("names",&names);
        for(unsigned i=0;i<names.size();i++)
          {
            Classifier * p = plugins->getPlugin(names[i]);
            if(p != NULL)
              p->load(s.get<map<string, SerializedObject> >(encodeKey(names[i])
                                                            +"_internals"));
          }
      }
  }
  //--------------------------------------------------
  //decisions
  {
    s.setCurrentSection("decision");
    if(s.exists("selected"))
      {
        string temp = s.get<string>("selected");
        m.decisionSetSelected(temp);
      }

    DecisionConfig * cc = m.getDecisionConfig();
    PluginLoader<Decision> *plugins = cc->getPluginLoader();
    //now load each plugin
    if(s.exists("names"))
      {
        vector<string> names;
        s.get("names",&names);
        for(unsigned i=0;i<names.size();i++)
          {
            Decision * p = plugins->getPlugin(names[i]);
            if(p != NULL)
              p->load(s.get<map<string, SerializedObject> >(encodeKey(names[i])
                                                            +"_internals"));
          }
      }
  }
  //--------------------------------------------------
  //filters
  {
    s.setCurrentSection("filter");
    if(s.exists("selected"))
      {
        string temp = s.get<string>("selected");
        m.filterSetSelected(temp);
      }
    if(s.exists("lags"))
      {
        int temp = s.get<int>("lags");
        m.filterSetNumLags(temp);
      }
    if(s.exists("components"))
      {
        string temp = s.get<string>("components");
        m.filterSetSelectedComponentsString(temp);
      }


    FilterConfig * cc = m.getFilterConfig();
    PluginLoader<Filter> *plugins = cc->getPluginLoader();
    //now load each plugin
    if(s.exists("names"))
      {
        vector<string> names;
        s.get("names",&names);
        for(unsigned i=0;i<names.size();i++)
          {
            Filter * p = plugins->getPlugin(names[i]);
            if(p != NULL)
              p->load(s.get<map<string, SerializedObject> >(encodeKey(names[i])
                                                            +"_internals"));
          }
      }
  }
}

//------------------------------------------------------------------------------

void SessionManager::updateSession()
{
  Session &s = *current_session;
  CEBLModel &m = *model;

  //--------------------------------------------------
  //data source
  s.setCurrentSection("data_source");
  s ("source", m.dataGetSource())
    ("stored_buffer",m.dataGetStoredData().getMatrix());

  //--------------------------------------------------
  //process
  s.setCurrentSection("data_process");
  s ("reference_enabled", m.processGetReferenceEnabled())
    ("remove_enabled", m.processGetRemoveEnabled())
    ("filter_enabled", m.processGetFilterEnabled());

  //--------------------------------------------------
  //channels
  s.setCurrentSection("channels");
  s ("configuration_string", m.channelsGetConfigurationString());

  //--------------------------------------------------
  //device stream
  s.setCurrentSection("device");
  s ("location", m.deviceGetLocation())
    ("sample_rate",m.deviceGetSampleRate())
    ("block_size",m.deviceGetBlockSize());

  //--------------------------------------------------
  //file stream
  s.setCurrentSection("file_data_stream");
  s ("filename", m.fileStreamGetFilename());

  //--------------------------------------------------
  //training
  s.setCurrentSection("training");
  s ("training_data_is_loaded", m.trainingDataIsLoaded())
    ("training_data", m.trainingGetData())
    ("num_classes", m.trainingGetNumClasses())
    ("num_sequences", m.trainingGetNumSequences())
    ("sequence_length", m.trainingGetSequenceLength())
    ("pause_length", m.trainingGetPauseLength())
    ("class_labels", m.trainingGetClassLabels());

  //--------------------------------------------------
  //features
  {
    FeaturesConfig * cc = m.getFeaturesConfig();
    PluginLoader<Feature> *plugins = cc->getPluginLoader();
    vector<string> names = plugins->getNames();
    s.setCurrentSection("features");
    s ("selected", m.featuresGetSelected())
      ("names",names);
    //now save each plugin
    for(unsigned i=0;i<names.size();i++)
      {
        s(encodeKey(names[i]) + "_internals",
          plugins->getPlugin(names[i])->save());
      }
  }
  //--------------------------------------------------
  //classifiers
  {
    ClassifiersConfig * cc = m.getClassifiersConfig();
    PluginLoader<Classifier> *plugins = cc->getPluginLoader();
    vector<string> names = plugins->getNames();
    s.setCurrentSection("classifiers");
    s ("selected", m.classifiersGetSelected())
      ("names",names);
    //now save each plugin
    for(unsigned i=0;i<names.size();i++)
      {
        s(encodeKey(names[i]) + "_internals",
          plugins->getPlugin(names[i])->save());
      }
  }

  //--------------------------------------------------
  //filter
  {
    FilterConfig * cc = m.getFilterConfig();
    PluginLoader<Filter> *plugins = cc->getPluginLoader();
    vector<string> names = plugins->getNames();
    s.setCurrentSection("filter");
    s ("selected", m.filterGetSelected())
      ("names",names)
      ("components",m.filterGetSelectedComponentsString())
      ("lags",m.filterGetNumLags());
    //now save each plugin
    for(unsigned i=0;i<names.size();i++)
      {
        s(encodeKey(names[i]) + "_internals",
          plugins->getPlugin(names[i])->save());
      }
  }

  //--------------------------------------------------
  //decision
  {
    DecisionConfig * cc = m.getDecisionConfig();
    PluginLoader<Decision> *plugins = cc->getPluginLoader();
    vector<string> names = plugins->getNames();
    s.setCurrentSection("decision");
    s ("selected", m.decisionGetSelected())
      ("names",names);
    //now save each plugin
    for(unsigned i=0;i<names.size();i++)
      {
        s(encodeKey(names[i]) + "_internals",
          plugins->getPlugin(names[i])->save());
      }
  }
}
