/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.datatype.DataType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.treelikelihood.AncestralStateBeagleTreeLikelihood;
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.model.Statistic;
import dr.math.MathUtils;
import dr.math.UnivariateFunction;
import dr.math.UnivariateMinimum;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.Set;

public class ASRSubstitutionModelConvolutionStatistic
extends Statistic.Abstract
implements Reportable {
    private BranchRateModel branchRates;
    private final Statistic rateAncestor;
    private final Statistic rateDescendant;
    private AncestralStateBeagleTreeLikelihood asrLikelihood;
    private SubstitutionModel substitutionModelAncestor;
    private SubstitutionModel substitutionModelDescendant;
    private SubstitutionModel pairedSubstitutionModelAncestor;
    private SubstitutionModel pairedSubstitutionModelDescendant;
    private final Set<String> leafSetDescendant;
    private final Tree tree;
    private final DataType dataType;
    private final boolean bootstrap;
    private final ParametricDistributionModel prior;
    private final String name;
    private final int[] doublets;
    private final int[] doubletsTo;
    private final boolean isPartitioned;
    private final boolean partitionConditionalOnEndState;
    private final boolean doubletsAreSafe;
    private final boolean takeDistanceAsFixed;
    private final boolean anchorAtPresent;

    public ASRSubstitutionModelConvolutionStatistic(String string, AncestralStateBeagleTreeLikelihood ancestralStateBeagleTreeLikelihood, SubstitutionModel substitutionModel, SubstitutionModel substitutionModel2, int[] nArray, int[] nArray2, SubstitutionModel substitutionModel3, SubstitutionModel substitutionModel4, BranchRateModel branchRateModel, Statistic statistic, Statistic statistic2, boolean bl, boolean bl2, TaxonList taxonList, boolean bl3, ParametricDistributionModel parametricDistributionModel) throws TreeUtils.MissingTaxonException {
        this.name = string;
        this.bootstrap = bl3;
        this.prior = parametricDistributionModel;
        this.asrLikelihood = ancestralStateBeagleTreeLikelihood;
        this.substitutionModelAncestor = substitutionModel;
        this.substitutionModelDescendant = substitutionModel2;
        this.branchRates = branchRateModel;
        this.rateAncestor = statistic;
        this.rateDescendant = statistic2;
        this.takeDistanceAsFixed = bl;
        this.anchorAtPresent = bl2;
        this.dataType = substitutionModel.getFrequencyModel().getDataType();
        if (this.dataType != this.substitutionModelDescendant.getFrequencyModel().getDataType()) {
            throw new RuntimeException("Incompatible datatypes in substitution models for ASRSubstitutionModelConvolution.");
        }
        this.tree = this.asrLikelihood.getTreeModel();
        this.leafSetDescendant = taxonList != null ? TreeUtils.getLeavesForTaxa(this.tree, taxonList) : null;
        this.doublets = nArray;
        this.doubletsTo = nArray2;
        this.pairedSubstitutionModelAncestor = substitutionModel3;
        this.pairedSubstitutionModelDescendant = substitutionModel4;
        this.isPartitioned = substitutionModel3 != null;
        this.partitionConditionalOnEndState = nArray2.length > 0;
        boolean bl4 = this.doubletsAreSafe = !this.isPartitioned || !this.doubletsCanOverlap();
        if (this.isPartitioned) {
            if (nArray.length % 2 != 0) {
                throw new RuntimeException("Improperly specified doublets");
            }
            if (bl3) {
                throw new RuntimeException("Cannot currently bootstrap context-dependent models.");
            }
            if (substitutionModel3 != null && substitutionModel3 == null || substitutionModel3 == null && substitutionModel3 != null) {
                throw new RuntimeException("If specifying models for doublets must specify ancestral and descendant models.");
            }
            if (substitutionModel3.getFrequencyModel().getFrequencies().length != substitutionModel4.getFrequencyModel().getFrequencies().length) {
                throw new RuntimeException("Doublet models do not match in size.");
            }
            if (substitutionModel3.getFrequencyModel().getFrequencies().length != this.dataType.getStateCount() * this.dataType.getStateCount()) {
                throw new RuntimeException("Doublet models are not sized for doublets.");
            }
            if (this.partitionConditionalOnEndState && nArray.length != nArray2.length) {
                throw new RuntimeException("Length of doublets and doubletsTo do not match.");
            }
        }
    }

    @Override
    public int getDimension() {
        return 1;
    }

    @Override
    public String getDimensionName(int n) {
        if (this.name == null) {
            return "timeBeforeMRCA";
        }
        return this.name;
    }

    @Override
    public String getStatisticName() {
        return "name";
    }

    @Override
    public double getStatisticValue(int n) {
        NodeRef nodeRef = this.getNode(this.leafSetDescendant);
        NodeRef nodeRef2 = this.tree.getParent(nodeRef);
        if (!this.doubletsAreSafe && this.doubletsOverlapOnSequence(nodeRef2)) {
            return Double.NaN;
        }
        double d = this.getBranchTime(nodeRef);
        UnivariateMinimum univariateMinimum = this.optimizeTimes(nodeRef, nodeRef2, d);
        return (1.0 - univariateMinimum.minx) * d;
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder("asrSubstitutionModelConvolutionStatistic Report\n\n");
        stringBuilder.append("Estimated time of shift before common ancestor: ").append(this.getStatisticValue(0)).append("\n");
        stringBuilder.append("Using rates: ").append(new Vector(this.getRates(this.getNode(this.leafSetDescendant), 0.0))).append("\n");
        stringBuilder.append("Using substitution models named: ").append(this.substitutionModelAncestor.getId()).append(", ").append(this.substitutionModelDescendant.getId()).append("\n");
        stringBuilder.append("Using taxon set: ").append(this.leafSetDescendant).append("\n");
        stringBuilder.append("Using prior? ").append(this.prior != null).append("\n");
        if (this.prior != null) {
            stringBuilder.append("  Using prior of type: ").append(this.prior.getModelName()).append("\n");
            stringBuilder.append("  Using prior named: ").append(this.prior.getId()).append("\n");
        }
        stringBuilder.append("Using bootstrap? ").append(this.bootstrap).append("\n");
        if (this.isPartitioned) {
            stringBuilder.append("Using partitioned model for doublets: \n");
            for (int i = 0; i < this.doublets.length / 2; ++i) {
                stringBuilder.append("  ").append(this.dataType.getChar(this.doublets[2 * i])).append(this.dataType.getChar(this.doublets[2 * i + 1])).append("\n");
            }
            stringBuilder.append("Using doublet substitution models named: ").append(this.pairedSubstitutionModelAncestor.getId()).append(", ").append(this.pairedSubstitutionModelDescendant.getId()).append("\n");
        }
        stringBuilder.append("\n\n");
        return stringBuilder.toString();
    }

    private double getBranchTime(NodeRef nodeRef) {
        return this.tree.getNodeHeight(this.tree.getParent(nodeRef)) - this.tree.getNodeHeight(nodeRef);
    }

    private double[] getRawRates(NodeRef nodeRef) {
        double[] dArray = new double[]{this.rateAncestor != null ? this.rateAncestor.getStatisticValue(0) : this.branchRates.getBranchRate(this.tree, nodeRef), this.rateDescendant != null ? this.rateDescendant.getStatisticValue(0) : this.branchRates.getBranchRate(this.tree, nodeRef)};
        return dArray;
    }

    private double[] getRates(NodeRef nodeRef, double d) {
        double[] dArray = this.getRawRates(nodeRef);
        if (this.takeDistanceAsFixed) {
            double d2 = this.branchRates.getBranchRate(this.tree, nodeRef);
            if (this.anchorAtPresent) {
                dArray[0] = (d2 - dArray[1] * (1.0 - d)) / d;
            } else {
                dArray[1] = (d2 - dArray[0] * d) / (1.0 - d);
            }
        }
        return dArray;
    }

    private void convolveMatrices(double[] dArray, double[] dArray2, double[] dArray3, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                for (int k = 0; k < n; ++k) {
                    int n2 = i * n + j;
                    dArray3[n2] = dArray3[n2] + dArray[i * n + k] * dArray2[k * n + j];
                }
            }
        }
    }

    private boolean isPair(int n, int n2) {
        if (!this.isPartitioned) {
            return false;
        }
        boolean bl = false;
        int n3 = this.doublets.length / 2;
        for (int i = 0; i < n3; ++i) {
            if (n != this.doublets[2 * i] || n2 != this.doublets[2 * i + 1]) continue;
            bl = true;
            break;
        }
        return bl;
    }

    private boolean doPartition(int n, int n2, int n3, int n4) {
        if (!this.partitionConditionalOnEndState) {
            return this.isPair(n, n2);
        }
        if (!this.isPartitioned) {
            return false;
        }
        boolean bl = false;
        int n5 = this.doublets.length / 2;
        for (int i = 0; i < n5; ++i) {
            if (n != this.doublets[2 * i] || n2 != this.doublets[2 * i + 1] || n3 != this.doubletsTo[2 * i] || n4 != this.doubletsTo[2 * i + 1]) continue;
            bl = true;
            break;
        }
        return bl;
    }

    private int whichPair(int n, int n2) {
        if (!this.isPartitioned) {
            return -1;
        }
        int n3 = -1;
        int n4 = this.doublets.length / 2;
        for (int i = 0; i < n4; ++i) {
            if (n != this.doublets[2 * i] || n2 != this.doublets[2 * i + 1]) continue;
            n3 = i;
            break;
        }
        return n3;
    }

    private boolean doubletsCanOverlap() {
        int n = this.doublets.length / 2;
        boolean bl = false;
        block0: for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                if (this.doublets[2 * i] != this.doublets[2 * j + 1]) continue;
                bl = true;
                continue block0;
            }
        }
        return bl;
    }

    private boolean doubletsOverlapOnSequence(NodeRef nodeRef) {
        int[] nArray = this.asrLikelihood.getStatesForNode(this.tree, nodeRef);
        boolean bl = false;
        boolean bl2 = false;
        for (int i = 0; i < nArray.length - 1; ++i) {
            if (this.isPair(nArray[i], nArray[i + 1])) {
                if (bl2) {
                    bl = true;
                    break;
                }
                bl2 = true;
                continue;
            }
            bl2 = false;
        }
        return bl;
    }

    private int getDoublet(int n, int n2, int n3) {
        return n * n3 + n2;
    }

    private double computeLogLikelihood(double d, double d2, double d3, int[] nArray, int[] nArray2) {
        int n;
        int n2 = this.dataType.getStateCount();
        int n3 = n2 * n2;
        double[] dArray = new double[n3];
        double[] dArray2 = new double[n3];
        this.substitutionModelAncestor.getTransitionProbabilities(d, dArray);
        this.substitutionModelDescendant.getTransitionProbabilities(d2, dArray2);
        double[] dArray3 = new double[n2 * n2];
        this.convolveMatrices(dArray, dArray2, dArray3, n2);
        double[] dArray4 = new double[n2 * n2];
        for (int i = 0; i < n2 * n2; ++i) {
            dArray4[i] = Math.log(dArray3[i]);
        }
        double[] dArray5 = new double[n3 * n3];
        double[] dArray6 = new double[n3 * n3];
        if (this.isPartitioned) {
            double[] dArray7 = new double[n3 * n3];
            double[] dArray8 = new double[n3 * n3];
            this.pairedSubstitutionModelAncestor.getTransitionProbabilities(d * 2.0, dArray7);
            this.pairedSubstitutionModelDescendant.getTransitionProbabilities(d2 * 2.0, dArray8);
            double[] dArray9 = new double[n3 * n3];
            double[] dArray10 = new double[n3 * n3];
            double[][] dArray11 = new double[n3][n3];
            this.pairedSubstitutionModelAncestor.getInfinitesimalMatrix(dArray9);
            this.pairedSubstitutionModelDescendant.getInfinitesimalMatrix(dArray10);
            this.convolveMatrices(dArray7, dArray8, dArray5, n2);
            for (n = 0; n < n3 * n3; ++n) {
                dArray6[n] = Math.log(dArray7[n]);
            }
        }
        double d4 = 0.0;
        int n4 = nArray.length - 1;
        int n5 = 0;
        for (int i = 0; i < nArray.length; ++i) {
            if (i < n4 && this.doPartition(nArray[i], nArray[i + 1], nArray2[i], nArray2[i + 1])) {
                n = this.getDoublet(nArray[i], nArray[i + 1], n2);
                int n6 = this.getDoublet(nArray2[i], nArray2[i + 1], n2);
                d4 += dArray6[n * n3 + n6];
                ++i;
                n5 += 2;
                continue;
            }
            d4 += dArray4[nArray[i] * n2 + nArray2[i]];
        }
        if (this.prior != null) {
            d4 += this.prior.logPdf(d2 / d3);
        }
        return d4;
    }

    private NodeRef getNode(Set<String> set) {
        return set != null ? TreeUtils.getCommonAncestorNode(this.tree, set) : this.tree.getRoot();
    }

    private UnivariateMinimum optimizeTimes(final NodeRef nodeRef, NodeRef nodeRef2, final double d) {
        Object object;
        Object object2;
        final int[] nArray = this.asrLikelihood.getStatesForNode(this.tree, nodeRef2);
        final int[] nArray2 = this.asrLikelihood.getStatesForNode(this.tree, nodeRef);
        int n = nArray.length;
        if (this.bootstrap) {
            object2 = new int[n];
            object = new int[n];
            System.arraycopy(nArray, 0, object2, 0, n);
            System.arraycopy(nArray2, 0, object, 0, n);
            for (int i = 0; i < n; ++i) {
                int n2 = MathUtils.nextInt(n);
                nArray[i] = object2[n2];
                nArray2[i] = (int)object[n2];
            }
        }
        object2 = new UnivariateFunction(){

            @Override
            public double evaluate(double d5) {
                double[] dArray = ASRSubstitutionModelConvolutionStatistic.this.getRates(nodeRef, d5);
                double d2 = d5 * dArray[0] * d;
                double d3 = (1.0 - d5) * dArray[1] * d;
                double d4 = ASRSubstitutionModelConvolutionStatistic.this.computeLogLikelihood(d2, d3, dArray[1], nArray, nArray2);
                return -d4;
            }

            @Override
            public double getLowerBound() {
                double d3 = 0.0;
                if (ASRSubstitutionModelConvolutionStatistic.this.takeDistanceAsFixed && ASRSubstitutionModelConvolutionStatistic.this.anchorAtPresent) {
                    double d2 = ASRSubstitutionModelConvolutionStatistic.this.branchRates.getBranchRate(ASRSubstitutionModelConvolutionStatistic.this.tree, nodeRef);
                    double[] dArray = ASRSubstitutionModelConvolutionStatistic.this.getRawRates(nodeRef);
                    d3 = 1.0 - d2 / dArray[1];
                }
                return d3;
            }

            @Override
            public double getUpperBound() {
                double d3 = 1.0;
                if (ASRSubstitutionModelConvolutionStatistic.this.takeDistanceAsFixed && !ASRSubstitutionModelConvolutionStatistic.this.anchorAtPresent) {
                    double d2 = ASRSubstitutionModelConvolutionStatistic.this.branchRates.getBranchRate(ASRSubstitutionModelConvolutionStatistic.this.tree, nodeRef);
                    double[] dArray = ASRSubstitutionModelConvolutionStatistic.this.getRawRates(nodeRef);
                    d3 = d2 / dArray[0];
                }
                return d3;
            }
        };
        object = new UnivariateMinimum();
        ((UnivariateMinimum)object).findMinimum((UnivariateFunction)object2);
        return object;
    }
}

