/*
 * Decompiled with CFR 0.152.
 */
package bartMachine;

import OpenSourceExtensions.TDoubleHashSetAndArray;
import OpenSourceExtensions.UnorderedPair;
import bartMachine.Classifier;
import bartMachine.StatToolbox;
import bartMachine.Tools;
import bartMachine.bartMachine_b_hyperparams;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;

public class bartMachineTreeNode
implements Cloneable,
Serializable {
    public static final boolean DEBUG_NODES = false;
    protected static final double BAD_FLAG_double = -1.7976931348623157E308;
    protected static final int BAD_FLAG_int = -2147483647;
    private bartMachine_b_hyperparams bart;
    public bartMachineTreeNode parent;
    public bartMachineTreeNode left;
    public bartMachineTreeNode right;
    public int depth;
    public boolean isLeaf;
    public int splitAttributeM = -2147483647;
    public double splitValue = -1.7976931348623157E308;
    public boolean sendMissingDataRight;
    public double y_pred = -1.7976931348623157E308;
    public double y_avg = -1.7976931348623157E308;
    public double posterior_var = -1.7976931348623157E308;
    public double posterior_mean = -1.7976931348623157E308;
    public transient int n_eta;
    public transient double[] yhats;
    protected int[] indicies;
    protected transient double[] responses;
    private transient double sum_responses_qty_sqd;
    private transient double sum_responses_qty;
    private transient TIntArrayList possible_rule_variables;
    private transient HashMap<Integer, TDoubleHashSetAndArray> possible_split_vals_by_attr;
    protected transient Integer padj;
    private int[] attribute_split_counts;

    public bartMachineTreeNode() {
    }

    public bartMachineTreeNode(bartMachineTreeNode bartMachineTreeNode2, bartMachine_b_hyperparams bartMachine_b_hyperparams2) {
        this.parent = bartMachineTreeNode2;
        this.yhats = bartMachineTreeNode2.yhats;
        this.bart = bartMachine_b_hyperparams2;
        if (bartMachineTreeNode2 != null) {
            this.depth = bartMachineTreeNode2.depth + 1;
        }
        this.isLeaf = true;
    }

    public bartMachineTreeNode(bartMachineTreeNode bartMachineTreeNode2) {
        this(bartMachineTreeNode2, bartMachineTreeNode2.bart);
    }

    public bartMachineTreeNode(bartMachine_b_hyperparams bartMachine_b_hyperparams2) {
        this.bart = bartMachine_b_hyperparams2;
        this.isLeaf = true;
        this.depth = 0;
    }

    public bartMachineTreeNode clone() {
        bartMachineTreeNode bartMachineTreeNode2 = new bartMachineTreeNode();
        bartMachineTreeNode2.bart = this.bart;
        bartMachineTreeNode2.parent = this.parent;
        bartMachineTreeNode2.isLeaf = this.isLeaf;
        bartMachineTreeNode2.splitAttributeM = this.splitAttributeM;
        bartMachineTreeNode2.splitValue = this.splitValue;
        bartMachineTreeNode2.possible_rule_variables = this.possible_rule_variables;
        bartMachineTreeNode2.sendMissingDataRight = this.sendMissingDataRight;
        bartMachineTreeNode2.possible_split_vals_by_attr = this.possible_split_vals_by_attr;
        bartMachineTreeNode2.depth = this.depth;
        bartMachineTreeNode2.responses = this.responses;
        bartMachineTreeNode2.indicies = this.indicies;
        bartMachineTreeNode2.n_eta = this.n_eta;
        bartMachineTreeNode2.yhats = this.yhats;
        if (this.left != null) {
            bartMachineTreeNode2.left = this.left.clone();
            bartMachineTreeNode2.left.parent = bartMachineTreeNode2;
        }
        if (this.right != null) {
            bartMachineTreeNode2.right = this.right.clone();
            bartMachineTreeNode2.right.parent = bartMachineTreeNode2;
        }
        return bartMachineTreeNode2;
    }

    public double avgResponse() {
        return StatToolbox.sample_average(this.responses);
    }

    public ArrayList<bartMachineTreeNode> getTerminalNodesWithDataAboveOrEqualToN(int n) {
        ArrayList<bartMachineTreeNode> arrayList = new ArrayList<bartMachineTreeNode>();
        bartMachineTreeNode.findTerminalNodesDataAboveOrEqualToN(this, arrayList, n);
        return arrayList;
    }

    public ArrayList<bartMachineTreeNode> getTerminalNodes() {
        return this.getTerminalNodesWithDataAboveOrEqualToN(0);
    }

    private static void findTerminalNodesDataAboveOrEqualToN(bartMachineTreeNode bartMachineTreeNode2, ArrayList<bartMachineTreeNode> arrayList, int n) {
        if (bartMachineTreeNode2.isLeaf && bartMachineTreeNode2.n_eta >= n) {
            arrayList.add(bartMachineTreeNode2);
        } else if (!bartMachineTreeNode2.isLeaf) {
            if (bartMachineTreeNode2.left == null || bartMachineTreeNode2.right == null) {
                System.err.println("error node no children during findTerminalNodesDataAboveOrEqualToN id: " + bartMachineTreeNode2.stringID());
            }
            bartMachineTreeNode.findTerminalNodesDataAboveOrEqualToN(bartMachineTreeNode2.left, arrayList, n);
            bartMachineTreeNode.findTerminalNodesDataAboveOrEqualToN(bartMachineTreeNode2.right, arrayList, n);
        }
    }

    public ArrayList<bartMachineTreeNode> getPrunableAndChangeableNodes() {
        ArrayList<bartMachineTreeNode> arrayList = new ArrayList<bartMachineTreeNode>();
        bartMachineTreeNode.findPrunableAndChangeableNodes(this, arrayList);
        return arrayList;
    }

    private static void findPrunableAndChangeableNodes(bartMachineTreeNode bartMachineTreeNode2, ArrayList<bartMachineTreeNode> arrayList) {
        if (bartMachineTreeNode2.isLeaf) {
            return;
        }
        if (bartMachineTreeNode2.left.isLeaf && bartMachineTreeNode2.right.isLeaf) {
            arrayList.add(bartMachineTreeNode2);
        } else {
            bartMachineTreeNode.findPrunableAndChangeableNodes(bartMachineTreeNode2.left, arrayList);
            bartMachineTreeNode.findPrunableAndChangeableNodes(bartMachineTreeNode2.right, arrayList);
        }
    }

    public static void pruneTreeAt(bartMachineTreeNode bartMachineTreeNode2) {
        bartMachineTreeNode2.left = null;
        bartMachineTreeNode2.right = null;
        bartMachineTreeNode2.isLeaf = true;
        bartMachineTreeNode2.splitAttributeM = -2147483647;
        bartMachineTreeNode2.splitValue = -1.7976931348623157E308;
    }

    public int deepestNode() {
        int n;
        if (this.isLeaf) {
            return 0;
        }
        int n2 = this.left.deepestNode();
        if (n2 > (n = this.right.deepestNode())) {
            return 1 + n2;
        }
        return 1 + n;
    }

    public double Evaluate(double[] dArray) {
        return this.EvaluateNode((double[])dArray).y_pred;
    }

    public bartMachineTreeNode EvaluateNode(double[] dArray) {
        bartMachineTreeNode bartMachineTreeNode2 = this;
        while (!bartMachineTreeNode2.isLeaf) {
            if (Classifier.isMissing(dArray[bartMachineTreeNode2.splitAttributeM])) {
                bartMachineTreeNode2 = bartMachineTreeNode2.sendMissingDataRight ? bartMachineTreeNode2.right : bartMachineTreeNode2.left;
                continue;
            }
            if (dArray[bartMachineTreeNode2.splitAttributeM] <= bartMachineTreeNode2.splitValue) {
                bartMachineTreeNode2 = bartMachineTreeNode2.left;
                continue;
            }
            bartMachineTreeNode2 = bartMachineTreeNode2.right;
        }
        return bartMachineTreeNode2;
    }

    public void flushNodeData() {
        this.yhats = null;
        if (this.bart.flush_indices_to_save_ram) {
            this.indicies = null;
        }
        this.responses = null;
        this.possible_rule_variables = null;
        this.possible_split_vals_by_attr = null;
        if (this.left != null) {
            this.left.flushNodeData();
        }
        if (this.right != null) {
            this.right.flushNodeData();
        }
    }

    public void propagateDataByChangedRule() {
        if (this.isLeaf) {
            return;
        }
        TIntArrayList tIntArrayList = new TIntArrayList(this.n_eta);
        TIntArrayList tIntArrayList2 = new TIntArrayList(this.n_eta);
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList(this.n_eta);
        TDoubleArrayList tDoubleArrayList2 = new TDoubleArrayList(this.n_eta);
        for (int i = 0; i < this.n_eta; ++i) {
            double[] dArray = (double[])this.bart.X_y.get(this.indicies[i]);
            if (Classifier.isMissing(dArray[this.splitAttributeM])) {
                if (this.sendMissingDataRight) {
                    tIntArrayList2.add(this.indicies[i]);
                    tDoubleArrayList2.add(this.responses[i]);
                    continue;
                }
                tIntArrayList.add(this.indicies[i]);
                tDoubleArrayList.add(this.responses[i]);
                continue;
            }
            if (dArray[this.splitAttributeM] <= this.splitValue) {
                tIntArrayList.add(this.indicies[i]);
                tDoubleArrayList.add(this.responses[i]);
                continue;
            }
            tIntArrayList2.add(this.indicies[i]);
            tDoubleArrayList2.add(this.responses[i]);
        }
        this.left.n_eta = tDoubleArrayList.size();
        this.left.responses = tDoubleArrayList.toArray();
        this.left.indicies = tIntArrayList.toArray();
        this.right.n_eta = tDoubleArrayList2.size();
        this.right.responses = tDoubleArrayList2.toArray();
        this.right.indicies = tIntArrayList2.toArray();
        this.left.propagateDataByChangedRule();
        this.right.propagateDataByChangedRule();
    }

    public void updateWithNewResponsesRecursively(double[] dArray) {
        this.responses = new double[this.n_eta];
        this.sum_responses_qty_sqd = 0.0;
        this.sum_responses_qty = 0.0;
        for (int i = 0; i < this.n_eta; ++i) {
            double d;
            this.responses[i] = d = dArray[this.indicies[i]];
        }
        if (this.isLeaf) {
            return;
        }
        this.left.updateWithNewResponsesRecursively(dArray);
        this.right.updateWithNewResponsesRecursively(dArray);
    }

    public int numLeaves() {
        if (this.isLeaf) {
            return 1;
        }
        return this.left.numLeaves() + this.right.numLeaves();
    }

    public int numNodesAndLeaves() {
        if (this.isLeaf) {
            return 1;
        }
        return 1 + this.left.numNodesAndLeaves() + this.right.numNodesAndLeaves();
    }

    public int numPruneNodesAvailable() {
        if (this.isLeaf) {
            return 0;
        }
        if (this.left.isLeaf && this.right.isLeaf) {
            return 1;
        }
        return this.left.numPruneNodesAvailable() + this.right.numPruneNodesAvailable();
    }

    public double prediction_untransformed() {
        return this.y_pred == -1.7976931348623157E308 ? -1.7976931348623157E308 : this.bart.un_transform_y(this.y_pred);
    }

    public double avg_response_untransformed() {
        return this.bart.un_transform_y(this.avgResponse());
    }

    public double sumResponsesQuantitySqd() {
        if (this.sum_responses_qty_sqd == 0.0) {
            this.sum_responses_qty_sqd = Math.pow(this.sumResponses(), 2.0);
        }
        return this.sum_responses_qty_sqd;
    }

    public double sumResponses() {
        if (this.sum_responses_qty == 0.0) {
            this.sum_responses_qty = 0.0;
            for (int i = 0; i < this.n_eta; ++i) {
                this.sum_responses_qty += this.responses[i];
            }
        }
        return this.sum_responses_qty;
    }

    protected TIntArrayList predictorsThatCouldBeUsedToSplitAtNode() {
        if (this.bart.mem_cache_for_speed) {
            if (this.possible_rule_variables == null) {
                this.possible_rule_variables = this.tabulatePredictorsThatCouldBeUsedToSplitAtNode();
            }
            return this.possible_rule_variables;
        }
        return this.tabulatePredictorsThatCouldBeUsedToSplitAtNode();
    }

    private TIntArrayList tabulatePredictorsThatCouldBeUsedToSplitAtNode() {
        TIntArrayList tIntArrayList = new TIntArrayList();
        block0: for (int i = 0; i < this.bart.p; ++i) {
            double[] dArray = (double[])this.bart.X_y_by_col.get(i);
            for (int j = 1; j < this.indicies.length; ++j) {
                if (dArray[this.indicies[j - 1]] == dArray[this.indicies[j]]) continue;
                tIntArrayList.add(i);
                continue block0;
            }
        }
        return tIntArrayList;
    }

    public int nAdj() {
        return this.possibleSplitValuesGivenAttribute().size();
    }

    protected TDoubleHashSetAndArray possibleSplitValuesGivenAttribute() {
        if (this.bart.mem_cache_for_speed) {
            if (this.possible_split_vals_by_attr == null) {
                this.possible_split_vals_by_attr = new HashMap();
            }
            if (this.possible_split_vals_by_attr.get(this.splitAttributeM) == null) {
                this.possible_split_vals_by_attr.put(this.splitAttributeM, this.tabulatePossibleSplitValuesGivenAttribute());
            }
            return this.possible_split_vals_by_attr.get(this.splitAttributeM);
        }
        return this.tabulatePossibleSplitValuesGivenAttribute();
    }

    private TDoubleHashSetAndArray tabulatePossibleSplitValuesGivenAttribute() {
        double d;
        double[] dArray = (double[])this.bart.X_y_by_col.get(this.splitAttributeM);
        double[] dArray2 = new double[this.n_eta];
        for (int i = 0; i < this.n_eta; ++i) {
            d = dArray[this.indicies[i]];
            dArray2[i] = Classifier.isMissing(d) ? -1.7976931348623157E308 : d;
        }
        TDoubleHashSetAndArray tDoubleHashSetAndArray = new TDoubleHashSetAndArray(dArray2);
        tDoubleHashSetAndArray.remove(-1.7976931348623157E308);
        d = Tools.max(dArray2);
        tDoubleHashSetAndArray.remove(d);
        return tDoubleHashSetAndArray;
    }

    public double pickRandomSplitValue() {
        TDoubleHashSetAndArray tDoubleHashSetAndArray = this.possibleSplitValuesGivenAttribute();
        if (tDoubleHashSetAndArray.size() == 0) {
            return -1.7976931348623157E308;
        }
        int n = (int)Math.floor(StatToolbox.rand() * (double)tDoubleHashSetAndArray.size());
        return tDoubleHashSetAndArray.getAtIndex(n);
    }

    public boolean isStump() {
        return this.parent == null && this.left == null && this.right == null;
    }

    public String stringID() {
        return this.toString().split("@")[1];
    }

    public void setStumpData(ArrayList<double[]> arrayList, double[] dArray, int n) {
        int n2;
        this.n_eta = arrayList.size();
        this.responses = new double[this.n_eta];
        this.indicies = new int[this.n_eta];
        for (n2 = 0; n2 < this.n_eta; ++n2) {
            this.indicies[n2] = n2;
        }
        for (n2 = 0; n2 < this.n_eta; ++n2) {
            for (int i = 0; i < n + 2; ++i) {
                if (i != n) continue;
                this.responses[n2] = dArray[n2];
            }
        }
        this.yhats = new double[this.n_eta];
        this.sendMissingDataRight = bartMachineTreeNode.pickRandomDirectionForMissingData();
    }

    public void updateYHatsWithPrediction() {
        for (int i = 0; i < this.indicies.length; ++i) {
            this.yhats[this.indicies[i]] = this.y_pred;
        }
    }

    public int[] attributeSplitCounts() {
        if (this.attribute_split_counts == null) {
            this.attribute_split_counts = new int[this.bart.p];
            this.attributeSplitCountsInner(this.attribute_split_counts);
        }
        return this.attribute_split_counts;
    }

    public void attributeSplitCountsInner(int[] nArray) {
        if (this.isLeaf) {
            return;
        }
        int n = this.splitAttributeM;
        nArray[n] = nArray[n] + 1;
        this.left.attributeSplitCountsInner(nArray);
        this.right.attributeSplitCountsInner(nArray);
    }

    public void findInteractions(HashSet<UnorderedPair<Integer>> hashSet) {
        if (this.isLeaf) {
            return;
        }
        this.findSplitAttributesUsedUnderneath(this.splitAttributeM, hashSet);
        this.left.findInteractions(hashSet);
        this.right.findInteractions(hashSet);
    }

    private void findSplitAttributesUsedUnderneath(int n, HashSet<UnorderedPair<Integer>> hashSet) {
        if (this.isLeaf) {
            return;
        }
        if (!this.left.isLeaf) {
            hashSet.add(new UnorderedPair<Integer>(n, this.left.splitAttributeM));
        }
        if (!this.right.isLeaf) {
            hashSet.add(new UnorderedPair<Integer>(n, this.right.splitAttributeM));
        }
        this.left.findSplitAttributesUsedUnderneath(n, hashSet);
        this.right.findSplitAttributesUsedUnderneath(n, hashSet);
    }

    public void clearRulesAndSplitCache() {
        this.possible_rule_variables = null;
        this.possible_split_vals_by_attr = null;
    }

    public static boolean pickRandomDirectionForMissingData() {
        return !(StatToolbox.rand() < 0.5);
    }

    public String stringLocation(boolean bl) {
        if (this.parent == null) {
            return bl ? "P" : "";
        }
        if (this.parent.left == this) {
            return this.parent.stringLocation(false) + "L";
        }
        if (this.parent.right == this) {
            return this.parent.stringLocation(false) + "R";
        }
        return this.parent.stringLocation(false) + "?";
    }

    public String stringLocation() {
        return this.stringLocation(true);
    }

    public void printNodeDebugInfo(String string) {
        int n;
        double[] dArray;
        Object object;
        System.out.println("\n" + string + " node debug info for " + this.stringLocation(true) + (this.isLeaf ? " (LEAF) " : " (INTERNAL NODE) ") + " d = " + this.depth);
        System.out.println("-----------------------------------------");
        System.out.println("n_eta = " + this.n_eta + " y_pred = " + (this.y_pred == -1.7976931348623157E308 ? "BLANK" : Double.valueOf(this.bart.un_transform_y_and_round(this.y_pred))));
        System.out.println("parent = " + this.parent + " left = " + this.left + " right = " + this.right);
        if (this.parent != null) {
            System.out.println("----- PARENT RULE:   X_" + this.parent.splitAttributeM + " <= " + this.parent.splitValue + " & M -> " + (this.parent.sendMissingDataRight ? "R" : "L") + " ------");
            object = (double[])this.bart.X_y_by_col.get(this.parent.splitAttributeM);
            dArray = new double[this.n_eta];
            for (n = 0; n < this.n_eta; ++n) {
                dArray[n] = (double)object[this.indicies[n]];
            }
            Arrays.sort(dArray);
            System.out.println("   all X_" + this.parent.splitAttributeM + " values here: [" + Tools.StringJoin(dArray) + "]");
        }
        if (!this.isLeaf) {
            System.out.println("----- RULE:   X_" + this.splitAttributeM + " <= " + this.splitValue + " & M -> " + (this.sendMissingDataRight ? "R" : "L") + " ------");
            object = (double[])this.bart.X_y_by_col.get(this.splitAttributeM);
            dArray = new double[this.n_eta];
            for (n = 0; n < this.n_eta; ++n) {
                dArray[n] = object[this.indicies[n]];
            }
            Arrays.sort(dArray);
            System.out.println("   all X_" + this.splitAttributeM + " values here: [" + Tools.StringJoin(dArray) + "]");
        }
        System.out.println("sum_responses_qty = " + this.sum_responses_qty + " sum_responses_qty_sqd = " + this.sum_responses_qty_sqd);
        if (this.bart.mem_cache_for_speed) {
            System.out.println("possible_rule_variables: [" + Tools.StringJoin(this.possible_rule_variables, ", ") + "]");
            System.out.println("possible_split_vals_by_attr: {");
            if (this.possible_split_vals_by_attr != null) {
                object = this.possible_split_vals_by_attr.keySet().iterator();
                while (object.hasNext()) {
                    int n2 = (Integer)object.next();
                    double[] dArray2 = this.possible_split_vals_by_attr.get(n2).toArray();
                    Arrays.sort(dArray2);
                    System.out.println("  " + n2 + " -> [" + Tools.StringJoin(dArray2) + "],");
                }
                System.out.print(" }\n");
            } else {
                System.out.println(" NULL MAP\n}");
            }
        }
        System.out.println("responses: (size " + this.responses.length + ") [" + Tools.StringJoin(this.bart.un_transform_y_and_round(this.responses)) + "]");
        System.out.println("indicies: (size " + this.indicies.length + ") [" + Tools.StringJoin(this.indicies) + "]");
        if (Arrays.equals(this.yhats, new double[this.yhats.length])) {
            System.out.println("y_hat_vec: (size " + this.yhats.length + ") [ BLANK ]");
        } else {
            System.out.println("y_hat_vec: (size " + this.yhats.length + ") [" + Tools.StringJoin(this.bart.un_transform_y_and_round(this.yhats)) + "]");
        }
        System.out.println("-----------------------------------------\n\n\n");
    }

    public bartMachineTreeNode getLeft() {
        return this.left;
    }

    public bartMachineTreeNode getRight() {
        return this.right;
    }

    public int getGeneration() {
        return this.depth;
    }

    public void setGeneration(int n) {
        this.depth = n;
    }

    public boolean isLeaf() {
        return this.isLeaf;
    }

    public void setLeaf(boolean bl) {
        this.isLeaf = bl;
    }

    public void setLeft(bartMachineTreeNode bartMachineTreeNode2) {
        this.left = bartMachineTreeNode2;
    }

    public void setRight(bartMachineTreeNode bartMachineTreeNode2) {
        this.right = bartMachineTreeNode2;
    }

    public int getSplitAttributeM() {
        return this.splitAttributeM;
    }

    public void setSplitAttributeM(int n) {
        this.splitAttributeM = n;
    }

    public double getSplitValue() {
        return this.splitValue;
    }

    public void setSplitValue(double d) {
        this.splitValue = d;
    }
}

