/*
 * Decompiled with CFR 0.152.
 */
package LBJ2.learn;

import LBJ2.classify.Feature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.RealPrimitiveStringFeature;
import LBJ2.classify.ScoreSet;
import LBJ2.learn.Learner;
import LBJ2.learn.SparseWeightVector;
import LBJ2.util.ExceptionlessInputStream;
import LBJ2.util.ExceptionlessOutputStream;
import java.io.PrintStream;

public class StochasticGradientDescent
extends Learner {
    public static final double defaultLearningRate = 0.1;
    public static final SparseWeightVector defaultWeightVector;
    protected SparseWeightVector weightVector;
    protected double bias;
    protected double learningRate;
    static final /* synthetic */ boolean $assertionsDisabled;

    public StochasticGradientDescent() {
        this("");
    }

    public StochasticGradientDescent(double d) {
        this("", d);
    }

    public StochasticGradientDescent(Parameters parameters) {
        this("", parameters);
    }

    public StochasticGradientDescent(String string) {
        this(string, 0.1);
    }

    public StochasticGradientDescent(String string, double d) {
        super(string);
        Parameters parameters = new Parameters();
        parameters.learningRate = d;
        this.setParameters(parameters);
    }

    public StochasticGradientDescent(String string, Parameters parameters) {
        super(string);
        this.setParameters(parameters);
    }

    public void setParameters(Parameters parameters) {
        this.weightVector = parameters.weightVector;
        this.learningRate = parameters.learningRate;
    }

    public Learner.Parameters getParameters() {
        Parameters parameters = new Parameters(super.getParameters());
        parameters.weightVector = this.weightVector.emptyClone();
        parameters.learningRate = this.learningRate;
        return parameters;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void forget() {
        super.forget();
        this.weightVector = this.weightVector.emptyClone();
        this.bias = 0.0;
    }

    public String getOutputType() {
        return "real";
    }

    public void learn(int[] nArray, double[] dArray, int[] nArray2, double[] dArray2) {
        if (!$assertionsDisabled && nArray2.length != 1) {
            throw new AssertionError((Object)"Example must have a single label.");
        }
        double d = dArray2[0];
        double d2 = this.learningRate * (d - this.weightVector.dot(nArray, dArray) - this.bias);
        this.weightVector.scaledAdd(nArray, dArray, d2);
        this.bias += d2;
    }

    public ScoreSet scores(int[] nArray, double[] dArray) {
        return null;
    }

    public Feature featureValue(int[] nArray, double[] dArray) {
        return new RealPrimitiveStringFeature(this.containingPackage, this.name, "", this.realValue(nArray, dArray));
    }

    public double realValue(int[] nArray, double[] dArray) {
        return this.weightVector.dot(nArray, dArray) + this.bias;
    }

    public FeatureVector classify(int[] nArray, double[] dArray) {
        return new FeatureVector(this.featureValue(nArray, dArray));
    }

    public void write(PrintStream printStream) {
        printStream.println(this.name + ": " + this.learningRate + ", " + this.bias);
        if (this.lexicon.size() == 0) {
            this.weightVector.write(printStream);
        } else {
            this.weightVector.write(printStream, this.lexicon);
        }
    }

    public void write(ExceptionlessOutputStream exceptionlessOutputStream) {
        super.write(exceptionlessOutputStream);
        exceptionlessOutputStream.writeDouble(this.learningRate);
        exceptionlessOutputStream.writeDouble(this.bias);
        this.weightVector.write(exceptionlessOutputStream);
    }

    public void read(ExceptionlessInputStream exceptionlessInputStream) {
        super.read(exceptionlessInputStream);
        this.learningRate = exceptionlessInputStream.readDouble();
        this.bias = exceptionlessInputStream.readDouble();
        this.weightVector = SparseWeightVector.readWeightVector(exceptionlessInputStream);
    }

    public Object clone() {
        StochasticGradientDescent stochasticGradientDescent = null;
        try {
            stochasticGradientDescent = (StochasticGradientDescent)super.clone();
        }
        catch (Exception exception) {
            System.err.println("Error cloning StochasticGradientDescent: " + exception);
            System.exit(1);
        }
        stochasticGradientDescent.weightVector = (SparseWeightVector)this.weightVector.clone();
        return stochasticGradientDescent;
    }

    static {
        $assertionsDisabled = !StochasticGradientDescent.class.desiredAssertionStatus();
        defaultWeightVector = new SparseWeightVector();
    }

    public static class Parameters
    extends Learner.Parameters {
        public SparseWeightVector weightVector;
        public double learningRate;

        public Parameters() {
            this.weightVector = (SparseWeightVector)defaultWeightVector.clone();
            this.learningRate = 0.1;
        }

        public Parameters(Learner.Parameters parameters) {
            super(parameters);
            this.weightVector = (SparseWeightVector)defaultWeightVector.clone();
            this.learningRate = 0.1;
        }

        public Parameters(Parameters parameters) {
            super(parameters);
            this.weightVector = parameters.weightVector;
            this.learningRate = parameters.learningRate;
        }

        public void setParameters(Learner learner) {
            ((StochasticGradientDescent)learner).setParameters(this);
        }

        public String nonDefaultString() {
            String string = super.nonDefaultString();
            if (this.learningRate != 0.1) {
                string = string + ", learningRate = " + this.learningRate;
            }
            if (string.startsWith(", ")) {
                string = string.substring(2);
            }
            return string;
        }
    }
}

