/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.core.attributeclassobservers;

import java.io.Serializable;
import moa.classifiers.core.AttributeSplitSuggestion;
import moa.classifiers.core.attributeclassobservers.BinaryTreeNumericAttributeClassObserver;
import moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
import moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest;
import moa.classifiers.core.splitcriteria.SplitCriterion;
import moa.core.DoubleVector;
import moa.core.ObjectRepository;
import moa.tasks.TaskMonitor;

public class FIMTDDNumericAttributeClassObserver
extends BinaryTreeNumericAttributeClassObserver
implements NumericAttributeClassObserver {
    private static final long serialVersionUID = 1L;
    protected Node root = null;
    double sumTotalLeft;
    double sumTotalRight;
    double sumSqTotalLeft;
    double sumSqTotalRight;
    double countRightTotal;
    double countLeftTotal;

    public void observeAttributeClass(double attVal, double classVal, double weight) {
        if (!Double.isNaN(attVal)) {
            if (this.root == null) {
                this.root = new Node(attVal, classVal, weight);
            } else {
                this.root.insertValue(attVal, classVal, weight);
            }
        }
    }

    @Override
    public double probabilityOfAttributeValueGivenClass(double attVal, int classVal) {
        return 0.0;
    }

    @Override
    public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(SplitCriterion criterion, double[] preSplitDist, int attIndex, boolean binaryOnly) {
        this.sumTotalLeft = 0.0;
        this.sumTotalRight = preSplitDist[1];
        this.sumSqTotalLeft = 0.0;
        this.sumSqTotalRight = preSplitDist[2];
        this.countLeftTotal = 0.0;
        this.countRightTotal = preSplitDist[0];
        return this.searchForBestSplitOption(this.root, null, criterion, attIndex);
    }

    protected AttributeSplitSuggestion searchForBestSplitOption(Node currentNode, AttributeSplitSuggestion currentBestOption, SplitCriterion criterion, int attIndex) {
        if (currentNode == null || this.countRightTotal == 0.0) {
            return currentBestOption;
        }
        if (currentNode.left != null) {
            currentBestOption = this.searchForBestSplitOption(currentNode.left, currentBestOption, criterion, attIndex);
        }
        this.sumTotalLeft += currentNode.leftStatistics.getValue(1);
        this.sumTotalRight -= currentNode.leftStatistics.getValue(1);
        this.sumSqTotalLeft += currentNode.leftStatistics.getValue(2);
        this.sumSqTotalRight -= currentNode.leftStatistics.getValue(2);
        this.countLeftTotal += currentNode.leftStatistics.getValue(0);
        this.countRightTotal -= currentNode.leftStatistics.getValue(0);
        double[][] postSplitDists = new double[][]{{this.countLeftTotal, this.sumTotalLeft, this.sumSqTotalLeft}, {this.countRightTotal, this.sumTotalRight, this.sumSqTotalRight}};
        double[] preSplitDist = new double[]{this.countLeftTotal + this.countRightTotal, this.sumTotalLeft + this.sumTotalRight, this.sumSqTotalLeft + this.sumSqTotalRight};
        double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
        if (currentBestOption == null || merit > currentBestOption.merit) {
            currentBestOption = new AttributeSplitSuggestion(new NumericAttributeBinaryTest(attIndex, currentNode.cut_point, true), postSplitDists, merit);
        }
        if (currentNode.right != null) {
            currentBestOption = this.searchForBestSplitOption(currentNode.right, currentBestOption, criterion, attIndex);
        }
        this.sumTotalLeft -= currentNode.leftStatistics.getValue(1);
        this.sumTotalRight += currentNode.leftStatistics.getValue(1);
        this.sumSqTotalLeft -= currentNode.leftStatistics.getValue(2);
        this.sumSqTotalRight += currentNode.leftStatistics.getValue(2);
        this.countLeftTotal -= currentNode.leftStatistics.getValue(0);
        this.countRightTotal += currentNode.leftStatistics.getValue(0);
        return currentBestOption;
    }

    public void removeBadSplits(SplitCriterion criterion, double lastCheckRatio, double lastCheckSDR, double lastCheckE) {
        this.removeBadSplitNodes(criterion, this.root, lastCheckRatio, lastCheckSDR, lastCheckE);
    }

    private boolean removeBadSplitNodes(SplitCriterion criterion, Node currentNode, double lastCheckRatio, double lastCheckSDR, double lastCheckE) {
        boolean isBad = false;
        if (currentNode == null) {
            return true;
        }
        if (currentNode.left != null) {
            isBad = this.removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE);
        }
        if (currentNode.right != null && isBad) {
            isBad = this.removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE);
        }
        if (isBad) {
            double[][] postSplitDists = new double[][]{{currentNode.leftStatistics.getValue(0), currentNode.leftStatistics.getValue(1), currentNode.leftStatistics.getValue(2)}, {currentNode.rightStatistics.getValue(0), currentNode.rightStatistics.getValue(1), currentNode.rightStatistics.getValue(2)}};
            double[] dArray = new double[]{currentNode.leftStatistics.getValue(0) + currentNode.rightStatistics.getValue(0), currentNode.leftStatistics.getValue(1) + currentNode.rightStatistics.getValue(1), currentNode.leftStatistics.getValue(2) + currentNode.rightStatistics.getValue(2)};
            double[] preSplitDist = dArray;
            double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
            if (merit / lastCheckSDR < lastCheckRatio - 2.0 * lastCheckE) {
                currentNode = null;
                return true;
            }
        }
        return false;
    }

    @Override
    public void getDescription(StringBuilder sb, int indent) {
    }

    @Override
    protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
    }

    protected class Node
    implements Serializable {
        private static final long serialVersionUID = 1L;
        public double cut_point;
        public DoubleVector leftStatistics = new DoubleVector();
        public DoubleVector rightStatistics = new DoubleVector();
        public Node left;
        public Node right;

        public Node(double val, double label, double weight) {
            this.cut_point = val;
            this.leftStatistics.addToValue(0, 1.0);
            this.leftStatistics.addToValue(1, label);
            this.leftStatistics.addToValue(2, label * label);
        }

        public void insertValue(double val, double label, double weight) {
            if (val == this.cut_point) {
                this.leftStatistics.addToValue(0, 1.0);
                this.leftStatistics.addToValue(1, label);
                this.leftStatistics.addToValue(2, label * label);
            } else if (val <= this.cut_point) {
                this.leftStatistics.addToValue(0, 1.0);
                this.leftStatistics.addToValue(1, label);
                this.leftStatistics.addToValue(2, label * label);
                if (this.left == null) {
                    this.left = new Node(val, label, weight);
                } else {
                    this.left.insertValue(val, label, weight);
                }
            } else {
                this.rightStatistics.addToValue(0, 1.0);
                this.rightStatistics.addToValue(1, label);
                this.rightStatistics.addToValue(2, label * label);
                if (this.right == null) {
                    this.right = new Node(val, label, weight);
                } else {
                    this.right.insertValue(val, label, weight);
                }
            }
        }
    }
}

