/*
 * Decompiled with CFR 0.152.
 */
package LBJ2.learn;

import LBJ2.classify.Classifier;
import LBJ2.classify.DiscretePrimitiveStringFeature;
import LBJ2.classify.Feature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.ScoreSet;
import LBJ2.learn.Learner;
import LBJ2.learn.LinearThresholdUnit;
import LBJ2.learn.SparseAveragedPerceptron;
import LBJ2.util.ExceptionlessInputStream;
import LBJ2.util.ExceptionlessOutputStream;
import LBJ2.util.OVector;
import java.io.PrintStream;
import java.util.Collection;
import java.util.Iterator;

public class SparseNetworkLearner
extends Learner {
    public static final LinearThresholdUnit defaultBaseLTU = new SparseAveragedPerceptron();
    protected LinearThresholdUnit baseLTU;
    protected OVector network;
    protected int numExamples;
    protected int numFeatures;
    protected boolean conjunctiveLabels;

    public SparseNetworkLearner() {
        this("");
    }

    public SparseNetworkLearner(LinearThresholdUnit linearThresholdUnit) {
        this("", linearThresholdUnit);
    }

    public SparseNetworkLearner(Parameters parameters) {
        this("", parameters);
    }

    public SparseNetworkLearner(String string) {
        this(string, new Parameters());
    }

    public SparseNetworkLearner(String string, LinearThresholdUnit linearThresholdUnit) {
        super(string);
        Parameters parameters = new Parameters();
        parameters.baseLTU = linearThresholdUnit;
        this.setParameters(parameters);
        this.network = new OVector();
    }

    public SparseNetworkLearner(String string, Parameters parameters) {
        super(string);
        this.setParameters(parameters);
        this.network = new OVector();
    }

    public void setParameters(Parameters parameters) {
        if (!parameters.baseLTU.getOutputType().equals("discrete")) {
            System.err.println("LBJ WARNING: SparseNetworkLearner will only work with a LinearThresholdUnit that returns discrete.");
            System.err.println("             The given LTU, " + parameters.baseLTU.getClass().getName() + ", returns " + parameters.baseLTU.getOutputType() + ".");
        }
        this.setLTU(parameters.baseLTU);
    }

    public Learner.Parameters getParameters() {
        Parameters parameters = new Parameters(super.getParameters());
        parameters.baseLTU = this.baseLTU;
        return parameters;
    }

    public void setLTU(LinearThresholdUnit linearThresholdUnit) {
        this.baseLTU = linearThresholdUnit;
        this.baseLTU.name = this.name + "$baseLTU";
    }

    public void setLabeler(Classifier classifier) {
        if (this.getClass().getName().indexOf("SparseNetworkLearner") != -1 && !classifier.getOutputType().equals("discrete")) {
            System.err.println("LBJ WARNING: SparseNetworkLearner will only work with a label classifier that returns discrete.");
            System.err.println("             The given label classifier, " + classifier.getClass().getName() + ", returns " + classifier.getOutputType() + ".");
        }
        super.setLabeler(classifier);
    }

    public void setExtractor(Classifier classifier) {
        super.setExtractor(classifier);
        this.baseLTU.setExtractor(classifier);
        int n = this.network.size();
        for (int i = 0; i < n; ++i) {
            ((LinearThresholdUnit)this.network.get(i)).setExtractor(classifier);
        }
    }

    public void learn(int[] nArray, double[] dArray, int[] nArray2, double[] dArray2) {
        Object object;
        int n = nArray2[0];
        int n2 = this.network.size();
        if (n >= n2 || this.network.get(n) == null) {
            this.conjunctiveLabels |= this.labelLexicon.lookupKey(n).isConjunctive();
            object = (LinearThresholdUnit)this.baseLTU.clone();
            ((LinearThresholdUnit)object).initialize(this.numExamples, this.numFeatures);
            this.network.set(n, object);
            n2 = n + 1;
        }
        object = new int[1];
        for (int i = 0; i < n2; ++i) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
            if (linearThresholdUnit == null) continue;
            object[0] = i == n;
            linearThresholdUnit.learn(nArray, dArray, (int[])object, dArray2);
        }
    }

    public void doneLearning() {
        super.doneLearning();
        int n = this.network.size();
        for (int i = 0; i < n; ++i) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
            if (linearThresholdUnit == null) continue;
            linearThresholdUnit.doneLearning();
        }
    }

    public void initialize(int n, int n2) {
        this.numExamples = n;
        this.numFeatures = n2;
    }

    public void doneWithRound() {
        super.doneWithRound();
        int n = this.network.size();
        for (int i = 0; i < n; ++i) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
            if (linearThresholdUnit == null) continue;
            linearThresholdUnit.doneWithRound();
        }
    }

    public void forget() {
        super.forget();
        this.network = new OVector();
    }

    public ScoreSet scores(Object object, Collection collection) {
        Object[] objectArray = this.getExampleArray(object, false);
        return this.scores((int[])objectArray[0], (double[])objectArray[1], collection);
    }

    public ScoreSet scores(int[] nArray, double[] dArray, Collection collection) {
        ScoreSet scoreSet = new ScoreSet();
        Iterator iterator = collection.iterator();
        if (iterator.hasNext()) {
            if (this.conjunctiveLabels) {
                return this.conjunctiveScores(nArray, dArray, iterator);
            }
            while (iterator.hasNext()) {
                int n;
                LinearThresholdUnit linearThresholdUnit;
                String string = (String)iterator.next();
                DiscretePrimitiveStringFeature discretePrimitiveStringFeature = new DiscretePrimitiveStringFeature(this.labeler.containingPackage, this.labeler.name, "", string, this.labeler.valueIndexOf(string), (short)this.labeler.allowableValues().length);
                if (!this.labelLexicon.contains(discretePrimitiveStringFeature) || (linearThresholdUnit = (LinearThresholdUnit)this.network.get(n = this.labelLexicon.lookup(discretePrimitiveStringFeature))) == null) continue;
                scoreSet.put(string.toString(), linearThresholdUnit.score(nArray, dArray) - linearThresholdUnit.getThreshold());
            }
        } else {
            int n = this.network.size();
            for (int i = 0; i < n; ++i) {
                LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
                if (linearThresholdUnit == null) continue;
                scoreSet.put(this.labelLexicon.lookupKey(i).getStringValue(), linearThresholdUnit.score(nArray, dArray) - linearThresholdUnit.getThreshold());
            }
        }
        return scoreSet;
    }

    protected ScoreSet conjunctiveScores(int[] nArray, double[] dArray, Iterator iterator) {
        ScoreSet scoreSet = new ScoreSet();
        int n = this.network.size();
        block0: while (iterator.hasNext()) {
            String string = (String)iterator.next();
            for (int i = 0; i < n; ++i) {
                LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
                if (linearThresholdUnit == null || !this.labelLexicon.lookupKey(i).valueEquals(string)) continue;
                double d = linearThresholdUnit.score(nArray, dArray);
                scoreSet.put(string.toString(), d);
                continue block0;
            }
        }
        return scoreSet;
    }

    public ScoreSet scores(int[] nArray, double[] dArray) {
        ScoreSet scoreSet = new ScoreSet();
        int n = this.network.size();
        for (int i = 0; i < n; ++i) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
            if (linearThresholdUnit == null) continue;
            scoreSet.put(this.labelLexicon.lookupKey(i).getStringValue(), linearThresholdUnit.score(nArray, dArray) - linearThresholdUnit.getThreshold());
        }
        return scoreSet;
    }

    public Feature featureValue(int[] nArray, double[] dArray) {
        double d = Double.NEGATIVE_INFINITY;
        int n = -1;
        int n2 = this.network.size();
        for (int i = 0; i < n2; ++i) {
            double d2;
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
            if (linearThresholdUnit == null || !((d2 = linearThresholdUnit.score(nArray, dArray)) > d)) continue;
            n = i;
            d = d2;
        }
        return n == -1 ? null : this.predictions.get(n);
    }

    public String discreteValue(int[] nArray, double[] dArray) {
        return this.featureValue(nArray, dArray).getStringValue();
    }

    public FeatureVector classify(int[] nArray, double[] dArray) {
        return new FeatureVector(this.featureValue(nArray, dArray));
    }

    public Feature valueOf(Object object, Collection collection) {
        Object[] objectArray = this.getExampleArray(object, false);
        return this.valueOf((int[])objectArray[0], (double[])objectArray[1], collection);
    }

    public Feature valueOf(int[] nArray, double[] dArray, Collection collection) {
        double d = Double.NEGATIVE_INFINITY;
        int n = -1;
        Iterator iterator = collection.iterator();
        if (iterator.hasNext()) {
            if (this.conjunctiveLabels) {
                return this.conjunctiveValueOf(nArray, dArray, iterator);
            }
            while (iterator.hasNext()) {
                LinearThresholdUnit linearThresholdUnit;
                double d2 = Double.NEGATIVE_INFINITY;
                String string = (String)iterator.next();
                DiscretePrimitiveStringFeature discretePrimitiveStringFeature = new DiscretePrimitiveStringFeature(this.labeler.containingPackage, this.labeler.name, "", string, this.labeler.valueIndexOf(string), (short)this.labeler.allowableValues().length);
                int n2 = -1;
                if (this.labelLexicon.contains(discretePrimitiveStringFeature) && (linearThresholdUnit = (LinearThresholdUnit)this.network.get(n2 = this.labelLexicon.lookup(discretePrimitiveStringFeature))) != null) {
                    d2 = linearThresholdUnit.score(nArray, dArray);
                }
                if (!(d2 > d)) continue;
                n = n2;
                d = d2;
            }
        } else {
            int n3 = this.network.size();
            for (int i = 0; i < n3; ++i) {
                double d3;
                LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
                if (linearThresholdUnit == null || !((d3 = linearThresholdUnit.score(nArray, dArray)) > d)) continue;
                n = i;
                d = d3;
            }
        }
        return this.predictions.get(n);
    }

    protected Feature conjunctiveValueOf(int[] nArray, double[] dArray, Iterator iterator) {
        double d = Double.NEGATIVE_INFINITY;
        int n = -1;
        int n2 = this.network.size();
        block0: while (iterator.hasNext()) {
            String string = (String)iterator.next();
            for (int i = 0; i < n2; ++i) {
                LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
                if (linearThresholdUnit == null || !this.labelLexicon.lookupKey(i).valueEquals(string)) continue;
                double d2 = linearThresholdUnit.score(nArray, dArray);
                if (!(d2 > d)) continue block0;
                d = d2;
                n = i;
                continue block0;
            }
        }
        return this.predictions.get(n);
    }

    public void write(PrintStream printStream) {
        printStream.println(this.baseLTU.getClass().getName());
        this.baseLTU.write(printStream);
        int n = this.network.size();
        for (int i = 0; i < n; ++i) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
            if (linearThresholdUnit == null) continue;
            printStream.println("label: " + this.labelLexicon.lookupKey(i).getStringValue());
            linearThresholdUnit.setLexicon(this.lexicon);
            linearThresholdUnit.write(printStream);
            linearThresholdUnit.setLexicon(null);
        }
        printStream.println("End of SparseNetworkLearner");
    }

    public void write(ExceptionlessOutputStream exceptionlessOutputStream) {
        super.write(exceptionlessOutputStream);
        this.baseLTU.write(exceptionlessOutputStream);
        exceptionlessOutputStream.writeBoolean(this.conjunctiveLabels);
        int n = this.network.size();
        exceptionlessOutputStream.writeInt(n);
        for (int i = 0; i < n; ++i) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
            if (linearThresholdUnit == null) {
                exceptionlessOutputStream.writeString(null);
                continue;
            }
            linearThresholdUnit.write(exceptionlessOutputStream);
        }
    }

    public void read(ExceptionlessInputStream exceptionlessInputStream) {
        super.read(exceptionlessInputStream);
        this.baseLTU = (LinearThresholdUnit)Learner.readLearner(exceptionlessInputStream);
        this.conjunctiveLabels = exceptionlessInputStream.readBoolean();
        int n = exceptionlessInputStream.readInt();
        this.network = new OVector(n);
        for (int i = 0; i < n; ++i) {
            this.network.add(Learner.readLearner(exceptionlessInputStream));
        }
    }

    public Object clone() {
        SparseNetworkLearner sparseNetworkLearner = null;
        try {
            sparseNetworkLearner = (SparseNetworkLearner)super.clone();
        }
        catch (Exception exception) {
            System.err.println("Error cloning SparseNetworkLearner: " + exception);
            exception.printStackTrace();
            System.exit(1);
        }
        sparseNetworkLearner.baseLTU = (LinearThresholdUnit)this.baseLTU.clone();
        int n = this.network.size();
        sparseNetworkLearner.network = new OVector(n);
        for (int i = 0; i < n; ++i) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit)this.network.get(i);
            if (linearThresholdUnit == null) {
                sparseNetworkLearner.network.add(null);
                continue;
            }
            sparseNetworkLearner.network.add(linearThresholdUnit.clone());
        }
        return sparseNetworkLearner;
    }

    public static class Parameters
    extends Learner.Parameters {
        public LinearThresholdUnit baseLTU;

        public Parameters() {
            this.baseLTU = (LinearThresholdUnit)defaultBaseLTU.clone();
        }

        public Parameters(Learner.Parameters parameters) {
            super(parameters);
            this.baseLTU = (LinearThresholdUnit)defaultBaseLTU.clone();
        }

        public Parameters(Parameters parameters) {
            super(parameters);
            this.baseLTU = parameters.baseLTU;
        }

        public void setParameters(Learner learner) {
            ((SparseNetworkLearner)learner).setParameters(this);
        }

        public String nonDefaultString() {
            String string = this.baseLTU.getClass().getName();
            string = string.substring(string.lastIndexOf(46) + 1);
            return string + ": " + this.baseLTU.getParameters().nonDefaultString();
        }
    }
}

