/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.speciation;

import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import dr.evolution.util.Units;
import dr.evomodel.speciation.SpeciationModel;
import dr.evomodel.speciation.SpeciationModelGradientProvider;
import dr.inference.model.Parameter;
import java.util.Arrays;
import java.util.Set;

public class BirthDeathEpisodicSeriallySampledModel
extends SpeciationModel
implements SpeciationModelGradientProvider {
    Parameter samplingFractionAtPresent;
    Parameter birthRate;
    Parameter deathRate;
    Parameter serialSamplingRate;
    Parameter samplingProbability;
    Parameter treatmentProbability;
    Parameter originTime;
    private boolean conditionOnSurvival;
    private boolean birthRateChanges = false;
    private boolean deathRateChanges = false;
    private boolean serialSamplingRateChanges = false;
    private boolean treatmentChanges = false;
    private boolean intensiveSamplingOnlyAtPresent = false;
    private boolean noIntensiveSampling = true;
    boolean computedBCurrent;
    private double[][] partialBCurrentPartialAll;
    private double[][] partialPPreviousPartialAll;
    private double[][] partialPCurrentPartialAll;
    double absTol = 1.0E-8;
    int numIntervals = 1;
    double gridEnd;
    double[] intervalTimes;
    protected double[] piMinus1;
    protected double[] Ai;
    protected double[] Bi;

    public BirthDeathEpisodicSeriallySampledModel(Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, Parameter parameter6, int n, double d, Units.Type type) {
        this("birthDeathEpisodicSeriallySampledModel", parameter, parameter2, parameter3, parameter4, parameter5, parameter6, n, d, type);
    }

    @Override
    public SpeciationModelGradientProvider getProvider() {
        return this;
    }

    public BirthDeathEpisodicSeriallySampledModel(String string, Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, Parameter parameter6, int n, double d, Units.Type type) {
        super(string, type);
        this.numIntervals = n;
        this.gridEnd = d;
        this.setupTimeline();
        if (parameter.getSize() != 1 && parameter.getSize() != n) {
            throw new RuntimeException("Length of birthRate parameter should be one or equal to the size of time parameter (size = " + n + ")");
        }
        if (parameter2.getSize() != 1 && parameter2.getSize() != n) {
            throw new RuntimeException("Length of deathRate parameter should be one or equal to the size of time parameter (size = " + n + ")");
        }
        if (parameter3.getSize() != 1 && parameter3.getSize() != n) {
            throw new RuntimeException("Length of serialSamplingRate parameter should be one or equal to the size of time parameter (size = " + n + ")");
        }
        if (parameter4.getSize() != 1 && parameter4.getSize() != n) {
            throw new RuntimeException("Length of r parameter should be one or equal to the size of time parameter (size = " + n + ")");
        }
        if (parameter5.getSize() != 1 && parameter5.getSize() != n) {
            throw new RuntimeException("Length of samplingProbability parameter should be one or equal to the size of time parameter (size = " + n + ")");
        }
        if (parameter.getSize() > 1) {
            this.birthRateChanges = true;
        }
        if (parameter2.getSize() > 1) {
            this.deathRateChanges = true;
        }
        if (parameter3.getSize() > 1) {
            this.serialSamplingRateChanges = true;
        }
        if (parameter4.getSize() > 1) {
            this.treatmentChanges = true;
        }
        if (parameter5.getSize() > 1) {
            this.intensiveSamplingOnlyAtPresent = false;
        }
        for (int i = 0; i < parameter5.getSize(); ++i) {
            if (!((Double)parameter5.getValue(i) > Double.MIN_VALUE)) continue;
            this.noIntensiveSampling = false;
            break;
        }
        this.birthRate = parameter;
        this.addVariable(parameter);
        parameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, parameter.getSize()));
        this.deathRate = parameter2;
        this.addVariable(parameter2);
        parameter2.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, parameter2.getSize()));
        this.serialSamplingRate = parameter3;
        this.addVariable(parameter3);
        parameter2.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, parameter3.getSize()));
        this.treatmentProbability = parameter4;
        this.addVariable(parameter4);
        parameter5.addBounds(new Parameter.DefaultBounds(1.0, 0.0, parameter4.getSize()));
        this.samplingProbability = parameter5;
        this.addVariable(parameter5);
        parameter5.addBounds(new Parameter.DefaultBounds(1.0, 0.0, parameter5.getSize()));
        this.originTime = parameter6;
        this.addVariable(parameter6);
        parameter6.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, parameter6.getSize()));
    }

    private void setupTimeline() {
        if (this.intervalTimes == null) {
            this.intervalTimes = new double[this.numIntervals];
        } else {
            Arrays.fill(this.intervalTimes, 0.0);
        }
        for (int i = 0; i <= this.numIntervals - 1; ++i) {
            this.intervalTimes[i] = (double)(i + 1) * (this.gridEnd / (double)this.numIntervals);
        }
    }

    public double Ai(double d, double d2, double d3) {
        return Math.sqrt(Math.pow(d - d2 - d3, 2.0) + 4.0 * d * d3);
    }

    public double Bi(double d, double d2, double d3, double d4, double d5, double d6) {
        return ((1.0 - 2.0 * (1.0 - d4) * d6) * d + d2 + d3) / d5;
    }

    public double p(int n, double d) {
        double d2 = this.birthRate(n);
        double d3 = this.deathRate(n);
        double d4 = this.serialSamplingRate(n);
        double d5 = this.Ai[n];
        double d6 = this.Bi[n];
        double d7 = n == 0 ? 0.0 : this.intervalTimes[n - 1];
        double d8 = Math.exp(d5 * (d - d7));
        double d9 = d8 * (1.0 + d6) - (1.0 - d6);
        double d10 = d8 * (1.0 + d6) + (1.0 - d6);
        return (d2 + d3 + d4 - d5 * d9 / d10) / (2.0 * d2);
    }

    public double p(int n, double d, double d2, double d3, double d4, double d5, double d6) {
        double d7 = n == 0 ? 0.0 : this.intervalTimes[n - 1];
        double d8 = Math.exp(d5 * (d - d7));
        double d9 = d8 * (1.0 + d6) - (1.0 - d6);
        double d10 = d8 * (1.0 + d6) + (1.0 - d6);
        return (d2 + d3 + d4 - d5 * d9 / d10) / (2.0 * d2);
    }

    public double q(int n, double d) {
        double d2 = n == 0 ? 0.0 : this.intervalTimes[n - 1];
        double d3 = Math.exp(this.Ai[n] * (d - d2));
        double d4 = d3 * (1.0 + this.Bi[n]) + (1.0 - this.Bi[n]);
        return 4.0 * d3 / Math.pow(d4, 2.0);
    }

    public double logq(int n, double d) {
        double d2 = n == 0 ? 0.0 : this.intervalTimes[n - 1];
        double d3 = this.Ai[n] * (d - d2);
        double d4 = Math.exp(d3);
        double d5 = d4 * (1.0 + this.Bi[n]) + (1.0 - this.Bi[n]);
        return d3 + Math.log(4.0) - 2.0 * Math.log(d5);
    }

    public double birthRate(int n) {
        return (Double)this.birthRate.getValue(this.birthRateChanges ? n : 0);
    }

    public double deathRate(int n) {
        return (Double)this.deathRate.getValue(this.deathRateChanges ? n : 0);
    }

    public double serialSamplingRate(int n) {
        return (Double)this.serialSamplingRate.getValue(this.serialSamplingRateChanges ? n : 0);
    }

    public double r(int n) {
        return (Double)this.treatmentProbability.getValue(this.treatmentChanges ? n : 0);
    }

    public double samplingProbability(int n) {
        if (this.intensiveSamplingOnlyAtPresent) {
            if (n == 0) {
                return (Double)this.samplingProbability.getValue(0);
            }
            return 0.0;
        }
        return (Double)this.samplingProbability.getValue(n);
    }

    private double lambda(int n) {
        return (Double)this.birthRate.getValue(this.birthRateChanges ? n : 0);
    }

    private double mu(int n) {
        return (Double)this.deathRate.getValue(this.deathRateChanges ? n : 0);
    }

    private double psi(int n) {
        return (Double)this.serialSamplingRate.getValue(this.serialSamplingRateChanges ? n : 0);
    }

    private double rho(int n) {
        if (this.intensiveSamplingOnlyAtPresent) {
            if (n == 0) {
                return (Double)this.samplingProbability.getValue(0);
            }
            return 0.0;
        }
        return (Double)this.samplingProbability.getValue(n);
    }

    public void updateModelValues(int n) {
        int n2;
        double d = (Double)this.originTime.getValue(0);
        this.Ai = new double[this.numIntervals];
        this.Bi = new double[this.numIntervals];
        this.piMinus1 = new double[this.numIntervals];
        for (n2 = 0; n2 < this.numIntervals; ++n2) {
            this.Ai[n2] = this.Ai(this.birthRate(n2), this.deathRate(n2), this.serialSamplingRate(n2));
        }
        this.piMinus1[0] = 1.0;
        this.Bi[0] = this.Bi(this.birthRate(0), this.deathRate(0), this.serialSamplingRate(0), this.samplingProbability(0), this.Ai[0], this.piMinus1[0]);
        for (n2 = 1; n2 < this.numIntervals; ++n2) {
            this.piMinus1[n2] = this.p(n2 - 1, this.intervalTimes[n2 - 1]);
            this.Bi[n2] = this.Bi(this.birthRate(n2), this.deathRate(n2), this.serialSamplingRate(n2), this.samplingProbability(n2), this.Ai[n2], this.piMinus1[n2]);
        }
        if (this.partialBCurrentPartialAll == null) {
            this.partialPCurrentPartialAll = new double[this.numIntervals][4];
            this.partialBCurrentPartialAll = new double[this.numIntervals][4];
        }
        this.computedBCurrent = false;
        this.partialPPreviousPartialAll = this.partialPCurrentPartialAll;
        for (n2 = 0; n2 <= n; ++n2) {
            this.partialPCurrentPartialAll[n2] = this.partialPpartialAll(n, n2, this.intervalTimes[n]);
        }
    }

    @Override
    public final double calculateTreeLogLikelihood(Tree tree) {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public double calculateTreeLogLikelihood(Tree tree, Set<Taxon> set) {
        if (set.size() == 0) {
            return this.calculateTreeLogLikelihood(tree);
        }
        throw new RuntimeException("Not implemented!");
    }

    @Override
    public double[] getBreakPoints() {
        return this.intervalTimes;
    }

    @Override
    public double processModelSegmentBreakPoint(int n, double d, double d2, int n2) {
        return (double)n2 * (this.logq(n, d2) - this.logq(n, d));
    }

    @Override
    public double processInterval(int n, double d, double d2, int n2) {
        return (double)n2 * (this.logq(n, d2) - this.logq(n, d));
    }

    @Override
    public double processOrigin(int n, double d) {
        return this.logq(n, (Double)this.originTime.getValue(0)) - this.logq(n, d);
    }

    @Override
    public double processCoalescence(int n, double d) {
        return Math.log((Double)this.birthRate.getValue(n));
    }

    @Override
    public double processSampling(int n, double d) {
        boolean bl;
        double d2 = 0.0;
        double d3 = Double.MIN_VALUE;
        boolean bl2 = d < d3;
        boolean bl3 = this.samplingProbability(0) >= Double.MIN_VALUE;
        boolean bl4 = Math.abs(d - this.intervalTimes[n]) < d3;
        boolean bl5 = bl = this.samplingProbability(n) >= Double.MIN_VALUE;
        if (bl2 && bl3) {
            d2 = Math.log(this.samplingProbability(0));
        } else if (bl4 && bl) {
            d2 = Math.log(this.samplingProbability(n + 1));
        } else {
            double d4 = Math.log(this.serialSamplingRate(n));
            double d5 = (Double)this.treatmentProbability.getValue(n);
            d2 = d4 + Math.log(d5 + (1.0 - d5) * this.p(n, d));
        }
        return d2;
    }

    private double ti(int n) {
        return n == 0 ? 0.0 : this.intervalTimes[n - 1];
    }

    @Override
    public double logConditioningProbability(int n) {
        return 0.0;
    }

    private double partialApartialLambda(int n) {
        return (this.lambda(n) - this.mu(n) + this.psi(n)) / this.Ai[n];
    }

    private double partialApartialMu(int n) {
        return (-this.lambda(n) + this.mu(n) + this.psi(n)) / this.Ai[n];
    }

    private double partialApartialPsi(int n) {
        return (this.lambda(n) + this.mu(n) + this.psi(n)) / this.Ai[n];
    }

    private double[] partialApartialAll(int n) {
        double[] dArray = new double[]{this.partialApartialLambda(n), this.partialApartialMu(n), this.partialApartialPsi(n), 0.0};
        return dArray;
    }

    private double[] partialBpartialAll(int n, int n2) {
        if (this.computedBCurrent) {
            return this.partialBCurrentPartialAll[n2];
        }
        double[] dArray = new double[4];
        if (n2 == n) {
            double[] dArray2 = this.partialApartialAll(n);
            double d = 1.0 - 2.0 * (1.0 - this.rho(n)) * this.piMinus1[n];
            dArray[0] = (this.Ai[n] * d - dArray2[0] * (d * this.lambda(n) + this.mu(n) + this.psi(n))) / (this.Ai[n] * this.Ai[n]);
            dArray[1] = (this.Ai[n] - dArray2[1] * (d * this.lambda(n) + this.mu(n) + this.psi(n))) / (this.Ai[n] * this.Ai[n]);
            dArray[2] = (this.Ai[n] - dArray2[2] * (d * this.lambda(n) + this.mu(n) + this.psi(n))) / (this.Ai[n] * this.Ai[n]);
            dArray[3] = 2.0 * this.lambda(n) * this.piMinus1[n] / this.Ai[n];
        } else if (n2 < n) {
            for (int i = 0; i < 4; ++i) {
                dArray[i] = -2.0 * (1.0 - this.rho(n)) * this.lambda(n) / this.Ai[n] * this.partialPPreviousPartialAll[n2][i];
            }
        }
        this.partialBCurrentPartialAll[n2] = dArray;
        if (n2 == n) {
            this.computedBCurrent = true;
        }
        return dArray;
    }

    private double[] partialPpartialAll(int n, int n2, double d) {
        double[] dArray = new double[4];
        if (n2 == n) {
            double d2 = n == 0 ? 0.0 : this.intervalTimes[n - 1];
            double d3 = Math.exp(this.Ai[n] * (d - d2)) * (1.0 + this.Bi[n]) + (1.0 - this.Bi[n]);
            double d4 = this.Ai[n] * (1.0 - 2.0 * (1.0 - this.Bi[n]) / d3);
            double[] dArray2 = this.partialApartialAll(n);
            double[] dArray3 = this.partialBpartialAll(n, n2);
            double[] dArray4 = new double[3];
            for (int i = 0; i < 3; ++i) {
                dArray4[i] = Math.exp(this.Ai[n] * (d - d2)) * (1.0 + this.Bi[n]) * dArray2[i] * (d - d2) + (Math.exp(this.Ai[n] * (d - d2)) - 1.0) * dArray3[i];
            }
            double d5 = dArray2[0] - 2.0 * (d3 * (dArray2[0] * (1.0 - this.Bi[n]) - dArray3[0] * this.Ai[n]) - (1.0 - this.Bi[n]) * dArray4[0] * this.Ai[n]) / (d3 * d3);
            dArray[0] = (-this.mu(n) - this.psi(n) - this.lambda(n) * d5 + d4) / (2.0 * this.lambda(n) * this.lambda(n));
            d5 = dArray2[1] - 2.0 * (d3 * (dArray2[1] * (1.0 - this.Bi[n]) - dArray3[1] * this.Ai[n]) - (1.0 - this.Bi[n]) * dArray4[1] * this.Ai[n]) / (d3 * d3);
            dArray[1] = (1.0 - d5) / (2.0 * this.lambda(n));
            d5 = dArray2[2] - 2.0 * (d3 * (dArray2[2] * (1.0 - this.Bi[n]) - dArray3[2] * this.Ai[n]) - (1.0 - this.Bi[n]) * dArray4[2] * this.Ai[n]) / (d3 * d3);
            dArray[2] = (1.0 - d5) / (2.0 * this.lambda(n));
            dArray[3] = -this.Ai[n] / this.lambda(n) * ((1.0 - this.Bi[n]) * (Math.exp(this.Ai[n] * (d - d2)) - 1.0) + d3) * dArray3[3] / Math.pow(d3, 2.0);
        } else if (n2 < n) {
            double d6 = n == 0 ? 0.0 : this.intervalTimes[n - 1];
            double[] dArray5 = this.partialBpartialAll(n, n2);
            double d7 = Math.exp(this.Ai[n] * (d - d6)) * (1.0 + this.Bi[n]) + (1.0 - this.Bi[n]);
            for (int i = 0; i < 4; ++i) {
                dArray[i] = -this.Ai[n] / this.lambda(n) * ((1.0 - this.Bi[n]) * (Math.exp(this.Ai[n] * (d - d6)) - 1.0) + d7) * dArray5[i] / Math.pow(d7, 2.0);
            }
        }
        return dArray;
    }

    private double[] partialqpartialAll(int n, int n2, double d) {
        double[] dArray = new double[4];
        if (n < n2) {
            return dArray;
        }
        double[] dArray2 = this.partialApartialAll(n);
        double[] dArray3 = this.partialBpartialAll(n, n2);
        double d2 = n == 0 ? 0.0 : this.intervalTimes[n - 1];
        double d3 = Math.exp(this.Ai[n] * (d - d2)) * (1.0 + this.Bi[n]) + (1.0 - this.Bi[n]);
        double d4 = Math.exp(this.Ai[n] * (d - d2));
        for (int i = 0; i < 3; ++i) {
            double d5 = n == n2 ? (d - d2) * dArray2[i] * (d3 / 2.0 - d4 * (1.0 + this.Bi[n])) : 0.0;
            dArray[i] = 8.0 * d4 * (d5 - dArray3[i] * (d4 - 1.0)) / Math.pow(d3, 3.0);
        }
        dArray[3] = -8.0 * d4 * dArray3[3] * (d4 - 1.0) / Math.pow(d3, 3.0);
        return dArray;
    }

    @Override
    public Parameter getSamplingProbabilityParameter() {
        return this.samplingProbability;
    }

    @Override
    public Parameter getDeathRateParameter() {
        return this.deathRate;
    }

    @Override
    public Parameter getBirthRateParameter() {
        return this.birthRate;
    }

    @Override
    public Parameter getSamplingRateParameter() {
        return this.serialSamplingRate;
    }

    @Override
    public Parameter getTreatmentProbabilityParameter() {
        return this.treatmentProbability;
    }

    @Override
    public void precomputeGradientConstants() {
        this.partialPPreviousPartialAll = new double[this.numIntervals][4];
        this.partialPCurrentPartialAll = new double[this.numIntervals][4];
        this.partialBCurrentPartialAll = new double[this.numIntervals][4];
        this.updateModelValues(0);
    }

    @Override
    public void processGradientModelSegmentBreakPoint(double[] dArray, int n, double d, double d2, int n2) {
        for (int i = 0; i <= n; ++i) {
            double[] dArray2 = this.partialqpartialAll(n, i, d);
            double[] dArray3 = this.partialqpartialAll(n, i, d2);
            double d3 = this.q(n, d);
            double d4 = this.q(n, d2);
            for (int j = 0; j < 4; ++j) {
                int n3 = i * 5 + j;
                dArray[n3] = dArray[n3] + (double)n2 * (dArray3[j] / d4 - dArray2[j] / d3);
            }
        }
    }

    @Override
    public void processGradientInterval(double[] dArray, int n, double d, double d2, int n2) {
        for (int i = 0; i <= n; ++i) {
            double[] dArray2 = this.partialqpartialAll(n, i, d);
            double[] dArray3 = this.partialqpartialAll(n, i, d2);
            double d3 = this.q(n, d);
            double d4 = this.q(n, d2);
            for (int j = 0; j < 4; ++j) {
                int n3 = i * 5 + j;
                dArray[n3] = dArray[n3] + (double)n2 * (dArray3[j] / d4 - dArray2[j] / d3);
            }
        }
    }

    @Override
    public void processGradientOrigin(double[] dArray, int n, double d) {
        double d2 = (Double)this.originTime.getValue(0);
        for (int i = 0; i <= n; ++i) {
            double[] dArray2 = this.partialqpartialAll(n, i, d2);
            double[] dArray3 = this.partialqpartialAll(n, i, d);
            double d3 = this.q(n, d2);
            double d4 = this.q(n, d);
            for (int j = 0; j < 4; ++j) {
                int n2 = i * 5 + j;
                dArray[n2] = dArray[n2] + (dArray2[j] / d3 - dArray3[j] / d4);
            }
        }
    }

    @Override
    public void processGradientCoalescence(double[] dArray, int n, double d) {
        int n2 = n * 5;
        dArray[n2] = dArray[n2] + 1.0 / this.lambda(n);
    }

    @Override
    public void processGradientSampling(double[] dArray, int n, double d) {
        boolean bl;
        double d2 = 0.0;
        double d3 = Double.MIN_VALUE;
        boolean bl2 = d < d3;
        boolean bl3 = this.samplingProbability(0) >= Double.MIN_VALUE;
        boolean bl4 = Math.abs(d - this.intervalTimes[n]) < d3;
        boolean bl5 = bl = this.samplingProbability(n) >= Double.MIN_VALUE;
        if (bl2 && bl3) {
            dArray[3] = dArray[3] + 1.0 / this.rho(0);
        } else if (bl4 && bl) {
            int n2 = 3 + 5 * (n + 1);
            dArray[n2] = dArray[n2] + 1.0 / this.rho(n + 1);
        } else {
            int n3 = 2 + 5 * n;
            dArray[n3] = dArray[n3] + 1.0 / this.psi(n);
            double d4 = (Double)this.treatmentProbability.getValue(n);
            double d5 = this.p(n, d);
            int n4 = 4 + 5 * n;
            dArray[n4] = dArray[n4] + (1.0 - d5) / ((1.0 - d4) * d5 + d4);
            for (int i = 0; i <= n; ++i) {
                double[] dArray2 = this.partialPpartialAll(n, i, d);
                for (int j = 0; j < 4; ++j) {
                    int n5 = j + 5 * i;
                    dArray[n5] = dArray[n5] + (1.0 - d4) / ((1.0 - d4) * d5 + d4) * dArray2[j];
                }
            }
        }
    }

    @Override
    public void logConditioningProbability(int n, double[] dArray) {
    }

    @Override
    public int getGradientLength() {
        return 5 * this.numIntervals;
    }
}

