#include "MSPRT.hpp"
#include "cppR/cppR.hpp"

using namespace cppR;

namespace CEBL {

  MSPRT::MSPRT()
  {
    this->threshold = 0.8;
    this->g = 0.01;
    this->num_classes = 3;
    this->plugin_name = "MSPRT";
  }

  MSPRT::~MSPRT()
  {
  }



  //! get list of parameters needed for decision process
  std::map<std::string, CEBL::Param> MSPRT::getParamsList()
  {
    std::map<std::string, CEBL::Param> params;
    CEBL::Param thresh("Threshold", "", this->threshold);
    thresh.setStep(0.001);
    thresh.setMax(10);
    thresh.setMin(0);
    params["thresh"] = thresh;

    CEBL::Param g("g", "Gain parameter", this->g);
    g.setStep(0.001);
    g.setMax(1.0);
    g.setMin(0.0);
    params["g"] = g;

    return params;
  }


  //! set list of parameters
  void MSPRT::setParamsList( std::map<std::string, CEBL::Param> &p)
  {

    double old_threshold = this->threshold;
    double old_g = this->g;

    this->threshold = p["thresh"].getDouble();
    this->g  = p["g"].getDouble();

    if(old_g != g || old_threshold != threshold)
      {
        this->init(num_classes);
      }
  }


  void MSPRT::updateWithProbabilities(std::vector<double> probs)
  {
    if(sums.size() != probs.size())
      {
        cerr << "MSPRT: size of probability vectory doesn't seem right. Did you initialize the decision\n";
        return;
      }
    sums = sums + cppR::asUblasVector(probs);
    ublas::vector<double> y = g * sums;
    double logsumexp = log(sum(apply(y,exp)));
    this->log_probs = y - rep(logsumexp,y.size());
  }

  std::vector<double> MSPRT::decideClasses()
  {
    ublas::vector<double> percents;
    double expthresh = 1.0 / exp(log_threshold);
    percents = apply(log_probs, exp) * expthresh;

    //reset if we have selected a class
    if(max(percents) >= 1.0)
      {
        init(num_classes);
      }
    return cppR::asStdVector(percents);
  }

  void MSPRT::init(int num_classes)
  {
    this->log_threshold = log(this->threshold);
    this->num_classes = num_classes;
    this->sums = cppR::rep(0,num_classes);
  }

  //----------------------------------------------------------------------
  //SAVING and LOADING

  //! serialize for saving to archive
  map<string, SerializedObject> MSPRT::save() const
  {
    map<string, SerializedObject> ret;
    ret["sums"] = serialize(sums);
    ret["log_probs"] = serialize(log_probs);
    ret["log_threshold"] = serialize(log_threshold);
    ret["threshold"] = serialize(threshold);
    ret["g"] = serialize(g);
    return ret;
  }

  //! serialize class to load form archive
  void MSPRT::load(map<string, SerializedObject> objects)
  {
    deserialize(objects["sums"],sums);
    deserialize(objects["log_probs"],log_probs);
    deserialize(objects["log_threshold"],log_threshold);
    deserialize(objects["threshold"],threshold);
    deserialize(objects["g"],g);
  }

}



/*************************************************************/
//DYNAMIC LOADING

extern "C" CEBL::Decision* ObjectCreate()
{
  return new CEBL::MSPRT;
}

extern "C" void ObjectDestroy(CEBL::Decision* p)
{
  delete p;
}

