/*
 * Decompiled with CFR 0.152.
 */
package LBJ2.learn;

import LBJ2.classify.Classifier;
import LBJ2.classify.DiscretePrimitiveStringFeature;
import LBJ2.classify.Feature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.ScoreSet;
import LBJ2.learn.Learner;
import LBJ2.learn.SparseAveragedPerceptron;
import LBJ2.util.ExceptionlessInputStream;
import LBJ2.util.ExceptionlessOutputStream;
import LBJ2.util.OVector;
import java.io.PrintStream;
import java.util.Arrays;

public class AdaBoost
extends Learner {
    public static final Learner defaultWeakLearner = new SparseAveragedPerceptron();
    public static final int defaultRounds = 10;
    protected Learner weakLearner;
    protected int rounds;
    protected Learner[] weakLearners;
    protected double[] alpha;
    protected OVector allExamples;
    protected String[] allowableValues;

    public AdaBoost() {
        this("");
    }

    public AdaBoost(Learner learner) {
        this("", learner);
    }

    public AdaBoost(int n) {
        this("", n);
    }

    public AdaBoost(Learner learner, int n) {
        this("", learner, n);
    }

    public AdaBoost(Parameters parameters) {
        this("", parameters);
    }

    public AdaBoost(String string) {
        this(string, new Parameters());
    }

    public AdaBoost(String string, Learner learner) {
        this(string, learner, 10);
    }

    public AdaBoost(String string, int n) {
        this(string, defaultWeakLearner, n);
    }

    public AdaBoost(String string, Learner learner, int n) {
        super(string);
        this.weakLearner = learner;
        this.rounds = n;
        this.allExamples = new OVector();
        this.allowableValues = new String[]{"*", "*"};
    }

    public AdaBoost(String string, Parameters parameters) {
        super(string);
        this.setParameters(parameters);
        this.allExamples = new OVector();
        this.allowableValues = new String[]{"*", "*"};
    }

    public void setParameters(Parameters parameters) {
        this.weakLearner = parameters.weakLearner;
        this.rounds = parameters.rounds;
    }

    public Learner.Parameters getParameters() {
        Parameters parameters = new Parameters(super.getParameters());
        parameters.weakLearner = this.weakLearner;
        parameters.rounds = this.rounds;
        return parameters;
    }

    public String[] allowableValues() {
        return this.allowableValues;
    }

    public void setLabeler(Classifier classifier) {
        if (classifier == null || classifier.allowableValues().length != 2) {
            System.err.println("Error: " + this.name + ": An LTU must be given a single binary label classifier.");
            new Exception().printStackTrace();
            System.exit(1);
        }
        super.setLabeler(classifier);
        this.allowableValues = classifier.allowableValues();
        this.labelLexicon.clear();
        this.labelLexicon.lookup(new DiscretePrimitiveStringFeature(classifier.containingPackage, classifier.name, "", this.allowableValues[0], 0, 2), true);
        this.labelLexicon.lookup(new DiscretePrimitiveStringFeature(classifier.containingPackage, classifier.name, "", this.allowableValues[1], 1, 2), true);
        this.createPrediction(0);
        this.createPrediction(1);
    }

    public void initialize(int n, int n2) {
        this.allExamples = new OVector(n);
    }

    public void learn(Object object) {
        this.allExamples.add(this.getExampleArray(object));
    }

    public void learn(int[] nArray, double[] dArray, int[] nArray2, double[] dArray2) {
        this.allExamples.add(new Object[]{nArray, dArray, nArray2, dArray2});
    }

    public void doneLearning() {
        int n = this.allExamples.size();
        if (n == 0) {
            return;
        }
        double[] dArray = new double[n];
        Arrays.fill(dArray, 1.0 / (double)n);
        this.weakLearners = new Learner[this.rounds];
        this.alpha = new double[this.rounds];
        for (int i = 0; i < this.rounds; ++i) {
            int n2;
            int n3;
            Object[][] objectArrayArray = new Object[n][];
            for (n3 = 0; n3 < n; ++n3) {
                double d = Math.random();
                int n4 = 0;
                for (double d2 = 0.0; d2 <= d; d2 += dArray[n4++]) {
                }
                objectArrayArray[n3] = (Object[])this.allExamples.get(n4 - 1);
            }
            this.weakLearners[i] = (Learner)this.weakLearner.clone();
            this.weakLearners[i].setLabelLexicon(this.labelLexicon);
            this.weakLearners[i].learn((Object[])objectArrayArray);
            this.weakLearners[i].doneLearning();
            n3 = 0;
            boolean[] blArray = new boolean[n];
            for (int j = 0; j < n; ++j) {
                String string = this.labelLexicon.lookupKey(((int[])objectArrayArray[j][2])[0]).getStringValue();
                String string2 = this.weakLearners[i].featureValue(objectArrayArray[j]).getStringValue();
                blArray[j] = string.equals(string2);
                if (!blArray[j]) continue;
                ++n3;
            }
            double d = (double)n3 / (double)(n - n3);
            this.alpha[i] = Math.log(d) / 2.0;
            if (i + 1 >= this.rounds) continue;
            double d3 = Math.sqrt(d);
            double d4 = 0.0;
            for (n2 = 0; n2 < n; ++n2) {
                if (blArray[n2]) {
                    int n5 = n2;
                    dArray[n5] = dArray[n5] / d3;
                } else {
                    int n6 = n2;
                    dArray[n6] = dArray[n6] * d3;
                }
                d4 += dArray[n2];
            }
            n2 = 0;
            while (n2 < n) {
                int n7 = n2++;
                dArray[n7] = dArray[n7] / d4;
            }
        }
        this.allExamples = null;
    }

    public void forget() {
        super.forget();
        this.weakLearners = null;
        this.alpha = null;
        this.allExamples = new OVector();
    }

    protected double[] sumAlphas(int[] nArray, double[] dArray) {
        double[] dArray2 = new double[2];
        for (int i = 0; i < this.rounds; ++i) {
            short s;
            short s2 = s = this.weakLearners[i].featureValue(nArray, dArray).getValueIndex();
            dArray2[s2] = dArray2[s2] + this.alpha[i];
        }
        return dArray2;
    }

    public ScoreSet scores(int[] nArray, double[] dArray) {
        double[] dArray2 = this.sumAlphas(nArray, dArray);
        String[] stringArray = new String[]{this.labelLexicon.lookupKey(0).getStringValue(), this.labelLexicon.lookupKey(1).getStringValue()};
        return new ScoreSet(stringArray, dArray2);
    }

    public Feature featureValue(int[] nArray, double[] dArray) {
        double[] dArray2 = this.sumAlphas(nArray, dArray);
        return this.predictions.get(dArray2[0] > dArray2[1] ? 0 : 1);
    }

    public String discreteValue(int[] nArray, double[] dArray) {
        double[] dArray2 = this.sumAlphas(nArray, dArray);
        return this.allowableValues[dArray2[0] > dArray2[1] ? 0 : 1];
    }

    public FeatureVector classify(int[] nArray, double[] dArray) {
        return new FeatureVector(this.featureValue(nArray, dArray));
    }

    public void write(PrintStream printStream) {
        int n;
        printStream.println(this.name);
        if (this.rounds > 0) {
            printStream.print(this.alpha[0]);
            for (n = 1; n < this.rounds; ++n) {
                printStream.print(", " + this.alpha[n]);
            }
            printStream.println();
        } else {
            printStream.println("---");
        }
        printStream.println(this.weakLearner.getClass().getName());
        this.weakLearner.write(printStream);
        for (n = 0; n < this.rounds; ++n) {
            this.weakLearners[n].setLexicon(this.lexicon);
            this.weakLearners[n].write(printStream);
            this.weakLearners[n].setLexicon(null);
        }
    }

    public void write(ExceptionlessOutputStream exceptionlessOutputStream) {
        int n;
        super.write(exceptionlessOutputStream);
        this.weakLearner.write(exceptionlessOutputStream);
        exceptionlessOutputStream.writeInt(this.rounds);
        for (n = 0; n < this.rounds; ++n) {
            this.weakLearners[n].write(exceptionlessOutputStream);
        }
        for (n = 0; n < this.rounds; ++n) {
            exceptionlessOutputStream.writeDouble(this.alpha[n]);
        }
        exceptionlessOutputStream.writeString(this.allowableValues[0]);
        exceptionlessOutputStream.writeString(this.allowableValues[1]);
    }

    public void read(ExceptionlessInputStream exceptionlessInputStream) {
        int n;
        super.read(exceptionlessInputStream);
        this.weakLearner = Learner.readLearner(exceptionlessInputStream);
        this.rounds = exceptionlessInputStream.readInt();
        for (n = 0; n < this.rounds; ++n) {
            this.weakLearners[n] = Learner.readLearner(exceptionlessInputStream);
        }
        for (n = 0; n < this.rounds; ++n) {
            this.alpha[n] = exceptionlessInputStream.readDouble();
        }
        this.allowableValues = new String[2];
        this.allowableValues[0] = exceptionlessInputStream.readString();
        this.allowableValues[1] = exceptionlessInputStream.readString();
    }

    public static class Parameters
    extends Learner.Parameters {
        protected Learner weakLearner;
        protected int rounds;

        public Parameters() {
            this.weakLearner = (Learner)defaultWeakLearner.clone();
            this.rounds = 10;
        }

        public Parameters(Learner.Parameters parameters) {
            super(parameters);
            this.weakLearner = (Learner)defaultWeakLearner.clone();
            this.rounds = 10;
        }

        public Parameters(Parameters parameters) {
            super(parameters);
            this.weakLearner = parameters.weakLearner;
            this.rounds = parameters.rounds;
        }

        public void setParameters(Learner learner) {
            ((AdaBoost)learner).setParameters(this);
        }

        public String nonDefaultString() {
            String string = super.nonDefaultString();
            if (this.rounds != 10) {
                string = string + ", rounds = " + this.rounds;
            }
            if (string.startsWith(", ")) {
                string = string.substring(2);
            }
            return string;
        }
    }
}

