/*
 * Decompiled with CFR 0.152.
 */
package LBJ2.learn;

import LBJ2.classify.Classifier;
import LBJ2.classify.DiscretePrimitiveStringFeature;
import LBJ2.classify.Feature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.ScoreSet;
import LBJ2.learn.BiasedRandomWeightVector;
import LBJ2.learn.Learner;
import LBJ2.learn.LinearThresholdUnit;
import LBJ2.learn.SparseWeightVector;
import LBJ2.util.ExceptionlessInputStream;
import LBJ2.util.ExceptionlessOutputStream;
import LBJ2.util.OVector;
import java.io.PrintStream;
import java.util.Collection;
import java.util.Iterator;

public class SparseMIRA
extends Learner {
    public static final double TOLERANCE = 1.0E-9;
    protected OVector network = new OVector();
    protected boolean conjunctiveLabels;

    public SparseMIRA() {
        this("");
    }

    public SparseMIRA(Parameters parameters) {
        this("", parameters);
    }

    public SparseMIRA(String string) {
        super(string);
    }

    public SparseMIRA(String string, Parameters parameters) {
        this(string);
    }

    public Learner.Parameters getParameters() {
        return new Parameters();
    }

    public void setLabeler(Classifier classifier) {
        if (this.getClass().getName().indexOf("SparseMIRA") != -1 && !classifier.getOutputType().equals("discrete")) {
            System.err.println("LBJ WARNING: SparseMIRA will only work with a label classifier that returns discrete.");
            System.err.println("             The given label classifier, " + classifier.getClass().getName() + ", returns " + classifier.getOutputType() + ".");
        }
        super.setLabeler(classifier);
    }

    public void learn(int[] nArray, double[] dArray, int[] nArray2, double[] dArray2) {
        int n;
        int n2 = nArray2[0];
        int n3 = this.network.size();
        if (n2 >= n3) {
            this.conjunctiveLabels |= this.labelLexicon.lookupKey(n2).isConjunctive();
            while (n3++ <= n2) {
                this.network.add(new BiasedRandomWeightVector());
            }
        }
        if (n3 == 1) {
            return;
        }
        double d = FeatureVector.L2NormSquared(dArray) + 1.0;
        double[] dArray3 = new double[n3];
        boolean[] blArray = new boolean[dArray3.length];
        BiasedRandomWeightVector[] biasedRandomWeightVectorArray = new BiasedRandomWeightVector[dArray3.length];
        double d2 = Double.MAX_VALUE;
        double d3 = -1.7976931348623157E308;
        for (n = 0; n < n3; ++n) {
            blArray[n] = n == n2;
            biasedRandomWeightVectorArray[n] = (BiasedRandomWeightVector)this.network.get(n);
            dArray3[n] = biasedRandomWeightVectorArray[n].dot(nArray, dArray) / d;
            d2 = Math.min(d2, dArray3[n]);
            d3 = Math.max(d3, dArray3[n]);
        }
        d2 -= 1.0;
        d3 += 1.0;
        while (!SparseMIRA.nearlyEqualTo(d2, d3)) {
            double d4 = (d3 + d2) / 2.0;
            if (SparseMIRA.sumMultipliers(d4, dArray3, blArray) <= 0.0) {
                d2 = d4;
                continue;
            }
            d3 = d4;
        }
        for (n = 0; n < n3; ++n) {
            double d5 = SparseMIRA.getMultiplier(d2, dArray3[n], blArray[n]);
            if (SparseMIRA.nearlyEqualTo(d5, 0.0)) continue;
            biasedRandomWeightVectorArray[n].scaledAdd(nArray, dArray, d5);
        }
    }

    private static double getMultiplier(double d, double d2, boolean bl) {
        return Math.min(d - d2, bl ? 1.0 : 0.0);
    }

    private static double sumMultipliers(double d, double[] dArray, boolean[] blArray) {
        double d2 = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            d2 += SparseMIRA.getMultiplier(d, dArray[i], blArray[i]);
        }
        return d2;
    }

    private static boolean nearlyEqualTo(double d, double d2) {
        return -1.0E-9 < d - d2 && d - d2 < 1.0E-9;
    }

    public void forget() {
        super.forget();
        this.network = new OVector();
    }

    public ScoreSet scores(int[] nArray, double[] dArray) {
        ScoreSet scoreSet = new ScoreSet();
        int n = this.network.size();
        for (int i = 0; i < n; ++i) {
            double d = ((BiasedRandomWeightVector)this.network.get(i)).dot(nArray, dArray);
            scoreSet.put(this.labelLexicon.lookupKey(i).getStringValue(), d);
        }
        return scoreSet;
    }

    public Feature featureValue(int[] nArray, double[] dArray) {
        double d = Double.NEGATIVE_INFINITY;
        int n = -1;
        int n2 = this.network.size();
        for (int i = 0; i < n2; ++i) {
            double d2 = ((BiasedRandomWeightVector)this.network.get(i)).dot(nArray, dArray);
            if (!(d2 > d)) continue;
            n = i;
            d = d2;
        }
        if (n == -1) {
            return null;
        }
        return this.predictions.get(n);
    }

    public String discreteValue(int[] nArray, double[] dArray) {
        return this.featureValue(nArray, dArray).getStringValue();
    }

    public FeatureVector classify(int[] nArray, double[] dArray) {
        return new FeatureVector(this.featureValue(nArray, dArray));
    }

    public Feature valueOf(Object object, Collection collection) {
        Object[] objectArray = this.getExampleArray(object, false);
        return this.valueOf((int[])objectArray[0], (double[])objectArray[1], collection);
    }

    public Feature valueOf(int[] nArray, double[] dArray, Collection collection) {
        double d = Double.NEGATIVE_INFINITY;
        int n = -1;
        Iterator iterator = collection.iterator();
        if (iterator.hasNext()) {
            if (this.conjunctiveLabels) {
                return this.conjunctiveValueOf(nArray, dArray, iterator);
            }
            while (iterator.hasNext()) {
                double d2 = Double.NEGATIVE_INFINITY;
                String string = (String)iterator.next();
                DiscretePrimitiveStringFeature discretePrimitiveStringFeature = new DiscretePrimitiveStringFeature(this.labeler.containingPackage, this.labeler.name, "", string, this.labeler.valueIndexOf(string), (short)this.labeler.allowableValues().length);
                int n2 = -1;
                if (this.labelLexicon.contains(discretePrimitiveStringFeature)) {
                    n2 = this.labelLexicon.lookup(discretePrimitiveStringFeature);
                    d2 = ((BiasedRandomWeightVector)this.network.get(n2)).dot(nArray, dArray);
                }
                if (!(d2 > d)) continue;
                n = n2;
                d = d2;
            }
        } else {
            int n3 = this.network.size();
            for (int i = 0; i < n3; ++i) {
                double d3 = ((BiasedRandomWeightVector)this.network.get(i)).dot(nArray, dArray);
                if (!(d3 > d)) continue;
                n = i;
                d = d3;
            }
        }
        return n == -1 ? null : this.predictions.get(n);
    }

    protected Feature conjunctiveValueOf(int[] nArray, double[] dArray, Iterator iterator) {
        double d = Double.NEGATIVE_INFINITY;
        int n = -1;
        int n2 = this.network.size();
        block0: while (iterator.hasNext()) {
            String string = (String)iterator.next();
            for (int i = 0; i < n2; ++i) {
                LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
                if (linearThresholdUnit == null || !this.predictions.get(i).valueEquals(string)) continue;
                double d2 = linearThresholdUnit.score(nArray, dArray);
                if (!(d2 > d)) continue block0;
                d = d2;
                n = i;
                continue block0;
            }
        }
        return n == -1 ? null : this.predictions.get(n);
    }

    public ScoreSet scores(Object object, Collection collection) {
        Object[] objectArray = this.getExampleArray(object, false);
        return this.scores((int[])objectArray[0], (double[])objectArray[1], collection);
    }

    public ScoreSet scores(int[] nArray, double[] dArray, Collection collection) {
        ScoreSet scoreSet = new ScoreSet();
        Iterator iterator = collection.iterator();
        if (iterator.hasNext()) {
            if (this.conjunctiveLabels) {
                return this.conjunctiveScores(nArray, dArray, iterator);
            }
            while (iterator.hasNext()) {
                double d = Double.NEGATIVE_INFINITY;
                String string = (String)iterator.next();
                DiscretePrimitiveStringFeature discretePrimitiveStringFeature = new DiscretePrimitiveStringFeature(this.labeler.containingPackage, this.labeler.name, "", string, this.labeler.valueIndexOf(string), (short)this.labeler.allowableValues().length);
                if (!this.labelLexicon.contains(discretePrimitiveStringFeature)) continue;
                int n = this.labelLexicon.lookup(discretePrimitiveStringFeature);
                d = ((BiasedRandomWeightVector)this.network.get(n)).dot(nArray, dArray);
                scoreSet.put(string.toString(), d);
            }
        } else {
            int n = this.network.size();
            for (int i = 0; i < n; ++i) {
                double d = ((BiasedRandomWeightVector)this.network.get(i)).dot(nArray, dArray);
                scoreSet.put(this.labelLexicon.lookupKey(i).getStringValue(), d);
            }
        }
        return scoreSet;
    }

    protected ScoreSet conjunctiveScores(int[] nArray, double[] dArray, Iterator iterator) {
        ScoreSet scoreSet = new ScoreSet();
        int n = this.network.size();
        block0: while (iterator.hasNext()) {
            String string = (String)iterator.next();
            for (int i = 0; i < n; ++i) {
                LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
                if (linearThresholdUnit == null || !this.labelLexicon.lookupKey(i).valueEquals(string)) continue;
                double d = linearThresholdUnit.score(nArray, dArray);
                scoreSet.put(string.toString(), d);
                continue block0;
            }
        }
        return scoreSet;
    }

    public void write(PrintStream printStream) {
        int n = this.network.size();
        for (int i = 0; i < n; ++i) {
            printStream.println("label: " + this.predictions.get(i).getStringValue());
            ((BiasedRandomWeightVector)this.network.get(i)).write(printStream, this.lexicon);
        }
        printStream.println("End of SparseMIRA");
    }

    public void write(ExceptionlessOutputStream exceptionlessOutputStream) {
        super.write(exceptionlessOutputStream);
        int n = this.network.size();
        exceptionlessOutputStream.writeInt(n);
        for (int i = 0; i < n; ++i) {
            ((BiasedRandomWeightVector)this.network.get(i)).write(exceptionlessOutputStream);
        }
    }

    public void read(ExceptionlessInputStream exceptionlessInputStream) {
        super.read(exceptionlessInputStream);
        int n = exceptionlessInputStream.readInt();
        this.network = new OVector(n);
        for (int i = 0; i < n; ++i) {
            this.network.add(SparseWeightVector.readWeightVector(exceptionlessInputStream));
        }
    }

    public Object clone() {
        SparseMIRA sparseMIRA = null;
        try {
            sparseMIRA = (SparseMIRA)super.clone();
        }
        catch (Exception exception) {
            System.err.println("Error cloning SparseMIRA: " + exception);
            exception.printStackTrace();
            System.exit(1);
        }
        int n = this.network.size();
        sparseMIRA.network = new OVector(n);
        for (int i = 0; i < n; ++i) {
            sparseMIRA.network.add(((BiasedRandomWeightVector)this.network.get(i)).clone());
        }
        return sparseMIRA;
    }

    public static class Parameters
    extends Learner.Parameters {
        public Parameters() {
        }

        public Parameters(Learner.Parameters parameters) {
            super(parameters);
        }

        public Parameters(Parameters parameters) {
            super(parameters);
        }

        public void setParameters(Learner learner) {
        }
    }
}

