/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch;
import dr.evomodel.treedatalikelihood.preorder.BranchConditionalDistributionDelegate;
import dr.evomodel.treedatalikelihood.preorder.BranchSufficientStatistics;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.List;

public class BranchSpecificOptimaGradient
implements GradientWrtParameterProvider,
Reportable {
    private TreeDataLikelihood treeDataLikelihood;
    private ContinuousTraitGradientForBranch branchProvider;
    private final TreeTrait<List<BranchSufficientStatistics>> treeTraitProvider;
    private final int numBranches;
    private final int numTraits;
    private final Tree tree;
    private final int dimension;
    private Parameter parameter;
    private MultivariateFunction numeric1 = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < BranchSpecificOptimaGradient.this.dimension; ++i) {
                BranchSpecificOptimaGradient.this.parameter.setParameterValue(i, dArray[i]);
            }
            BranchSpecificOptimaGradient.this.treeDataLikelihood.makeDirty();
            return BranchSpecificOptimaGradient.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return BranchSpecificOptimaGradient.this.dimension;
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };

    public BranchSpecificOptimaGradient(String string, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, ContinuousTraitGradientForBranch continuousTraitGradientForBranch, CompoundParameter compoundParameter) {
        TreeTrait treeTrait;
        this.branchProvider = continuousTraitGradientForBranch;
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.numTraits = compoundParameter.getParameterCount();
        this.numBranches = treeDataLikelihood.getTree().getNodeCount() - 1;
        this.parameter = compoundParameter;
        this.dimension = this.parameter.getDimension();
        String string2 = BranchConditionalDistributionDelegate.getName(string);
        if (treeDataLikelihood.getTreeTrait(string2) == null) {
            continuousDataLikelihoodDelegate.addBranchConditionalDensityTrait(string);
        }
        this.treeTraitProvider = treeTrait = treeDataLikelihood.getTreeTrait(string2);
        assert (this.treeTraitProvider != null);
        this.getGradientLogDensity();
    }

    @Override
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.dimension;
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = new double[this.numBranches * this.numTraits];
        double[] dArray2 = new double[this.numTraits];
        int n = 0;
        for (int i = 0; i < this.numBranches; ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            List<BranchSufficientStatistics> list = this.treeTraitProvider.getTrait(this.tree, nodeRef);
            dArray2 = this.branchProvider.getGradientForBranch(list.get(0), nodeRef);
            for (int j = 0; j < this.numTraits; ++j) {
                dArray[i + this.numBranches * j] = dArray2[j];
                ++n;
            }
        }
        return dArray;
    }

    public double[] getNumericalGradient() {
        double[] dArray = this.parameter.getParameterValues();
        double[] dArray2 = NumericalDerivative.gradient(this.numeric1, this.parameter.getParameterValues());
        for (int i = 0; i < this.dimension; ++i) {
            this.parameter.setParameterValue(i, dArray[i]);
        }
        return dArray2;
    }

    @Override
    public String getReport() {
        double[] dArray = this.getNumericalGradient();
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("peeling: ").append(new Vector(this.getGradientLogDensity()));
        stringBuilder.append("\n");
        stringBuilder.append("numeric: ").append(new Vector(dArray));
        stringBuilder.append("\n");
        return stringBuilder.toString();
    }
}

