/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.maxent;

import edu.stanford.nlp.maxent.iis.LambdaSolve;
import edu.stanford.nlp.optimization.CGMinimizer;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.util.ReflectionLoading;
import java.util.Arrays;

public class CGRunner {
    private static final boolean SAVE_LAMBDAS_REGULARLY = false;
    private final LambdaSolve prob;
    private final String filename;
    private final double tol;
    private final boolean useGaussianPrior;
    private final double priorSigmaS;
    private final double[] sigmaSquareds;
    private static final double DEFAULT_TOLERANCE = 1.0E-4;
    private static final double DEFAULT_SIGMASQUARED = 0.5;

    public CGRunner(LambdaSolve prob, String filename) {
        this(prob, filename, 0.5);
    }

    public CGRunner(LambdaSolve prob, String filename, double priorSigmaS) {
        this(prob, filename, 1.0E-4, priorSigmaS);
    }

    public CGRunner(LambdaSolve prob, String filename, double tol, double priorSigmaS) {
        this.prob = prob;
        this.filename = filename;
        this.tol = tol;
        this.useGaussianPrior = priorSigmaS != 0.0 && priorSigmaS != Double.POSITIVE_INFINITY;
        this.priorSigmaS = priorSigmaS;
        this.sigmaSquareds = null;
    }

    public CGRunner(LambdaSolve prob, String filename, double tol, double[] sigmaSquareds) {
        this.prob = prob;
        this.filename = filename;
        this.tol = tol;
        this.useGaussianPrior = sigmaSquareds != null;
        this.sigmaSquareds = sigmaSquareds;
        this.priorSigmaS = -1.0;
    }

    public void solve() {
        this.solveQN();
    }

    public void solveQN() {
        LikelihoodFunction df = new LikelihoodFunction(this.prob, this.tol, this.useGaussianPrior, this.priorSigmaS, this.sigmaSquareds);
        MonitorFunction monitor = new MonitorFunction(this.prob, df, this.filename);
        QNMinimizer cgm = new QNMinimizer(monitor, 10);
        double[] result = cgm.minimize(df, this.tol, new double[df.domainDimension()]);
        this.prob.lambda = result;
        monitor.reportMonitoring(df.valueAt(result));
        System.err.println("after optimization value is " + df.valueAt(result));
    }

    public void solveCG() {
        LikelihoodFunction df = new LikelihoodFunction(this.prob, this.tol, this.useGaussianPrior, this.priorSigmaS, this.sigmaSquareds);
        MonitorFunction monitor = new MonitorFunction(this.prob, df, this.filename);
        CGMinimizer cgm = new CGMinimizer(monitor);
        double[] result = cgm.minimize(df, this.tol, new double[df.domainDimension()]);
        this.prob.lambda = result;
        monitor.reportMonitoring(df.valueAt(result));
        System.err.println("after optimization value is " + df.valueAt(result));
    }

    public void solveL1(double weight) {
        LikelihoodFunction df = new LikelihoodFunction(this.prob, this.tol, this.useGaussianPrior, this.priorSigmaS, this.sigmaSquareds);
        Minimizer owl = (Minimizer)ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", weight);
        double[] result = owl.minimize(df, this.tol, new double[df.domainDimension()]);
        this.prob.lambda = result;
        System.err.println("after optimization value is " + df.valueAt(result));
    }

    private static final class MonitorFunction
    implements Function {
        private final LambdaSolve model;
        private final LikelihoodFunction lf;
        private final String filename;
        private int iterations;

        public MonitorFunction(LambdaSolve m, LikelihoodFunction lf, String filename) {
            this.model = m;
            this.lf = lf;
            this.filename = filename;
        }

        public double valueAt(double[] lambda) {
            double likelihood = this.lf.likelihood();
            System.err.println();
            System.err.print(this.reportMonitoring(likelihood));
            if (this.iterations > 0 && this.iterations % 30 == 0) {
                this.model.checkCorrectness();
            }
            ++this.iterations;
            return 42.0;
        }

        public String reportMonitoring(double likelihood) {
            return "Iter. " + this.iterations + ": " + "neg. log cond. likelihood = " + likelihood + " [" + this.lf.numCalls() + " calls to valueAt]";
        }

        public int domainDimension() {
            return this.lf.domainDimension();
        }
    }

    private static final class LikelihoodFunction
    implements DiffFunction {
        private final LambdaSolve model;
        private final double tol;
        private final boolean useGaussianPrior;
        private final double[] sigmaSquareds;
        private int valueAtCalls;
        private double likelihood;

        public LikelihoodFunction(LambdaSolve m, double tol, boolean useGaussianPrior, double sigmaSquared, double[] sigmaSquareds) {
            this.model = m;
            this.tol = tol;
            this.useGaussianPrior = useGaussianPrior;
            if (useGaussianPrior) {
                this.sigmaSquareds = new double[this.model.lambda.length];
                if (sigmaSquareds != null) {
                    System.arraycopy(sigmaSquareds, 0, this.sigmaSquareds, 0, sigmaSquareds.length);
                } else {
                    Arrays.fill(this.sigmaSquareds, sigmaSquared);
                }
            } else {
                this.sigmaSquareds = null;
            }
        }

        public int domainDimension() {
            return this.model.lambda.length;
        }

        public double likelihood() {
            return this.likelihood;
        }

        public int numCalls() {
            return this.valueAtCalls;
        }

        public double valueAt(double[] lambda) {
            ++this.valueAtCalls;
            this.model.lambda = lambda;
            double lik = this.model.logLikelihoodScratch();
            if (this.useGaussianPrior) {
                for (int i = 0; i < lambda.length; ++i) {
                    lik += lambda[i] * lambda[i] / (this.sigmaSquareds[i] + this.sigmaSquareds[i]);
                }
            }
            this.likelihood = lik;
            return lik;
        }

        public double[] derivativeAt(double[] lambda) {
            boolean eq = true;
            for (int j = 0; j < lambda.length; ++j) {
                if (!(Math.abs(lambda[j] - this.model.lambda[j]) > this.tol)) continue;
                eq = false;
                break;
            }
            if (!eq) {
                System.err.println("derivativeAt: call with different value");
                this.valueAt(lambda);
            }
            double[] drvs = this.model.getDerivatives();
            if (this.useGaussianPrior) {
                for (int j = 0; j < lambda.length; ++j) {
                    int n = j;
                    drvs[n] = drvs[n] + lambda[j] / this.sigmaSquareds[j];
                }
            }
            return drvs;
        }
    }
}

