/*
 * Decompiled with CFR 0.152.
 */
package LBJ2.learn;

import LBJ2.classify.Feature;
import LBJ2.learn.Learner;
import LBJ2.learn.Lexicon;
import LBJ2.learn.SparsePerceptron;
import LBJ2.learn.SparseWeightVector;
import LBJ2.util.DVector;
import LBJ2.util.ExceptionlessInputStream;
import LBJ2.util.ExceptionlessOutputStream;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Map;

public class SparseAveragedPerceptron
extends SparsePerceptron {
    public static final AveragedWeightVector defaultWeightVector;
    protected AveragedWeightVector awv;
    protected double averagedBias;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SparseAveragedPerceptron() {
        this("");
    }

    public SparseAveragedPerceptron(double d) {
        this("", d);
    }

    public SparseAveragedPerceptron(double d, double d2) {
        this("", d, d2);
    }

    public SparseAveragedPerceptron(double d, double d2, double d3) {
        this("", d, d2, d3);
    }

    public SparseAveragedPerceptron(double d, double d2, double d3, double d4) {
        this("", d, d2, d3, d4);
    }

    public SparseAveragedPerceptron(Parameters parameters) {
        this("", parameters);
    }

    public SparseAveragedPerceptron(String string) {
        this(string, 0.1);
    }

    public SparseAveragedPerceptron(String string, double d) {
        this(string, d, 0.0);
    }

    public SparseAveragedPerceptron(String string, double d, double d2) {
        this(string, d, d2, 0.0);
    }

    public SparseAveragedPerceptron(String string, double d, double d2, double d3) {
        this(string, d, d2, d3, d3);
    }

    public SparseAveragedPerceptron(String string, double d, double d2, double d3, double d4) {
        super(string);
        Parameters parameters = new Parameters();
        parameters.learningRate = d;
        parameters.threshold = d2;
        parameters.positiveThickness = d3;
        parameters.negativeThickness = d4;
        this.setParameters(parameters);
    }

    public SparseAveragedPerceptron(String string, Parameters parameters) {
        super(string);
        this.setParameters(parameters);
    }

    public Learner.Parameters getParameters() {
        Parameters parameters = new Parameters((SparsePerceptron.Parameters)super.getParameters());
        return parameters;
    }

    public void setParameters(Parameters parameters) {
        super.setParameters(parameters);
        this.awv = (AveragedWeightVector)this.weightVector;
    }

    public double score(int[] nArray, double[] dArray) {
        double d = this.awv.dot(nArray, dArray, this.initialWeight);
        int n = this.awv.getExamples();
        if (n > 0) {
            d += ((double)n * this.bias - this.averagedBias) / (double)n;
        }
        return d;
    }

    public void promote(int[] nArray, double[] dArray, double d) {
        this.bias += d;
        int n = this.awv.getExamples();
        this.averagedBias += (double)n * d;
        this.awv.scaledAdd(nArray, dArray, d, this.initialWeight);
    }

    public void demote(int[] nArray, double[] dArray, double d) {
        this.bias -= d;
        int n = this.awv.getExamples();
        this.averagedBias -= (double)n * d;
        this.awv.scaledAdd(nArray, dArray, -d, this.initialWeight);
    }

    public void learn(int[] nArray, double[] dArray, int[] nArray2, double[] dArray2) {
        if (!$assertionsDisabled && nArray2.length != 1) {
            throw new AssertionError((Object)"Example must have a single label.");
        }
        if (!$assertionsDisabled && nArray2[0] != 0 && nArray2[0] != 1) {
            throw new AssertionError((Object)"Example has unallowed label value.");
        }
        boolean bl = nArray2[0] == 1;
        double d = this.awv.simpleDot(nArray, dArray, this.initialWeight) + this.bias;
        if (bl && d < this.threshold + this.positiveThickness) {
            this.promote(nArray, dArray, this.getLearningRate());
        } else if (!bl && d >= this.threshold - this.negativeThickness) {
            this.demote(nArray, dArray, this.getLearningRate());
        } else {
            this.awv.correctExample();
        }
    }

    public void initialize(int n, int n2) {
        double[] dArray = new double[n2];
        Arrays.fill(dArray, this.initialWeight);
        this.awv = new AveragedWeightVector(dArray);
        this.weightVector = this.awv;
    }

    public void forget() {
        super.forget();
        this.awv = (AveragedWeightVector)this.weightVector;
        this.averagedBias = 0.0;
    }

    public void write(PrintStream printStream) {
        printStream.println(this.name + ": " + this.learningRate + ", " + this.initialWeight + ", " + this.threshold + ", " + this.positiveThickness + ", " + this.negativeThickness + ", " + this.bias + ", " + this.averagedBias);
        if (this.lexicon == null || this.lexicon.size() == 0) {
            this.awv.write(printStream);
        } else {
            this.awv.write(printStream, this.lexicon);
        }
    }

    public void write(ExceptionlessOutputStream exceptionlessOutputStream) {
        super.write(exceptionlessOutputStream);
        exceptionlessOutputStream.writeDouble(this.averagedBias);
    }

    public void read(ExceptionlessInputStream exceptionlessInputStream) {
        super.read(exceptionlessInputStream);
        this.awv = (AveragedWeightVector)this.weightVector;
        this.averagedBias = exceptionlessInputStream.readDouble();
    }

    static {
        $assertionsDisabled = !SparseAveragedPerceptron.class.desiredAssertionStatus();
        defaultWeightVector = new AveragedWeightVector();
    }

    public static class AveragedWeightVector
    extends SparseWeightVector {
        public DVector averagedWeights;
        protected int examples;

        public AveragedWeightVector() {
            this(new DVector(1024));
        }

        public AveragedWeightVector(double[] dArray) {
            this(new DVector(dArray));
        }

        public AveragedWeightVector(DVector dVector) {
            super((DVector)dVector.clone());
            this.averagedWeights = dVector;
        }

        public void correctExample() {
            ++this.examples;
        }

        public int getExamples() {
            return this.examples;
        }

        public double getAveragedWeight(int n, double d) {
            if (this.examples == 0) {
                return 0.0;
            }
            double d2 = this.averagedWeights.get(n, d);
            double d3 = this.getWeight(n, d);
            return ((double)this.examples * d3 - d2) / (double)this.examples;
        }

        public double dot(int[] nArray, double[] dArray) {
            return this.dot(nArray, dArray, 0.0);
        }

        public double dot(int[] nArray, double[] dArray, double d) {
            double d2 = 0.0;
            for (int i = 0; i < nArray.length; ++i) {
                double d3 = this.getAveragedWeight(nArray[i], d);
                d2 += d3 * dArray[i];
            }
            return d2;
        }

        public double simpleDot(int[] nArray, double[] dArray) {
            return super.dot(nArray, dArray, 0.0);
        }

        public double simpleDot(int[] nArray, double[] dArray, double d) {
            return super.dot(nArray, dArray, d);
        }

        public void scaledAdd(int[] nArray, double[] dArray, double d) {
            this.scaledAdd(nArray, dArray, d, 0.0);
        }

        public void scaledAdd(int[] nArray, double[] dArray, double d, double d2) {
            for (int i = 0; i < nArray.length; ++i) {
                int n = nArray[i];
                double d3 = this.getWeight(n, d2);
                double d4 = d3 + d * dArray[i];
                double d5 = d4 - d3;
                this.updateAveragedWeight(n, (double)this.examples * d5);
                this.setWeight(n, d4);
            }
            ++this.examples;
        }

        protected void updateAveragedWeight(int n, double d) {
            double d2 = this.averagedWeights.get(n, 0.0) + d;
            this.averagedWeights.set(n, d2, 0.0);
        }

        public void write(PrintStream printStream) {
            printStream.println("Begin AveragedWeightVector");
            for (int i = 0; i < this.averagedWeights.size(); ++i) {
                printStream.println(this.getAveragedWeight(i, 0.0));
            }
            printStream.println("End AveragedWeightVector");
        }

        public void write(PrintStream printStream, Lexicon lexicon) {
            String string;
            int n;
            printStream.println("Begin AveragedWeightVector");
            Map map = lexicon.getMap();
            Map.Entry[] entryArray = map.entrySet().toArray(new Map.Entry[map.size()]);
            Arrays.sort(entryArray, new Comparator(){

                public int compare(Object object, Object object2) {
                    Map.Entry entry = (Map.Entry)object;
                    Map.Entry entry2 = (Map.Entry)object2;
                    int n = (Integer)entry.getValue();
                    int n2 = (Integer)entry2.getValue();
                    if (n < AveragedWeightVector.this.weights.size() != n2 < AveragedWeightVector.this.weights.size()) {
                        return n - n2;
                    }
                    return ((Feature)entry.getKey()).compareTo(entry2.getKey());
                }
            });
            int n2 = 0;
            for (n = 0; n < entryArray.length; ++n) {
                string = entryArray[n].getKey().toString() + ((Integer)entryArray[n].getValue() < this.weights.size() ? "" : " (pruned)");
                n2 = Math.max(n2, string.length());
            }
            n2 = n2 % 2 == 0 ? (n2 += 2) : ++n2;
            for (n = 0; n < entryArray.length; ++n) {
                string = entryArray[n].getKey().toString() + ((Integer)entryArray[n].getValue() < this.weights.size() ? "" : " (pruned)");
                printStream.print(string);
                int n3 = 0;
                while (string.length() + n3 < n2) {
                    printStream.print(" ");
                    ++n3;
                }
                n3 = (Integer)entryArray[n].getValue();
                double d = this.getAveragedWeight(n3, 0.0);
                printStream.println(d);
            }
            printStream.println("End AveragedWeightVector");
        }

        public void write(ExceptionlessOutputStream exceptionlessOutputStream) {
            super.write(exceptionlessOutputStream);
            exceptionlessOutputStream.writeInt(this.examples);
            this.averagedWeights.write(exceptionlessOutputStream);
        }

        public void read(ExceptionlessInputStream exceptionlessInputStream) {
            super.read(exceptionlessInputStream);
            this.examples = exceptionlessInputStream.readInt();
            this.averagedWeights.read(exceptionlessInputStream);
        }

        public Object clone() {
            AveragedWeightVector averagedWeightVector = (AveragedWeightVector)super.clone();
            averagedWeightVector.averagedWeights = (DVector)this.averagedWeights.clone();
            return averagedWeightVector;
        }

        public SparseWeightVector emptyClone() {
            return new AveragedWeightVector();
        }
    }

    public static class Parameters
    extends SparsePerceptron.Parameters {
        public Parameters() {
            this.weightVector = (AveragedWeightVector)defaultWeightVector.clone();
        }

        public Parameters(SparsePerceptron.Parameters parameters) {
            super(parameters);
        }

        public Parameters(Parameters parameters) {
            super(parameters);
        }

        public void setParameters(Learner learner) {
            ((SparseAveragedPerceptron)learner).setParameters(this);
        }
    }
}

