/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.sentiment.SentimentModel;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.TwoDimensionalMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class SentimentCostAndGradient
extends AbstractCachingDiffFunction {
    SentimentModel model;
    List<Tree> trainingBatch;

    public SentimentCostAndGradient(SentimentModel model, List<Tree> trainingBatch) {
        this.model = model;
        this.trainingBatch = trainingBatch;
    }

    @Override
    public int domainDimension() {
        return this.model.totalParamSize();
    }

    private static double sumError(Tree tree) {
        if (tree.isLeaf()) {
            return 0.0;
        }
        if (tree.isPreTerminal()) {
            return RNNCoreAnnotations.getPredictionError(tree);
        }
        double error = 0.0;
        for (Tree child : tree.children()) {
            error += SentimentCostAndGradient.sumError(child);
        }
        return RNNCoreAnnotations.getPredictionError(tree) + error;
    }

    public int getPredictedClass(SimpleMatrix predictions) {
        int argmax = 0;
        for (int i = 1; i < predictions.getNumElements(); ++i) {
            if (!(predictions.get(i) > predictions.get(argmax))) continue;
            argmax = i;
        }
        return argmax;
    }

    @Override
    public void calculate(double[] theta) {
        int numCols;
        int numRows;
        this.model.vectorToParams(theta);
        double localValue = 0.0;
        double[] localDerivative = new double[theta.length];
        TwoDimensionalMap<String, String, SimpleMatrix> binaryTD = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleMatrix> binaryCD = TwoDimensionalMap.treeMap();
        TreeMap<String, SimpleMatrix> unaryCD = Generics.newTreeMap();
        TreeMap<String, SimpleMatrix> wordVectorD = Generics.newTreeMap();
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.model.binaryTransform) {
            numRows = entry.getValue().numRows();
            numCols = entry.getValue().numCols();
            binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
        }
        if (!this.model.op.combineClassification) {
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.model.binaryClassification) {
                numRows = entry.getValue().numRows();
                numCols = entry.getValue().numCols();
                binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
            }
        }
        if (this.model.op.useTensors) {
            for (TwoDimensionalMap.Entry entry : this.model.binaryTensors) {
                numRows = ((SimpleTensor)entry.getValue()).numRows();
                numCols = ((SimpleTensor)entry.getValue()).numCols();
                int numSlices = ((SimpleTensor)entry.getValue()).numSlices();
                binaryTensorTD.put((String)entry.getFirstKey(), (String)entry.getSecondKey(), new SimpleTensor(numRows, numCols, numSlices));
            }
        }
        for (Map.Entry entry : this.model.unaryClassification.entrySet()) {
            numRows = ((SimpleMatrix)entry.getValue()).numRows();
            numCols = ((SimpleMatrix)entry.getValue()).numCols();
            unaryCD.put((String)entry.getKey(), new SimpleMatrix(numRows, numCols));
        }
        for (Map.Entry entry : this.model.wordVectors.entrySet()) {
            numRows = ((SimpleMatrix)entry.getValue()).numRows();
            numCols = ((SimpleMatrix)entry.getValue()).numCols();
            wordVectorD.put((String)entry.getKey(), new SimpleMatrix(numRows, numCols));
        }
        ArrayList<Tree> forwardPropTrees = Generics.newArrayList();
        for (Tree tree : this.trainingBatch) {
            Tree trainingTree = tree.deepCopy();
            this.forwardPropagateTree(trainingTree);
            forwardPropTrees.add(trainingTree);
        }
        double d = 0.0;
        for (Tree tree : forwardPropTrees) {
            this.backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD);
            d += SentimentCostAndGradient.sumError(tree);
        }
        double scale = 1.0 / (double)this.trainingBatch.size();
        this.value = d * scale;
        this.value += this.scaleAndRegularize(binaryTD, this.model.binaryTransform, scale, this.model.op.trainOptions.regTransform);
        this.value += this.scaleAndRegularize(binaryCD, this.model.binaryClassification, scale, this.model.op.trainOptions.regClassification);
        this.value += this.scaleAndRegularizeTensor(binaryTensorTD, this.model.binaryTensors, scale, this.model.op.trainOptions.regTransform);
        this.value += this.scaleAndRegularize(unaryCD, this.model.unaryClassification, scale, this.model.op.trainOptions.regClassification);
        this.value += this.scaleAndRegularize(wordVectorD, this.model.wordVectors, scale, this.model.op.trainOptions.regWordVector);
        this.derivative = NeuralUtils.paramsToVector(theta.length, binaryTD.valueIterator(), binaryCD.valueIterator(), SimpleTensor.iteratorSimpleMatrix(binaryTensorTD.valueIterator()), unaryCD.values().iterator(), wordVectorD.values().iterator());
    }

    double scaleAndRegularize(TwoDimensionalMap<String, String, SimpleMatrix> derivatives, TwoDimensionalMap<String, String, SimpleMatrix> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : currentMatrices) {
            SimpleMatrix D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            D = (SimpleMatrix)((SimpleMatrix)D.scale(scale)).plus(entry.getValue().scale(regCost));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
            cost += ((SimpleMatrix)entry.getValue().elementMult((SimpleBase)entry.getValue())).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    double scaleAndRegularize(Map<String, SimpleMatrix> derivatives, Map<String, SimpleMatrix> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (Map.Entry<String, SimpleMatrix> entry : currentMatrices.entrySet()) {
            SimpleMatrix D = derivatives.get(entry.getKey());
            D = (SimpleMatrix)((SimpleMatrix)D.scale(scale)).plus(entry.getValue().scale(regCost));
            derivatives.put(entry.getKey(), D);
            cost += ((SimpleMatrix)entry.getValue().elementMult((SimpleBase)entry.getValue())).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    double scaleAndRegularizeTensor(TwoDimensionalMap<String, String, SimpleTensor> derivatives, TwoDimensionalMap<String, String, SimpleTensor> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : currentMatrices) {
            SimpleTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            D = D.scale(scale).plus(entry.getValue().scale(regCost));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
            cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    private void backpropDerivativesAndError(Tree tree, TwoDimensionalMap<String, String, SimpleMatrix> binaryTD, TwoDimensionalMap<String, String, SimpleMatrix> binaryCD, TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD, Map<String, SimpleMatrix> unaryCD, Map<String, SimpleMatrix> wordVectorD) {
        SimpleMatrix delta = new SimpleMatrix(this.model.op.numHid, 1);
        this.backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, delta);
    }

    private void backpropDerivativesAndError(Tree tree, TwoDimensionalMap<String, String, SimpleMatrix> binaryTD, TwoDimensionalMap<String, String, SimpleMatrix> binaryCD, TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD, Map<String, SimpleMatrix> unaryCD, Map<String, SimpleMatrix> wordVectorD, SimpleMatrix deltaUp) {
        if (tree.isLeaf()) {
            return;
        }
        SimpleMatrix currentVector = RNNCoreAnnotations.getNodeVector(tree);
        String category = tree.label().value();
        category = this.model.basicCategory(category);
        SimpleMatrix goldLabel = new SimpleMatrix(this.model.numClasses, 1);
        int goldClass = RNNCoreAnnotations.getGoldClass(tree);
        if (goldClass >= 0) {
            goldLabel.set(goldClass, 1.0);
        }
        double nodeWeight = this.model.op.trainOptions.getClassWeight(goldClass);
        SimpleMatrix predictions = RNNCoreAnnotations.getPredictions(tree);
        SimpleMatrix deltaClass = goldClass >= 0 ? (SimpleMatrix)((SimpleMatrix)predictions.minus((SimpleBase)goldLabel)).scale(nodeWeight) : new SimpleMatrix(predictions.numRows(), predictions.numCols());
        SimpleMatrix localCD = (SimpleMatrix)deltaClass.mult(NeuralUtils.concatenateWithBias(currentVector).transpose());
        double error = -((SimpleMatrix)NeuralUtils.elementwiseApplyLog(predictions).elementMult((SimpleBase)goldLabel)).elementSum();
        RNNCoreAnnotations.setPredictionError(tree, error *= nodeWeight);
        if (tree.isPreTerminal()) {
            unaryCD.put(category, (SimpleMatrix)unaryCD.get(category).plus((SimpleBase)localCD));
            String word = tree.children()[0].label().value();
            word = this.model.getVocabWord(word);
            SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
            SimpleMatrix deltaFromClass = (SimpleMatrix)((SimpleMatrix)this.model.getUnaryClassification(category).transpose()).mult((SimpleBase)deltaClass);
            deltaFromClass = (SimpleMatrix)((SimpleMatrix)deltaFromClass.extractMatrix(0, this.model.op.numHid, 0, 1)).elementMult((SimpleBase)currentVectorDerivative);
            SimpleMatrix deltaFull = (SimpleMatrix)deltaFromClass.plus((SimpleBase)deltaUp);
            wordVectorD.put(word, (SimpleMatrix)wordVectorD.get(word).plus((SimpleBase)deltaFull));
        } else {
            SimpleMatrix deltaDown;
            String leftCategory = this.model.basicCategory(tree.children()[0].label().value());
            String rightCategory = this.model.basicCategory(tree.children()[1].label().value());
            if (this.model.op.combineClassification) {
                unaryCD.put("", (SimpleMatrix)unaryCD.get("").plus((SimpleBase)localCD));
            } else {
                binaryCD.put(leftCategory, rightCategory, (SimpleMatrix)binaryCD.get(leftCategory, rightCategory).plus((SimpleBase)localCD));
            }
            SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
            SimpleMatrix deltaFromClass = (SimpleMatrix)((SimpleMatrix)this.model.getBinaryClassification(leftCategory, rightCategory).transpose()).mult((SimpleBase)deltaClass);
            deltaFromClass = (SimpleMatrix)((SimpleMatrix)deltaFromClass.extractMatrix(0, this.model.op.numHid, 0, 1)).elementMult((SimpleBase)currentVectorDerivative);
            SimpleMatrix deltaFull = (SimpleMatrix)deltaFromClass.plus((SimpleBase)deltaUp);
            SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
            SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
            SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
            SimpleMatrix W_df = (SimpleMatrix)deltaFull.mult(childrenVector.transpose());
            binaryTD.put(leftCategory, rightCategory, (SimpleMatrix)binaryTD.get(leftCategory, rightCategory).plus((SimpleBase)W_df));
            if (this.model.op.useTensors) {
                SimpleTensor Wt_df = this.getTensorGradient(deltaFull, leftVector, rightVector);
                binaryTensorTD.put(leftCategory, rightCategory, binaryTensorTD.get(leftCategory, rightCategory).plus(Wt_df));
                deltaDown = this.computeTensorDeltaDown(deltaFull, leftVector, rightVector, this.model.getBinaryTransform(leftCategory, rightCategory), this.model.getBinaryTensor(leftCategory, rightCategory));
            } else {
                deltaDown = (SimpleMatrix)((SimpleMatrix)this.model.getBinaryTransform(leftCategory, rightCategory).transpose()).mult((SimpleBase)deltaFull);
            }
            SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector);
            SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector);
            SimpleMatrix leftDeltaDown = (SimpleMatrix)deltaDown.extractMatrix(0, deltaFull.numRows(), 0, 1);
            SimpleMatrix rightDeltaDown = (SimpleMatrix)deltaDown.extractMatrix(deltaFull.numRows(), deltaFull.numRows() * 2, 0, 1);
            this.backpropDerivativesAndError(tree.children()[0], binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, (SimpleMatrix)leftDerivative.elementMult((SimpleBase)leftDeltaDown));
            this.backpropDerivativesAndError(tree.children()[1], binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, (SimpleMatrix)rightDerivative.elementMult((SimpleBase)rightDeltaDown));
        }
    }

    private SimpleMatrix computeTensorDeltaDown(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector, SimpleMatrix W, SimpleTensor Wt) {
        SimpleMatrix WTDelta = (SimpleMatrix)((SimpleMatrix)W.transpose()).mult((SimpleBase)deltaFull);
        SimpleMatrix WTDeltaNoBias = (SimpleMatrix)WTDelta.extractMatrix(0, deltaFull.numRows() * 2, 0, 1);
        int size = deltaFull.getNumElements();
        SimpleMatrix deltaTensor = new SimpleMatrix(size * 2, 1);
        SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
        for (int slice = 0; slice < size; ++slice) {
            SimpleMatrix scaledFullVector = (SimpleMatrix)fullVector.scale(deltaFull.get(slice));
            deltaTensor = (SimpleMatrix)deltaTensor.plus(((SimpleMatrix)Wt.getSlice(slice).plus(Wt.getSlice(slice).transpose())).mult((SimpleBase)scaledFullVector));
        }
        return (SimpleMatrix)deltaTensor.plus((SimpleBase)WTDeltaNoBias);
    }

    private SimpleTensor getTensorGradient(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector) {
        int size = deltaFull.getNumElements();
        SimpleTensor Wt_df = new SimpleTensor(size * 2, size * 2, size);
        SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
        for (int slice = 0; slice < size; ++slice) {
            Wt_df.setSlice(slice, (SimpleMatrix)((SimpleMatrix)fullVector.scale(deltaFull.get(slice))).mult(fullVector.transpose()));
        }
        return Wt_df;
    }

    public void forwardPropagateTree(Tree tree) {
        SimpleMatrix nodeVector = null;
        SimpleMatrix classification = null;
        if (tree.isLeaf()) {
            throw new AssertionError((Object)"We should not have reached leaves in forwardPropagate");
        }
        if (tree.isPreTerminal()) {
            classification = this.model.getUnaryClassification(tree.label().value());
            String word = tree.children()[0].label().value();
            SimpleMatrix wordVector = this.model.getWordVector(word);
            nodeVector = NeuralUtils.elementwiseApplyTanh(wordVector);
        } else {
            if (tree.children().length == 1) {
                throw new AssertionError((Object)"Non-preterminal nodes of size 1 should have already been collapsed");
            }
            if (tree.children().length == 2) {
                this.forwardPropagateTree(tree.children()[0]);
                this.forwardPropagateTree(tree.children()[1]);
                String leftCategory = tree.children()[0].label().value();
                String rightCategory = tree.children()[1].label().value();
                SimpleMatrix W = this.model.getBinaryTransform(leftCategory, rightCategory);
                classification = this.model.getBinaryClassification(leftCategory, rightCategory);
                SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
                SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
                SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
                if (this.model.op.useTensors) {
                    SimpleTensor tensor = this.model.getBinaryTensor(leftCategory, rightCategory);
                    SimpleMatrix tensorIn = NeuralUtils.concatenate(leftVector, rightVector);
                    SimpleMatrix tensorOut = tensor.bilinearProducts(tensorIn);
                    nodeVector = NeuralUtils.elementwiseApplyTanh((SimpleMatrix)((SimpleMatrix)W.mult((SimpleBase)childrenVector)).plus((SimpleBase)tensorOut));
                } else {
                    nodeVector = NeuralUtils.elementwiseApplyTanh((SimpleMatrix)W.mult((SimpleBase)childrenVector));
                }
            } else {
                throw new AssertionError((Object)"Tree not correctly binarized");
            }
        }
        SimpleMatrix predictions = NeuralUtils.softmax((SimpleMatrix)classification.mult((SimpleBase)NeuralUtils.concatenateWithBias(nodeVector)));
        int index = this.getPredictedClass(predictions);
        if (!(tree.label() instanceof CoreLabel)) {
            throw new AssertionError((Object)"Expected CoreLabels in the nodes");
        }
        CoreLabel label = (CoreLabel)tree.label();
        label.set(RNNCoreAnnotations.Predictions.class, predictions);
        label.set(RNNCoreAnnotations.PredictedClass.class, index);
        label.set(RNNCoreAnnotations.NodeVector.class, nodeVector);
    }
}

