/*
 * Decompiled with CFR 0.152.
 */
package LBJ2.learn;

import LBJ2.classify.Classifier;
import LBJ2.classify.DiscretePrimitiveStringFeature;
import LBJ2.classify.Feature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.ScoreSet;
import LBJ2.learn.Learner;
import LBJ2.learn.Lexicon;
import LBJ2.util.ExceptionlessInputStream;
import LBJ2.util.ExceptionlessOutputStream;
import LBJ2.util.FVector;
import LBJ2.util.IVector;
import LBJ2.util.OVector;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.Date;
import java.util.Iterator;
import liblinear.FeatureNode;
import liblinear.Linear;
import liblinear.Model;
import liblinear.Parameter;
import liblinear.Problem;
import liblinear.SolverType;

public class SupportVectorMachine
extends Learner {
    public static final String defaultSolverType = "L2LOSS_SVM";
    public static final double defaultC = 1.0;
    public static final double defaultEpsilon = 0.1;
    public static final double defaultBias = 1.0;
    private boolean warningPrinted;
    protected String solverType;
    protected double C;
    protected double epsilon;
    protected double bias;
    protected int biasFeatures;
    protected boolean displayLL = false;
    protected int numClasses;
    protected int numFeatures;
    protected boolean conjunctiveLabels;
    protected double[] weights;
    protected IVector allLabels;
    protected OVector allExamples;
    protected String[] allowableValues;
    protected Lexicon newLabelLexicon = this.labelLexicon;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SupportVectorMachine() {
        this("");
    }

    public SupportVectorMachine(double d) {
        this(d, 0.1);
    }

    public SupportVectorMachine(double d, double d2) {
        this(d, d2, 1.0);
    }

    public SupportVectorMachine(double d, double d2, double d3) {
        this(d, d2, d3, "");
    }

    public SupportVectorMachine(double d, double d2, double d3, String string) {
        this("", d, d2, d3, string, false);
    }

    public SupportVectorMachine(double d, double d2, double d3, String string, boolean bl) {
        this("", d, d2, d3, string, bl);
    }

    public SupportVectorMachine(String string) {
        this(string, new Parameters());
    }

    public SupportVectorMachine(String string, double d) {
        this(string, d, 0.1);
    }

    public SupportVectorMachine(String string, double d, double d2) {
        this(string, d, d2, 1.0);
    }

    public SupportVectorMachine(String string, double d, double d2, double d3) {
        this(string, d, d2, d3, "");
    }

    public SupportVectorMachine(String string, double d, double d2, double d3, String string2) {
        this(string, d, d2, d3, string2, false);
    }

    public SupportVectorMachine(String string, double d, double d2, double d3, String string2, boolean bl) {
        super(string);
        Parameters parameters = new Parameters();
        parameters.C = d;
        parameters.epsilon = d2;
        parameters.bias = d3;
        parameters.solverType = string2;
        parameters.displayLL = bl;
        this.allowableValues = new String[0];
        this.setParameters(parameters);
    }

    public SupportVectorMachine(Parameters parameters) {
        this("", parameters);
    }

    public SupportVectorMachine(String string, Parameters parameters) {
        super(string);
        this.allowableValues = new String[0];
        this.setParameters(parameters);
    }

    public void setParameters(Parameters parameters) {
        this.C = parameters.C;
        this.epsilon = parameters.epsilon;
        this.bias = parameters.bias;
        this.biasFeatures = this.bias >= 0.0 ? 1 : 0;
        this.solverType = parameters.solverType;
        this.displayLL = parameters.displayLL;
    }

    public Learner.Parameters getParameters() {
        Parameters parameters = new Parameters(super.getParameters());
        parameters.C = this.C;
        parameters.epsilon = this.epsilon;
        parameters.bias = this.bias;
        parameters.solverType = this.solverType;
        parameters.displayLL = this.displayLL;
        return parameters;
    }

    public void setLabeler(Classifier classifier) {
        super.setLabeler(classifier);
        String[] stringArray = this.allowableValues = classifier == null ? null : classifier.allowableValues();
        if (this.allowableValues == null) {
            this.allowableValues = new String[0];
        }
    }

    public String[] allowableValues() {
        return this.allowableValues;
    }

    public void initialize(int n, int n2) {
        this.allLabels = new IVector(n);
        this.allExamples = new OVector(n);
    }

    public void learn(int[] nArray, double[] dArray, int[] nArray2, double[] dArray2) {
        int n;
        int n2;
        int n3;
        if (this.allLabels == null || this.allLabels.size() == 0) {
            if (this.allLabels == null) {
                this.allLabels = new IVector();
                this.allExamples = new OVector();
            }
            this.weights = null;
            this.warningPrinted = false;
        }
        if (!$assertionsDisabled && nArray2.length != 1) {
            throw new AssertionError((Object)"Example must have a single label.");
        }
        this.allLabels.add(nArray2[0]);
        int n4 = nArray.length;
        FeatureNode[] featureNodeArray = new FeatureNode[n4 + this.biasFeatures];
        this.allExamples.add(featureNodeArray);
        for (n3 = 0; n3 < n4; ++n3) {
            n2 = nArray[n3] + 1;
            this.numFeatures = Math.max(this.numFeatures, n2);
            featureNodeArray[n3] = new FeatureNode(n2, dArray[n3]);
        }
        Arrays.sort(featureNodeArray, 0, n4, new Comparator(){

            public int compare(Object object, Object object2) {
                FeatureNode featureNode = (FeatureNode)object;
                FeatureNode featureNode2 = (FeatureNode)object2;
                return featureNode.index - featureNode2.index;
            }
        });
        n3 = -1;
        n2 = n4;
        for (int i = 0; i < n4; ++i) {
            n = featureNodeArray[i].index;
            if (n3 != -1 && n == featureNodeArray[n3].index) {
                --n2;
                featureNodeArray[n3] = new FeatureNode(n, featureNodeArray[n3].value + featureNodeArray[i].value);
                featureNodeArray[i] = null;
                continue;
            }
            n3 = i;
        }
        if (n2 < n4) {
            FeatureNode[] featureNodeArray2 = new FeatureNode[n2 + this.biasFeatures];
            n = 0;
            for (int i = 0; i < n4; ++i) {
                if (featureNodeArray[i] == null) continue;
                featureNodeArray2[n++] = featureNodeArray[i];
            }
            this.allExamples.set(this.allExamples.size() - 1, featureNodeArray2);
        }
    }

    public void doneLearning() {
        int n;
        boolean bl;
        int n2;
        int n3;
        super.doneLearning();
        if (this.labelLexicon.size() > 2 || this.solverType.equals("MCSVM_CS")) {
            this.newLabelLexicon = new Lexicon();
            n3 = 1;
            for (n2 = 0; n2 < this.allExamples.size(); ++n2) {
                Feature feature = this.labelLexicon.lookupKey(this.allLabels.get(n2));
                int n4 = this.newLabelLexicon.lookup(feature, true);
                n3 &= n4 == this.allLabels.get(n2) ? 1 : 0;
                this.allLabels.set(n2, n4);
            }
            if (n3 != 0 && this.newLabelLexicon.size() == this.labelLexicon.size()) {
                this.newLabelLexicon = this.labelLexicon;
            } else if (this.newLabelLexicon.size() > this.labelLexicon.size()) {
                System.err.println("LBJ ERROR: SupportVectorMachine: new label lexicon is too big!");
                new Exception().printStackTrace();
                System.exit(1);
            } else {
                n2 = this.newLabelLexicon.size();
                this.predictions = new FVector(n2);
                for (int i = 0; i < n2; ++i) {
                    this.createPrediction(this.newLabelLexicon, i);
                }
            }
        }
        if (this.displayLL) {
            System.out.println("  Training via liblinear at " + new Date());
        }
        if (this.allLabels == null) {
            if (this.displayLL) {
                System.out.println("    No training examples; no action taken.");
                System.out.println("  Finished training at " + new Date());
            }
            return;
        }
        if (this.solverType.length() == 0) {
            this.solverType = defaultSolverType;
        }
        this.numClasses = this.newLabelLexicon.size();
        for (n3 = 0; n3 < this.numClasses && !this.conjunctiveLabels; ++n3) {
            this.conjunctiveLabels = this.newLabelLexicon.lookupKey(n3).isConjunctive();
        }
        n3 = this.allExamples.size();
        n2 = this.numFeatures + this.biasFeatures;
        if (this.biasFeatures == 1) {
            for (int i = 0; i < n3; ++i) {
                FeatureNode[] featureNodeArray = (FeatureNode[])this.allExamples.get(i);
                featureNodeArray[featureNodeArray.length - 1] = new FeatureNode(n2, this.bias);
            }
        }
        boolean bl2 = bl = !this.solverType.equals("MCSVM_CS") && this.numClasses == 2 && this.allowableValues.length == 2;
        if (n3 > 0 && bl) {
            int n5;
            DiscretePrimitiveStringFeature discretePrimitiveStringFeature = new DiscretePrimitiveStringFeature(this.labeler.containingPackage, this.labeler.name, "", this.allowableValues[1], 1, 2);
            n = this.newLabelLexicon.lookup(discretePrimitiveStringFeature);
            for (n5 = 0; n5 < n3 && this.allLabels.get(n5) == 1 - n; ++n5) {
            }
            if (n5 > 0 && n5 < n3) {
                this.allLabels.set(0, n);
                this.allLabels.set(n5, 1 - n);
                this.allExamples.set(0, this.allExamples.set(n5, this.allExamples.get(0)));
                this.newLabelLexicon = new Lexicon();
                this.newLabelLexicon.lookup(discretePrimitiveStringFeature, true);
                this.newLabelLexicon.lookup(new DiscretePrimitiveStringFeature(this.labeler.containingPackage, this.labeler.name, "", this.allowableValues[0], 0, 2), true);
                this.predictions = new FVector(2);
                this.createPrediction(this.newLabelLexicon, 0);
                this.createPrediction(this.newLabelLexicon, 1);
            }
        }
        Problem problem = new Problem();
        problem.bias = this.bias;
        problem.l = n3;
        problem.n = n2;
        problem.x = new FeatureNode[n3][];
        for (n = 0; n < n3; ++n) {
            problem.x[n] = (FeatureNode[])this.allExamples.get(n);
        }
        problem.y = this.allLabels.toArray();
        Parameter parameter = new Parameter(Parameters.getSolverType(this.solverType), this.C, this.epsilon);
        Model model = Linear.train((Problem)problem, (Parameter)parameter);
        this.weights = model.getFeatureWeights();
        this.allExamples = null;
        this.allLabels = null;
        if (this.displayLL) {
            System.out.println("  Finished training at " + new Date());
        }
    }

    public void write(PrintStream printStream) {
        this.demandLexicon();
        printStream.println(this.name + ": " + this.C + ", " + this.epsilon + ", " + this.bias + ", " + this.solverType);
        if (this.weights != null) {
            printStream.println();
            printStream.println("Feature weights:");
            printStream.println("=========================================");
            int n = this.numFeatures;
            if (this.bias >= 0.0) {
                ++n;
            }
            if (!this.solverType.equals("MCSVM_CS") && this.numClasses <= 2) {
                this.numClasses = 1;
            }
            for (int i = 0; i < this.numClasses; ++i) {
                if (this.numClasses > 1) {
                    String string = this.newLabelLexicon.lookupKey(i).getStringValue();
                    printStream.println("Class = " + string);
                }
                for (int j = 0; j < n; ++j) {
                    if (j < this.numFeatures) {
                        printStream.print(this.lexicon.lookupKey(j));
                    } else {
                        printStream.print("[bias]");
                    }
                    double d = this.weights[j * this.numClasses + i];
                    printStream.println("\t\t\t" + d);
                }
            }
            printStream.println("=========================================");
        }
        printStream.println("End of SupportVectorMachine");
    }

    public void write(ExceptionlessOutputStream exceptionlessOutputStream) {
        int n;
        super.write(exceptionlessOutputStream);
        exceptionlessOutputStream.writeString(this.solverType);
        exceptionlessOutputStream.writeDouble(this.C);
        exceptionlessOutputStream.writeDouble(this.epsilon);
        exceptionlessOutputStream.writeDouble(this.bias);
        exceptionlessOutputStream.writeBoolean(this.displayLL);
        exceptionlessOutputStream.writeInt(this.numClasses);
        exceptionlessOutputStream.writeInt(this.numFeatures);
        exceptionlessOutputStream.writeBoolean(this.conjunctiveLabels);
        exceptionlessOutputStream.writeInt(this.allowableValues.length);
        for (n = 0; n < this.allowableValues.length; ++n) {
            exceptionlessOutputStream.writeString(this.allowableValues[n]);
        }
        if (this.newLabelLexicon == this.labelLexicon) {
            exceptionlessOutputStream.writeBoolean(false);
        } else {
            exceptionlessOutputStream.writeBoolean(true);
            this.newLabelLexicon.write(exceptionlessOutputStream);
        }
        if (this.weights == null) {
            exceptionlessOutputStream.writeInt(0);
        } else {
            exceptionlessOutputStream.writeInt(this.weights.length);
            for (n = 0; n < this.weights.length; ++n) {
                exceptionlessOutputStream.writeDouble(this.weights[n]);
            }
        }
    }

    public void read(ExceptionlessInputStream exceptionlessInputStream) {
        int n;
        super.read(exceptionlessInputStream);
        this.solverType = exceptionlessInputStream.readString();
        this.C = exceptionlessInputStream.readDouble();
        this.epsilon = exceptionlessInputStream.readDouble();
        this.bias = exceptionlessInputStream.readDouble();
        this.biasFeatures = this.bias >= 0.0 ? 1 : 0;
        this.displayLL = exceptionlessInputStream.readBoolean();
        this.numClasses = exceptionlessInputStream.readInt();
        this.numFeatures = exceptionlessInputStream.readInt();
        this.conjunctiveLabels = exceptionlessInputStream.readBoolean();
        int n2 = exceptionlessInputStream.readInt();
        this.allowableValues = new String[n2];
        for (n = 0; n < n2; ++n) {
            this.allowableValues[n] = exceptionlessInputStream.readString();
        }
        this.newLabelLexicon = exceptionlessInputStream.readBoolean() ? Lexicon.readLexicon(exceptionlessInputStream) : this.labelLexicon;
        n2 = exceptionlessInputStream.readInt();
        this.weights = new double[n2];
        for (n = 0; n < n2; ++n) {
            this.weights[n] = exceptionlessInputStream.readDouble();
        }
    }

    public Feature featureValue(int[] nArray, double[] dArray) {
        if (this.weights == null && this.allLabels != null && !this.warningPrinted) {
            System.err.println("LBJ WARNING: SupportVectorMachine's doneLearning() method should be called before attempting to make predictions.");
            this.warningPrinted = true;
        }
        if (this.weights == null) {
            return null;
        }
        double d = Double.NEGATIVE_INFINITY;
        int n = 0;
        if (this.numClasses > 2 || this.solverType.equals("MCSVM_CS")) {
            for (int i = 0; i < this.numClasses; ++i) {
                double d2 = this.score(nArray, dArray, i);
                if (!(d2 > d)) continue;
                d = d2;
                n = i;
            }
        } else {
            double d3 = this.score(nArray, dArray, 0);
            if (d3 < 0.0) {
                n = 1;
            }
        }
        return this.predictions.get(n);
    }

    public String discreteValue(int[] nArray, double[] dArray) {
        return this.featureValue(nArray, dArray).getStringValue();
    }

    public FeatureVector classify(int[] nArray, double[] dArray) {
        return new FeatureVector(this.featureValue(nArray, dArray));
    }

    public ScoreSet scores(int[] nArray, double[] dArray) {
        ScoreSet scoreSet = new ScoreSet();
        if (this.weights == null) {
            if (this.allLabels != null && !this.warningPrinted) {
                System.err.println("LBJ WARNING: SupportVectorMachine's doneLearning() method should be called before attempting to make predictions.");
                this.warningPrinted = true;
            }
            return scoreSet;
        }
        if (this.numClasses > 2 || this.solverType.equals("MCSVM_CS")) {
            for (int i = 0; i < this.numClasses; ++i) {
                String string = this.newLabelLexicon.lookupKey(i).getStringValue();
                double d = this.score(nArray, dArray, i);
                scoreSet.put(string, d);
            }
        } else {
            String string = this.newLabelLexicon.lookupKey(0).getStringValue();
            double d = this.score(nArray, dArray, 0);
            scoreSet.put(string, d);
            string = this.newLabelLexicon.lookupKey(1).getStringValue();
            scoreSet.put(string, -d);
        }
        return scoreSet;
    }

    public double score(Object object) {
        if (!$assertionsDisabled && (this.solverType.equals("MCSVM_CS") || this.numClasses != 2)) {
            throw new AssertionError((Object)"Cannot call score(Object) in a multi-class classifier.");
        }
        return this.score(object, 0);
    }

    public double score(Object object, int n) {
        Object[] objectArray = this.getExampleArray(object, false);
        return this.score((int[])objectArray[0], (double[])objectArray[1], n);
    }

    public double score(int[] nArray, double[] dArray, int n) {
        if (!$assertionsDisabled && nArray.length != dArray.length) {
            throw new AssertionError((Object)"Array mismatch; improperly formatted input.");
        }
        double d = 0.0;
        if (this.weights == null) {
            if (this.allLabels != null && !this.warningPrinted) {
                System.err.println("LBJ WARNING: SupportVectorMachine's doneLearning() method should be called before attempting to make predictions.");
                this.warningPrinted = true;
            }
            return 0.0;
        }
        boolean bl = false;
        if (this.numClasses <= 2 && !this.solverType.equals("MCSVM_CS")) {
            if (n == 1) {
                bl = true;
            }
            this.numClasses = 1;
            n = 0;
        }
        for (int i = 0; i < nArray.length; ++i) {
            int n2 = nArray[i];
            if (n2 >= this.numFeatures) continue;
            double d2 = dArray[i];
            double d3 = this.weights[n2 * this.numClasses + n];
            d += d3 * d2;
        }
        if (this.bias >= 0.0) {
            d += this.bias * this.weights[this.numFeatures * this.numClasses + n];
        }
        return bl ? -d : d;
    }

    public Feature valueOf(Object object, Collection collection) {
        Object[] objectArray = this.getExampleArray(object, false);
        return this.valueOf((int[])objectArray[0], (double[])objectArray[1], collection);
    }

    public Feature valueOf(int[] nArray, double[] dArray, Collection collection) {
        if (this.weights == null && this.allLabels != null && !this.warningPrinted) {
            System.err.println("LBJ WARNING: SupportVectorMachine's doneLearning() method should be called before attempting to make predictions.");
            this.warningPrinted = true;
        }
        if (this.weights == null) {
            return null;
        }
        double d = Double.NEGATIVE_INFINITY;
        int n = -1;
        Iterator iterator = collection.iterator();
        if (iterator.hasNext()) {
            if (this.conjunctiveLabels) {
                return this.conjunctiveValueOf(nArray, dArray, iterator);
            }
            while (iterator.hasNext()) {
                double d2 = Double.NEGATIVE_INFINITY;
                String string = (String)iterator.next();
                DiscretePrimitiveStringFeature discretePrimitiveStringFeature = new DiscretePrimitiveStringFeature(this.labeler.containingPackage, this.labeler.name, "", string, this.valueIndexOf(string), (short)this.allowableValues.length);
                int n2 = -1;
                if (this.newLabelLexicon.contains(discretePrimitiveStringFeature)) {
                    n2 = this.newLabelLexicon.lookup(discretePrimitiveStringFeature);
                    d2 = this.score(nArray, dArray, n2);
                }
                if (!(d2 > d)) continue;
                n = n2;
                d = d2;
            }
        } else {
            for (int i = 0; i < this.numClasses; ++i) {
                double d3 = this.score(nArray, dArray, i);
                if (!(d3 > d)) continue;
                n = i;
                d = d3;
            }
        }
        return this.predictions.get(n);
    }

    protected Feature conjunctiveValueOf(int[] nArray, double[] dArray, Iterator iterator) {
        double d = Double.NEGATIVE_INFINITY;
        int n = -1;
        block0: while (iterator.hasNext()) {
            String string = (String)iterator.next();
            for (int i = 0; i < this.numClasses; ++i) {
                if (!this.labelLexicon.lookupKey(i).valueEquals(string)) continue;
                double d2 = this.score(nArray, dArray, i);
                if (!(d2 > d)) continue block0;
                d = d2;
                n = i;
                continue block0;
            }
        }
        return this.predictions.get(n);
    }

    public void forget() {
        super.forget();
        this.numFeatures = 0;
        this.numClasses = 0;
        this.allLabels = null;
        this.allExamples = null;
        this.weights = null;
        this.conjunctiveLabels = false;
    }

    static {
        $assertionsDisabled = !SupportVectorMachine.class.desiredAssertionStatus();
    }

    public static class Parameters
    extends Learner.Parameters {
        public String solverType;
        public double C;
        public double epsilon;
        public double bias;
        public boolean displayLL;

        public Parameters() {
            this.solverType = "";
            this.C = 1.0;
            this.epsilon = 0.1;
            this.bias = 1.0;
            this.displayLL = false;
        }

        public Parameters(Learner.Parameters parameters) {
            super(parameters);
            this.solverType = "";
            this.C = 1.0;
            this.epsilon = 0.1;
            this.bias = 1.0;
            this.displayLL = false;
        }

        public Parameters(Parameters parameters) {
            super(parameters);
            this.solverType = parameters.solverType;
            this.C = parameters.C;
            this.epsilon = parameters.epsilon;
            this.bias = parameters.bias;
            this.displayLL = parameters.displayLL;
        }

        public void setParameters(Learner learner) {
            ((SupportVectorMachine)learner).setParameters(this);
        }

        public static SolverType getSolverType(String string) {
            if (string.equals("L2_LR")) {
                return SolverType.L2_LR;
            }
            if (string.equals("L2LOSS_SVM_DUAL")) {
                return SolverType.L2LOSS_SVM_DUAL;
            }
            if (string.equals(SupportVectorMachine.defaultSolverType)) {
                return SolverType.L2LOSS_SVM;
            }
            if (string.equals("L1LOSS_SVM_DUAL")) {
                return SolverType.L1LOSS_SVM_DUAL;
            }
            if (string.equals("MCSVM_CS")) {
                return SolverType.MCSVM_CS;
            }
            return SolverType.L2LOSS_SVM;
        }

        public String nonDefaultString() {
            String string = super.nonDefaultString();
            if (!this.solverType.equals(SupportVectorMachine.defaultSolverType)) {
                string = string + ", solverType = \"" + this.solverType + "\"";
            }
            if (this.C != 1.0) {
                string = string + ", C = " + this.C;
            }
            if (this.epsilon != 0.1) {
                string = string + ", epsilon = " + this.epsilon;
            }
            if (this.bias != 1.0) {
                string = string + ", bias = " + this.bias;
            }
            if (string.startsWith(", ")) {
                string = string.substring(2);
            }
            return string;
        }
    }
}

