/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel.shrinkage;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchratemodel.AutoCorrelatedBranchRatesDistribution;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.operators.shrinkage.BayesianBridgePriorSampler;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AndRule;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import dr.xml.XORRule;

public class IncrementClassifier
implements TreeTraitProvider,
Loggable {
    public static final String INCREMENT_CLASSIFIER = "incrementClassifier";
    public static final String BY_SIGN = "bySign";
    public static final String EPSILON = "epsilon";
    public static final String TARGET_PROBABILITY = "targetProbability";
    private AutoCorrelatedBranchRatesDistribution acbr;
    private DifferentiableBranchRates branchRateModel;
    private final classificationMode classifier;
    private double epsilon;
    private int dim;
    private TreeTraitProvider.Helper helper;
    private double[] classified;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(AutoCorrelatedBranchRatesDistribution.class), new XORRule(AttributeRule.newBooleanRule("bySign", false), new XORRule(AttributeRule.newDoubleRule("epsilon", false), new AndRule(AttributeRule.newDoubleRule("targetProbability", false), new ElementRule(BayesianBridgePriorSampler.class, false))))};

        @Override
        public String getParserName() {
            return IncrementClassifier.INCREMENT_CLASSIFIER;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            AutoCorrelatedBranchRatesDistribution autoCorrelatedBranchRatesDistribution = (AutoCorrelatedBranchRatesDistribution)xMLObject.getChild(AutoCorrelatedBranchRatesDistribution.class);
            double d = xMLObject.getAttribute(IncrementClassifier.EPSILON, 0.0);
            double d2 = xMLObject.getAttribute(IncrementClassifier.TARGET_PROBABILITY, 0.0);
            boolean bl = xMLObject.getAttribute(IncrementClassifier.BY_SIGN, false);
            if (bl) {
                if (d != 0.0) {
                    throw new RuntimeException("Sign classifier should use epsilon = 0.0.");
                }
            } else if (!(bl || xMLObject.hasAttribute(IncrementClassifier.EPSILON) || xMLObject.hasAttribute(IncrementClassifier.TARGET_PROBABILITY))) {
                throw new RuntimeException("Must specify epsilon or target probability when not using sign classifier.");
            }
            if (d < 0.0) {
                throw new XMLParseException("epsilon must be positive.");
            }
            if (d2 < 0.0) {
                throw new XMLParseException("target probability must be positive.");
            }
            if (d2 != 0.0) {
                BayesianBridgePriorSampler bayesianBridgePriorSampler = (BayesianBridgePriorSampler)xMLObject.getChild(BayesianBridgePriorSampler.class);
                double d3 = bayesianBridgePriorSampler.getSteps();
                if (d2 < 100.0 / d3) {
                    throw new XMLParseException("For BB prior with " + bayesianBridgePriorSampler.getSteps() + " steps, target Probability should be greater than " + 100.0 / d3);
                }
                d = bayesianBridgePriorSampler.getEpsilon(d2);
            }
            IncrementClassifier incrementClassifier = new IncrementClassifier(autoCorrelatedBranchRatesDistribution, d, bl);
            return incrementClassifier;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public String getParserDescription() {
            return "Classifies increment as 0 or 1 based on sign or, alternatively, arbitrary cutoff epsilon.";
        }

        @Override
        public Class getReturnType() {
            return IncrementClassifier.class;
        }
    };

    public IncrementClassifier(AutoCorrelatedBranchRatesDistribution autoCorrelatedBranchRatesDistribution, double d, boolean bl) {
        this.acbr = autoCorrelatedBranchRatesDistribution;
        this.branchRateModel = autoCorrelatedBranchRatesDistribution.getBranchRateModel();
        this.epsilon = d;
        this.classifier = bl ? classificationMode.BY_SIGN : classificationMode.BY_EPSILON;
        this.dim = autoCorrelatedBranchRatesDistribution.getDimension();
        this.classified = new double[this.dim];
        this.classify();
        System.out.println("EPSILON is " + d);
        this.helper = new TreeTraitProvider.Helper();
        this.setupTraits();
    }

    private void classify() {
        for (int i = 0; i < this.dim; ++i) {
            double d = this.acbr.getIncrement(i);
            this.classified[i] = 0.0;
            if (!(this.classifier.getIncrement(d) > this.epsilon)) continue;
            this.classified[i] = 1.0;
        }
    }

    private void setupTraits() {
        TreeTrait.D d = new TreeTrait.D(){

            @Override
            public String getTraitName() {
                return IncrementClassifier.INCREMENT_CLASSIFIER;
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.BRANCH;
            }

            @Override
            public Double getTrait(Tree tree, NodeRef nodeRef) {
                IncrementClassifier.this.classify();
                int n = IncrementClassifier.this.branchRateModel.getParameterIndexFromNode(nodeRef);
                return IncrementClassifier.this.classified[n];
            }

            @Override
            public boolean getLoggable() {
                return true;
            }
        };
        this.helper.addTrait(d);
    }

    @Override
    public TreeTrait[] getTreeTraits() {
        return this.helper.getTreeTraits();
    }

    @Override
    public TreeTrait getTreeTrait(String string) {
        return this.helper.getTreeTrait(string);
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[this.dim];
        this.classify();
        for (int i = 0; i < this.dim; ++i) {
            String string = "incrementClass.";
            final int n = i;
            logColumnArray[i] = new NumberColumn(string + n){

                @Override
                public double getDoubleValue() {
                    return IncrementClassifier.this.classified[n];
                }
            };
        }
        return logColumnArray;
    }

    public static enum classificationMode {
        BY_EPSILON("byEpsilon"){

            @Override
            double getIncrement(double d) {
                return Math.abs(d);
            }
        }
        ,
        BY_SIGN("bySign"){

            @Override
            double getIncrement(double d) {
                return d;
            }
        };

        private final String name;

        private classificationMode(String string2) {
            this.name = string2;
        }

        public String getName() {
            return this.name;
        }

        abstract double getIncrement(double var1);
    }
}

