/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.tree;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import dr.math.ConjugateDirectionSearch;
import dr.math.MultivariateFunction;
import dr.math.MultivariateMinimum;

public class RateSmoothingTree
extends SimpleTree {
    private MultivariateFunction nonParametricRateSmoothing = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < RateSmoothingTree.this.getInternalNodeCount(); ++i) {
                ((RateSmoothingTree)RateSmoothingTree.this).nodeValues[i] = dArray[i];
            }
            RateSmoothingTree.this.setNodeHeightsFromValues(RateSmoothingTree.this.getRoot());
            if (RateSmoothingTree.this.optimizeMu) {
                RateSmoothingTree.this.mu = dArray[RateSmoothingTree.this.muIndex];
            }
            double d = RateSmoothingTree.this.getSumOfRates();
            return d;
        }

        @Override
        public int getNumArguments() {
            if (RateSmoothingTree.this.optimizeMu) {
                return RateSmoothingTree.this.getInternalNodeCount() + 1;
            }
            return RateSmoothingTree.this.getInternalNodeCount();
        }

        @Override
        public double getLowerBound(int n) {
            if (RateSmoothingTree.this.optimizeMu && n == RateSmoothingTree.this.muIndex) {
                return Double.MIN_VALUE;
            }
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            if (RateSmoothingTree.this.optimizeMu && n == RateSmoothingTree.this.muIndex) {
                return Double.MAX_VALUE;
            }
            return Double.MAX_VALUE;
        }
    };
    private int nodeCount;
    private double[] nodeValues;
    private Tree sourceTree;
    private double mu;
    private double sumDist;
    private double sumTime;
    private boolean optimizeMu;
    private int muIndex;

    public RateSmoothingTree(Tree tree) {
        super(tree);
        this.sourceTree = tree;
        this.mu = 1.0;
        this.optimizeMu = false;
    }

    public RateSmoothingTree(Tree tree, double d) {
        this.sourceTree = tree;
        this.mu = d;
        this.optimizeMu = false;
    }

    public double getMu() {
        return this.mu;
    }

    public void smoothRates() {
        int n = this.nodeCount = this.getInternalNodeCount();
        if (this.optimizeMu) {
            ++n;
            this.muIndex = this.nodeCount;
        }
        ConjugateDirectionSearch conjugateDirectionSearch = new ConjugateDirectionSearch();
        this.nodeValues = new double[this.nodeCount];
        double[] dArray = new double[n];
        for (int i = 0; i < this.nodeCount; ++i) {
            dArray[i] = 1.0;
        }
        if (this.optimizeMu) {
            dArray[this.muIndex] = this.mu;
        }
        ((MultivariateMinimum)conjugateDirectionSearch).optimize(this.nonParametricRateSmoothing, dArray, 1.0E-8, 1.0E-8);
    }

    public double getSumOfRates() {
        double[] dArray = new double[]{0.0};
        NodeRef nodeRef = this.getRoot();
        if (this.getChildCount(nodeRef) != 2) {
            throw new IllegalArgumentException("The tree must have a bifurcating root node");
        }
        this.sumDist = 0.0;
        this.sumTime = 0.0;
        double d = this.sumScoreAtNode(this.getChild(nodeRef, 0), dArray);
        double d2 = this.sumScoreAtNode(this.getChild(nodeRef, 1), dArray);
        this.mu = this.sumDist / this.sumTime;
        double d3 = d2 - d;
        dArray[0] = dArray[0] + d3 * d3;
        return dArray[0];
    }

    private double sumScoreAtNode(NodeRef nodeRef, double[] dArray) {
        double d = this.getRateAtNode(nodeRef);
        if (!this.isExternal(nodeRef)) {
            for (int i = 0; i < this.getChildCount(nodeRef); ++i) {
                double d2 = this.sumScoreAtNode(this.getChild(nodeRef, i), dArray);
                double d3 = d - d2;
                dArray[0] = dArray[0] + d3 * d3;
            }
        }
        return d;
    }

    private double getRateAtNode(NodeRef nodeRef) {
        double d = this.getNodeHeight(this.getParent(nodeRef)) - this.getNodeHeight(nodeRef);
        double d2 = this.sourceTree.getBranchLength(this.sourceTree.getNode(nodeRef.getNumber()));
        double d3 = d == 0.0 ? (d2 == 0.0 ? 1.0 : Double.MIN_VALUE) : d2 / d;
        this.sumDist += d2;
        this.sumTime += d;
        this.setNodeRate(nodeRef, d3);
        return d3;
    }

    private double setNodeHeightsFromValues(NodeRef nodeRef) {
        if (!this.isExternal(nodeRef)) {
            double d = this.setNodeHeightsFromValues(this.getChild(nodeRef, 0));
            for (int i = 1; i < this.getChildCount(nodeRef); ++i) {
                double d2 = this.setNodeHeightsFromValues(this.getChild(nodeRef, i));
                if (!(d2 > d)) continue;
                d = d2;
            }
            this.setNodeHeight(nodeRef, d + this.nodeValues[nodeRef.getNumber() - this.getExternalNodeCount()]);
        }
        return this.getNodeHeight(nodeRef);
    }
}

