/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.distribution.NormalDistributionModel;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.MaskedParameter;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.model.TransformedParameter;
import dr.inference.model.Variable;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.SimpleMetropolizedGibbsOperator;
import dr.math.MathUtils;
import dr.math.distributions.CompoundGaussianProcess;
import dr.math.distributions.GaussianProcessRandomGenerator;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.util.Attribute;
import dr.util.Transform;
import java.util.ArrayList;
import java.util.List;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.DecompositionFactory;
import org.ejml.interfaces.decomposition.QRDecomposition;
import org.ejml.ops.CommonOps;

public class EllipticalSliceOperator
extends SimpleMetropolizedGibbsOperator
implements GibbsOperator {
    private final GaussianProcessRandomGenerator gaussianProcess;
    private static final boolean MINIMAL_EVALUATION = true;
    private double pathParameter = 1.0;
    private final Parameter variable;
    private int current;
    private boolean drawByRow;
    private boolean signalConstituentParameters;
    private double[] priorMean = null;
    private boolean center = true;
    private double bracketAngle;
    private boolean translationInvariant;
    private boolean rotationInvariant;

    public EllipticalSliceOperator(Parameter parameter, GaussianProcessRandomGenerator gaussianProcessRandomGenerator, boolean bl, boolean bl2) {
        this(parameter, gaussianProcessRandomGenerator, bl, bl2, 0.0, false, false);
    }

    public EllipticalSliceOperator(Parameter parameter, GaussianProcessRandomGenerator gaussianProcessRandomGenerator, boolean bl, boolean bl2, double d, boolean bl3, boolean bl4) {
        double[] dArray;
        int n;
        this.variable = parameter;
        this.gaussianProcess = gaussianProcessRandomGenerator;
        this.drawByRow = bl;
        this.signalConstituentParameters = bl2;
        this.bracketAngle = d;
        this.translationInvariant = bl3;
        this.rotationInvariant = bl4;
        if (d < 0.0 || d >= Math.PI * 2) {
            throw new IllegalArgumentException("Invalid bracket angle");
        }
        int n2 = parameter.getDimension();
        if (n2 != (n = (dArray = (double[])gaussianProcessRandomGenerator.nextRandom()).length)) {
            throw new IllegalArgumentException("Dimension of variable (" + n2 + ") does not match dimension of Gaussian process draw (" + n + ")");
        }
    }

    public Variable<Double> getVariable() {
        return this.variable;
    }

    private double getLogGaussianPrior() {
        return this.gaussianProcess.getLikelihood() == null ? this.gaussianProcess.logPdf(this.variable.getParameterValues()) : this.gaussianProcess.getLikelihood().getLogLikelihood();
    }

    private void unwindCompoundLikelihood(Likelihood likelihood, List<Likelihood> list) {
        if (likelihood instanceof CompoundLikelihood) {
            for (Likelihood likelihood2 : ((CompoundLikelihood)likelihood).getLikelihoods()) {
                this.unwindCompoundLikelihood(likelihood2, list);
            }
        } else {
            list.add(likelihood);
        }
    }

    private List<Likelihood> unwindCompoundLikelihood(Likelihood likelihood) {
        ArrayList<Likelihood> arrayList = new ArrayList<Likelihood>();
        this.unwindCompoundLikelihood(likelihood, arrayList);
        return arrayList;
    }

    private boolean containsGaussianProcess(Likelihood likelihood) {
        if (this.gaussianProcess instanceof CompoundGaussianProcess) {
            return ((CompoundGaussianProcess)this.gaussianProcess).contains(likelihood);
        }
        return this.gaussianProcess == likelihood;
    }

    private double evaluateDensity(Likelihood likelihood, double d) {
        double d2 = this.evaluate(likelihood, d);
        double d3 = this.getLogGaussianPrior() * d;
        return d2 - d3;
    }

    @Override
    public double doOperation(Likelihood likelihood) {
        List<Likelihood> list = this.unwindCompoundLikelihood(likelihood);
        ArrayList<Likelihood> arrayList = new ArrayList<Likelihood>();
        for (Likelihood likelihood2 : list) {
            if (this.containsGaussianProcess(likelihood2)) continue;
            arrayList.add(likelihood2);
        }
        CompoundLikelihood compoundLikelihood = new CompoundLikelihood(arrayList);
        double d = compoundLikelihood.getLogLikelihood();
        double d2 = d + MathUtils.randomLogDouble();
        this.drawFromSlice(compoundLikelihood, d2);
        return 0.0;
    }

    private double[] pointOnEllipse(double[] dArray, double[] dArray2, double d, double[] dArray3) {
        int n = dArray.length;
        double d2 = Math.cos(d);
        double d3 = Math.sin(d);
        double[] dArray4 = new double[n];
        if (dArray3 == null) {
            for (int i = 0; i < n; ++i) {
                dArray4[i] = dArray[i] * d2 + dArray2[i] * d3;
            }
        } else {
            for (int i = 0; i < n; ++i) {
                dArray4[i] = (dArray[i] - dArray3[i]) * d2 + (dArray2[i] - dArray3[i]) * d3 + dArray3[i];
            }
        }
        return dArray4;
    }

    private static void translate(double[] dArray, int n) {
        int n2;
        int n3;
        double[] dArray2 = new double[n];
        int n4 = 0;
        for (n3 = 0; n3 < dArray.length / n; ++n3) {
            n2 = 0;
            while (n2 < n) {
                int n5 = n2++;
                dArray2[n5] = dArray2[n5] + dArray[n4];
                ++n4;
            }
        }
        n3 = 0;
        while (n3 < n) {
            int n6 = n3++;
            dArray2[n6] = dArray2[n6] / (double)(dArray.length / n);
        }
        n4 = 0;
        for (n3 = 0; n3 < dArray.length / n; ++n3) {
            for (n2 = 0; n2 < n; ++n2) {
                int n7 = n4++;
                dArray[n7] = dArray[n7] - dArray2[n2];
            }
        }
    }

    private static void rotateNd(double[] dArray, int n) {
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(n, n);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                denseMatrix64F.set(i, j, dArray[j * n + i]);
            }
        }
        QRDecomposition<DenseMatrix64F> qRDecomposition = DecompositionFactory.qr(n, n);
        qRDecomposition.decompose(denseMatrix64F);
        DenseMatrix64F denseMatrix64F2 = qRDecomposition.getQ(null, true);
        DenseMatrix64F denseMatrix64F3 = qRDecomposition.getR(null, true);
        if (denseMatrix64F3.get(0, 0) < 0.0) {
            CommonOps.scale(-1.0, denseMatrix64F3);
            CommonOps.scale(-1.0, denseMatrix64F2);
        }
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(n, n);
        CommonOps.transpose(denseMatrix64F2, denseMatrix64F4);
        for (int i = 0; i < dArray.length / n; ++i) {
            WrappedVector.Raw raw = new WrappedVector.Raw(dArray, i * n, n);
            MissingOps.matrixVectorMultiple(denseMatrix64F4, raw, raw, n);
        }
    }

    private static void rotate(double[] dArray, int n) {
        EllipticalSliceOperator.rotateNd(dArray, n);
    }

    public static void transformPoint(double[] dArray, boolean bl, boolean bl2, int n) {
        if (bl) {
            EllipticalSliceOperator.translate(dArray, n);
        }
        if (bl2) {
            EllipticalSliceOperator.rotate(dArray, n);
        }
    }

    private void transformPoint(double[] dArray) {
        EllipticalSliceOperator.transformPoint(dArray, this.translationInvariant, this.rotationInvariant, 2);
    }

    private void setAllParameterValues(double[] dArray) {
        if (this.variable instanceof MatrixParameterInterface) {
            ((MatrixParameterInterface)this.variable).setAllParameterValuesQuietly(dArray, 0);
        } else {
            for (int i = 0; i < dArray.length; ++i) {
                this.variable.setParameterValueQuietly(i, dArray[i]);
            }
        }
    }

    private void setVariable(double[] dArray) {
        this.transformPoint(dArray);
        this.setAllParameterValues(dArray);
        if (this.signalConstituentParameters) {
            this.variable.fireParameterChangedEvent();
        } else {
            this.variable.fireParameterChangedEvent(-1, Variable.ChangeType.ALL_VALUES_CHANGED);
        }
    }

    private void drawFromSlice(Likelihood likelihood, double d) {
        double d2;
        Interval interval;
        double d3;
        double[] dArray = this.variable.getParameterValues();
        double[] dArray2 = (double[])this.gaussianProcess.nextRandom();
        if (this.bracketAngle == 0.0) {
            d3 = MathUtils.nextDouble() * 2.0 * Math.PI;
            interval = new Interval(d3 - Math.PI * 2, d3);
        } else {
            double d4 = -this.bracketAngle * MathUtils.nextDouble();
            d2 = d4 + this.bracketAngle;
            interval = new Interval(d4, d2);
            d3 = interval.draw();
        }
        boolean bl = false;
        while (!bl) {
            double[] dArray3 = this.pointOnEllipse(dArray, dArray2, d3, this.priorMean);
            this.setVariable(dArray3);
            d2 = this.evaluate(likelihood, this.pathParameter);
            d2 -= this.getLogGaussianPrior();
            if (d2 > d) {
                bl = true;
                continue;
            }
            interval.adjust(d3);
            d3 = interval.draw();
        }
    }

    private void drawFromSlice(CompoundLikelihood compoundLikelihood, double d) {
        double d2;
        Interval interval;
        double d3;
        double[] dArray = this.variable.getParameterValues();
        double[] dArray2 = (double[])this.gaussianProcess.nextRandom();
        if (this.bracketAngle == 0.0) {
            d3 = MathUtils.nextDouble() * 2.0 * Math.PI;
            interval = new Interval(d3 - Math.PI * 2, d3);
        } else {
            double d4 = -this.bracketAngle * MathUtils.nextDouble();
            d2 = d4 + this.bracketAngle;
            interval = new Interval(d4, d2);
            d3 = interval.draw();
        }
        boolean bl = false;
        while (!bl) {
            double[] dArray3 = this.pointOnEllipse(dArray, dArray2, d3, this.priorMean);
            this.setVariable(dArray3);
            d2 = compoundLikelihood.getLogLikelihood();
            if (d2 > d) {
                bl = true;
                continue;
            }
            interval.adjust(d3);
            d3 = interval.draw();
        }
    }

    @Override
    public int getStepCount() {
        return 1;
    }

    @Override
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }

    @Override
    public String getOperatorName() {
        return "ellipticalSliceSampler";
    }

    public static void main(String[] stringArray) {
        int n;
        Parameter.Default default_ = new Parameter.Default(new double[]{1.0, 0.0});
        MaskedParameter maskedParameter = new MaskedParameter(default_, new Parameter.Default(new double[]{1.0, 0.0}), true);
        TransformedParameter transformedParameter = new TransformedParameter(new MaskedParameter(default_, new Parameter.Default(new double[]{0.0, 1.0}), true), new Transform.LogTransform(), true);
        NormalDistributionModel normalDistributionModel = new NormalDistributionModel(maskedParameter, transformedParameter, true);
        DistributionLikelihood distributionLikelihood = new DistributionLikelihood(normalDistributionModel);
        MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(new double[]{0.0, 0.0}, new double[][]{{0.001, 0.0}, {0.0, 0.001}});
        MultivariateDistributionLikelihood multivariateDistributionLikelihood = new MultivariateDistributionLikelihood(multivariateNormalDistribution);
        multivariateDistributionLikelihood.addData(default_);
        distributionLikelihood.addData(new Attribute.Default<double[]>("Data", new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}));
        ArrayList<Likelihood> arrayList = new ArrayList<Likelihood>();
        arrayList.add(distributionLikelihood);
        arrayList.add(multivariateDistributionLikelihood);
        CompoundLikelihood compoundLikelihood = new CompoundLikelihood(0, arrayList);
        EllipticalSliceOperator ellipticalSliceOperator = new EllipticalSliceOperator(default_, multivariateNormalDistribution, false, true);
        int n2 = default_.getDimension();
        double[] dArray = new double[n2];
        double[] dArray2 = new double[n2];
        Parameter[] parameterArray = new Parameter[n2];
        parameterArray[0] = maskedParameter;
        parameterArray[1] = transformedParameter;
        for (n = 0; n < 100000; ++n) {
            ellipticalSliceOperator.doOperation(compoundLikelihood);
            int n3 = 0;
            while (n3 < n2) {
                double d = (Double)parameterArray[n3].getValue(0);
                int n4 = n3;
                dArray[n4] = dArray[n4] + d;
                int n5 = n3++;
                dArray2[n5] = dArray2[n5] + d * d;
            }
        }
        for (n = 0; n < n2; ++n) {
            int n6 = n;
            dArray[n6] = dArray[n6] / 100000.0;
            int n7 = n;
            dArray2[n7] = dArray2[n7] / 100000.0;
            int n8 = n;
            dArray2[n8] = dArray2[n8] - dArray[n] * dArray[n];
        }
        System.out.println("E(x)\tStErr(x)");
        for (n = 0; n < n2; ++n) {
            System.out.println(dArray[n] + " " + Math.sqrt(dArray2[n]));
        }
    }

    private class Interval {
        double lower;
        double upper;

        Interval(double d, double d2) {
            this.lower = d;
            this.upper = d2;
        }

        void adjust(double d) {
            if (d > 0.0) {
                this.upper = d;
            } else if (d < 0.0) {
                this.lower = d;
            } else {
                throw new RuntimeException("Shrunk to current position; bad.");
            }
        }

        double draw() {
            return MathUtils.nextDouble() * (this.upper - this.lower) + this.lower;
        }
    }
}

