/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hmc;

import dr.inference.hmc.DerivativeWrtParameterProvider;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.hmc.ParallelGradientExecutor;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.DerivativeOrder;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.ReciprocalLikelihood;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Future;

public class JointGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
DerivativeWrtParameterProvider,
Reportable {
    private final int dimension;
    private final Likelihood likelihood;
    private final Parameter parameter;
    private final ParallelGradientExecutor parallelExecutor;
    final List<GradientWrtParameterProvider> derivativeList;
    final List<DerivativeWrtParameterProvider> newDerivativeList;
    private final DerivativeOrder highestOrder;
    private static final boolean DEBUG = false;
    private static final boolean DEBUG_KILL = false;

    public JointGradient(List<GradientWrtParameterProvider> list) {
        this(list, 0);
    }

    public JointGradient(List<GradientWrtParameterProvider> list, int n) {
        this.derivativeList = list;
        GradientWrtParameterProvider gradientWrtParameterProvider = list.get(0);
        this.dimension = gradientWrtParameterProvider.getDimension();
        this.parameter = gradientWrtParameterProvider.getParameter();
        if (list.size() == 1) {
            this.likelihood = gradientWrtParameterProvider.getLikelihood();
        } else {
            ArrayList arrayList = new ArrayList();
            for (GradientWrtParameterProvider gradientWrtParameterProvider2 : list) {
                if (gradientWrtParameterProvider2.getDimension() != this.dimension) {
                    throw new RuntimeException("Unequal parameter dimensions");
                }
                if (!Arrays.equals(gradientWrtParameterProvider2.getParameter().getParameterValues(), this.parameter.getParameterValues())) {
                    throw new RuntimeException("Unequal parameter values");
                }
                Likelihood likelihood = gradientWrtParameterProvider2.getLikelihood();
                if (likelihood instanceof ReciprocalLikelihood) {
                    if (arrayList.contains(likelihood)) continue;
                    arrayList.add(likelihood);
                    continue;
                }
                for (Likelihood likelihood2 : gradientWrtParameterProvider2.getLikelihood().getLikelihoodSet()) {
                    if (arrayList.contains(likelihood2)) continue;
                    arrayList.add(likelihood2);
                }
            }
            this.likelihood = new CompoundLikelihood(arrayList);
        }
        this.newDerivativeList = new ArrayList<DerivativeWrtParameterProvider>();
        for (GradientWrtParameterProvider gradientWrtParameterProvider3 : list) {
            if (!(gradientWrtParameterProvider3 instanceof DerivativeWrtParameterProvider)) continue;
            DerivativeWrtParameterProvider derivativeWrtParameterProvider = (DerivativeWrtParameterProvider)((Object)gradientWrtParameterProvider3);
            this.newDerivativeList.add(derivativeWrtParameterProvider);
        }
        this.highestOrder = DerivativeWrtParameterProvider.getHighestOrder(this.newDerivativeList);
        this.parallelExecutor = n > 1 || n < 0 ? new ParallelGradientExecutor(n, list) : null;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension(DerivativeOrder derivativeOrder) {
        return derivativeOrder.getDerivativeDimension(this.dimension);
    }

    @Override
    public int getDimension() {
        return this.dimension;
    }

    @Override
    public double[] getDerivativeLogDensity(DerivativeOrder derivativeOrder) {
        assert (this.highestOrder.getValue() >= derivativeOrder.getValue());
        int n = this.newDerivativeList.size();
        double[] dArray = this.newDerivativeList.get(0).getDerivativeLogDensity(derivativeOrder);
        for (int i = 1; i < n; ++i) {
            double[] dArray2 = this.newDerivativeList.get(i).getDerivativeLogDensity(derivativeOrder);
            for (int j = 0; j < dArray2.length; ++j) {
                int n2 = j;
                dArray[n2] = dArray[n2] + dArray2[j];
            }
        }
        return dArray;
    }

    @Override
    public DerivativeOrder getHighestOrder() {
        return this.highestOrder;
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        return this.getDerivativeLogDensity(DerivativeType.DIAGONAL_HESSIAN);
    }

    @Override
    public double[][] getHessianLogDensity() {
        assert (this.derivativeList.get(0) instanceof HessianWrtParameterProvider);
        int n = this.derivativeList.size();
        double[][] dArray = ((HessianWrtParameterProvider)this.derivativeList.get(0)).getHessianLogDensity();
        for (int i = 1; i < n; ++i) {
            assert (this.derivativeList.get(i) instanceof HessianWrtParameterProvider);
            double[][] dArray2 = ((HessianWrtParameterProvider)this.derivativeList.get(i)).getHessianLogDensity();
            for (int j = 0; j < dArray2[0].length; ++j) {
                for (int k = 0; k < dArray2[0].length; ++k) {
                    double[] dArray3 = dArray[j];
                    int n2 = k;
                    dArray3[n2] = dArray3[n2] + dArray2[j][k];
                }
            }
        }
        return dArray;
    }

    double[] getDerivativeLogDensity(DerivativeType derivativeType) {
        if (this.parallelExecutor != null) {
            return this.getDerivativeLogDensityParallelImpl(derivativeType);
        }
        return this.getDerivativeLogDensitySerialImpl(derivativeType);
    }

    private double[] getDerivativeLogDensityParallelImpl(DerivativeType derivativeType) {
        return this.parallelExecutor.getDerivativeLogDensityInParallel(derivativeType, (list, n) -> {
            double[] dArray = new double[n];
            for (Future future : list) {
                double[] dArray2 = (double[])future.get();
                for (int i = 0; i < n; ++i) {
                    int n2 = i;
                    dArray[n2] = dArray[n2] + dArray2[i];
                }
            }
            return dArray;
        }, this.dimension);
    }

    private double[] getDerivativeLogDensitySerialImpl(DerivativeType derivativeType) {
        int n = this.derivativeList.size();
        double[] dArray = derivativeType.getDerivativeLogDensity(this.derivativeList.get(0));
        for (int i = 1; i < n; ++i) {
            double[] dArray2 = derivativeType.getDerivativeLogDensity(this.derivativeList.get(i));
            for (int j = 0; j < dArray2.length; ++j) {
                int n2 = j;
                dArray[n2] = dArray[n2] + dArray2[j];
            }
        }
        return dArray;
    }

    @Override
    public double[] getGradientLogDensity() {
        return this.getDerivativeLogDensity(DerivativeType.GRADIENT);
    }

    @Override
    public String getReport() {
        return "jointGradient." + this.parameter.getParameterName() + "\n" + GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, GradientWrtParameterProvider.TOLERANCE);
    }

    static enum DerivativeType {
        GRADIENT("gradient"){

            @Override
            public double[] getDerivativeLogDensity(GradientWrtParameterProvider gradientWrtParameterProvider) {
                return gradientWrtParameterProvider.getGradientLogDensity();
            }
        }
        ,
        DIAGONAL_HESSIAN("diagonalHessian"){

            @Override
            public double[] getDerivativeLogDensity(GradientWrtParameterProvider gradientWrtParameterProvider) {
                return ((HessianWrtParameterProvider)gradientWrtParameterProvider).getDiagonalHessianLogDensity();
            }
        };

        private final String type;

        private DerivativeType(String string2) {
            this.type = string2;
        }

        public abstract double[] getDerivativeLogDensity(GradientWrtParameterProvider var1);
    }
}

