/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes.net.search.local;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.classifiers.bayes.net.search.local.Scoreable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Statistics;
import weka.core.Tag;
import weka.core.Utils;

public class LocalScoreSearchAlgorithm
extends SearchAlgorithm {
    static final long serialVersionUID = 3325995552474190374L;
    BayesNet m_BayesNet;
    double m_fAlpha = 0.5;
    public static final Tag[] TAGS_SCORE_TYPE = new Tag[]{new Tag(0, "BAYES"), new Tag(1, "BDeu"), new Tag(2, "MDL"), new Tag(3, "ENTROPY"), new Tag(4, "AIC")};
    int m_nScoreType = 0;

    public LocalScoreSearchAlgorithm() {
    }

    public LocalScoreSearchAlgorithm(BayesNet bayesNet, Instances instances) {
        this.m_BayesNet = bayesNet;
    }

    public double logScore(int nType) {
        if (this.m_BayesNet.m_Distributions == null) {
            return 0.0;
        }
        if (nType < 0) {
            nType = this.m_nScoreType;
        }
        double fLogScore = 0.0;
        Instances instances = this.m_BayesNet.m_Instances;
        int iAttribute = 0;
        while (iAttribute < instances.numAttributes()) {
            int nCardinality = this.m_BayesNet.getParentSet(iAttribute).getCardinalityOfParents();
            int iParent = 0;
            while (iParent < nCardinality) {
                fLogScore += ((Scoreable)((Object)this.m_BayesNet.m_Distributions[iAttribute][iParent])).logScore(nType, nCardinality);
                ++iParent;
            }
            switch (nType) {
                case 2: {
                    fLogScore -= 0.5 * (double)this.m_BayesNet.getParentSet(iAttribute).getCardinalityOfParents() * (double)(instances.attribute(iAttribute).numValues() - 1) * Math.log(instances.numInstances());
                    break;
                }
                case 4: {
                    fLogScore -= (double)(this.m_BayesNet.getParentSet(iAttribute).getCardinalityOfParents() * (instances.attribute(iAttribute).numValues() - 1));
                }
            }
            ++iAttribute;
        }
        return fLogScore;
    }

    @Override
    public void buildStructure(BayesNet bayesNet, Instances instances) throws Exception {
        this.m_BayesNet = bayesNet;
        super.buildStructure(bayesNet, instances);
    }

    public double calcNodeScore(int nNode) {
        if (this.m_BayesNet.getUseADTree() && this.m_BayesNet.getADTree() != null) {
            return this.calcNodeScoreADTree(nNode);
        }
        return this.calcNodeScorePlain(nNode);
    }

    private double calcNodeScoreADTree(int nNode) {
        Instances instances = this.m_BayesNet.m_Instances;
        ParentSet oParentSet = this.m_BayesNet.getParentSet(nNode);
        int nNrOfParents = oParentSet.getNrOfParents();
        int[] nNodes = new int[nNrOfParents + 1];
        int iParent = 0;
        while (iParent < nNrOfParents) {
            nNodes[iParent] = oParentSet.getParent(iParent);
            ++iParent;
        }
        nNodes[nNrOfParents] = nNode;
        int[] nOffsets = new int[nNrOfParents + 1];
        int nOffset = 1;
        nOffsets[nNrOfParents] = 1;
        nOffset *= instances.attribute(nNode).numValues();
        int iNode = nNrOfParents - 1;
        while (iNode >= 0) {
            nOffsets[iNode] = nOffset;
            nOffset *= instances.attribute(nNodes[iNode]).numValues();
            --iNode;
        }
        iNode = 1;
        while (iNode < nNodes.length) {
            int iNode2 = iNode;
            while (iNode2 > 0 && nNodes[iNode2] < nNodes[iNode2 - 1]) {
                int h = nNodes[iNode2];
                nNodes[iNode2] = nNodes[iNode2 - 1];
                nNodes[iNode2 - 1] = h;
                h = nOffsets[iNode2];
                nOffsets[iNode2] = nOffsets[iNode2 - 1];
                nOffsets[iNode2 - 1] = h;
                --iNode2;
            }
            ++iNode;
        }
        int nCardinality = oParentSet.getCardinalityOfParents();
        int numValues = instances.attribute(nNode).numValues();
        int[] nCounts = new int[nCardinality * numValues];
        this.m_BayesNet.getADTree().getCounts(nCounts, nNodes, nOffsets, 0, 0, false);
        return this.calcScoreOfCounts(nCounts, nCardinality, numValues, instances);
    }

    private double calcNodeScorePlain(int nNode) {
        Instances instances = this.m_BayesNet.m_Instances;
        ParentSet oParentSet = this.m_BayesNet.getParentSet(nNode);
        int nCardinality = oParentSet.getCardinalityOfParents();
        int numValues = instances.attribute(nNode).numValues();
        int[] nCounts = new int[nCardinality * numValues];
        int iParent = 0;
        while (iParent < nCardinality * numValues) {
            nCounts[iParent] = 0;
            ++iParent;
        }
        Enumeration enumInsts = instances.enumerateInstances();
        while (enumInsts.hasMoreElements()) {
            Instance instance = (Instance)enumInsts.nextElement();
            double iCPT = 0.0;
            int iParent2 = 0;
            while (iParent2 < oParentSet.getNrOfParents()) {
                int nParent = oParentSet.getParent(iParent2);
                iCPT = iCPT * (double)instances.attribute(nParent).numValues() + instance.value(nParent);
                ++iParent2;
            }
            int n = numValues * (int)iCPT + (int)instance.value(nNode);
            nCounts[n] = nCounts[n] + 1;
        }
        return this.calcScoreOfCounts(nCounts, nCardinality, numValues, instances);
    }

    protected double calcScoreOfCounts(int[] nCounts, int nCardinality, int numValues, Instances instances) {
        double fLogScore = 0.0;
        int iParent = 0;
        while (iParent < nCardinality) {
            switch (this.m_nScoreType) {
                case 0: {
                    double nSumOfCounts = 0.0;
                    int iSymbol = 0;
                    while (iSymbol < numValues) {
                        if (this.m_fAlpha + (double)nCounts[iParent * numValues + iSymbol] != 0.0) {
                            fLogScore += Statistics.lnGamma(this.m_fAlpha + (double)nCounts[iParent * numValues + iSymbol]);
                            nSumOfCounts += this.m_fAlpha + (double)nCounts[iParent * numValues + iSymbol];
                        }
                        ++iSymbol;
                    }
                    if (nSumOfCounts != 0.0) {
                        fLogScore -= Statistics.lnGamma(nSumOfCounts);
                    }
                    if (this.m_fAlpha == 0.0) break;
                    fLogScore -= (double)numValues * Statistics.lnGamma(this.m_fAlpha);
                    fLogScore += Statistics.lnGamma((double)numValues * this.m_fAlpha);
                    break;
                }
                case 1: {
                    double nSumOfCounts = 0.0;
                    int iSymbol = 0;
                    while (iSymbol < numValues) {
                        if (this.m_fAlpha + (double)nCounts[iParent * numValues + iSymbol] != 0.0) {
                            fLogScore += Statistics.lnGamma(1.0 / (double)(numValues * nCardinality) + (double)nCounts[iParent * numValues + iSymbol]);
                            nSumOfCounts += 1.0 / (double)(numValues * nCardinality) + (double)nCounts[iParent * numValues + iSymbol];
                        }
                        ++iSymbol;
                    }
                    fLogScore -= Statistics.lnGamma(nSumOfCounts);
                    fLogScore -= (double)numValues * Statistics.lnGamma(1.0 / (double)(numValues * nCardinality));
                    fLogScore += Statistics.lnGamma(1.0 / (double)nCardinality);
                    break;
                }
                case 2: 
                case 3: 
                case 4: {
                    double nSumOfCounts = 0.0;
                    int iSymbol = 0;
                    while (iSymbol < numValues) {
                        nSumOfCounts += (double)nCounts[iParent * numValues + iSymbol];
                        ++iSymbol;
                    }
                    iSymbol = 0;
                    while (iSymbol < numValues) {
                        if (nCounts[iParent * numValues + iSymbol] > 0) {
                            fLogScore += (double)nCounts[iParent * numValues + iSymbol] * Math.log((double)nCounts[iParent * numValues + iSymbol] / nSumOfCounts);
                        }
                        ++iSymbol;
                    }
                    break;
                }
            }
            ++iParent;
        }
        switch (this.m_nScoreType) {
            case 2: {
                fLogScore -= 0.5 * (double)nCardinality * (double)(numValues - 1) * Math.log(instances.numInstances());
                break;
            }
            case 4: {
                fLogScore -= (double)(nCardinality * (numValues - 1));
            }
        }
        return fLogScore;
    }

    protected double calcScoreOfCounts2(int[][] nCounts, int nCardinality, int numValues, Instances instances) {
        double fLogScore = 0.0;
        int iParent = 0;
        while (iParent < nCardinality) {
            switch (this.m_nScoreType) {
                case 0: {
                    double nSumOfCounts = 0.0;
                    int iSymbol = 0;
                    while (iSymbol < numValues) {
                        if (this.m_fAlpha + (double)nCounts[iParent][iSymbol] != 0.0) {
                            fLogScore += Statistics.lnGamma(this.m_fAlpha + (double)nCounts[iParent][iSymbol]);
                            nSumOfCounts += this.m_fAlpha + (double)nCounts[iParent][iSymbol];
                        }
                        ++iSymbol;
                    }
                    if (nSumOfCounts != 0.0) {
                        fLogScore -= Statistics.lnGamma(nSumOfCounts);
                    }
                    if (this.m_fAlpha == 0.0) break;
                    fLogScore -= (double)numValues * Statistics.lnGamma(this.m_fAlpha);
                    fLogScore += Statistics.lnGamma((double)numValues * this.m_fAlpha);
                    break;
                }
                case 1: {
                    double nSumOfCounts = 0.0;
                    int iSymbol = 0;
                    while (iSymbol < numValues) {
                        if (this.m_fAlpha + (double)nCounts[iParent][iSymbol] != 0.0) {
                            fLogScore += Statistics.lnGamma(1.0 / (double)(numValues * nCardinality) + (double)nCounts[iParent][iSymbol]);
                            nSumOfCounts += 1.0 / (double)(numValues * nCardinality) + (double)nCounts[iParent][iSymbol];
                        }
                        ++iSymbol;
                    }
                    fLogScore -= Statistics.lnGamma(nSumOfCounts);
                    fLogScore -= (double)numValues * Statistics.lnGamma(1.0 / (double)(nCardinality * numValues));
                    fLogScore += Statistics.lnGamma(1.0 / (double)nCardinality);
                    break;
                }
                case 2: 
                case 3: 
                case 4: {
                    double nSumOfCounts = 0.0;
                    int iSymbol = 0;
                    while (iSymbol < numValues) {
                        nSumOfCounts += (double)nCounts[iParent][iSymbol];
                        ++iSymbol;
                    }
                    iSymbol = 0;
                    while (iSymbol < numValues) {
                        if (nCounts[iParent][iSymbol] > 0) {
                            fLogScore += (double)nCounts[iParent][iSymbol] * Math.log((double)nCounts[iParent][iSymbol] / nSumOfCounts);
                        }
                        ++iSymbol;
                    }
                    break;
                }
            }
            ++iParent;
        }
        switch (this.m_nScoreType) {
            case 2: {
                fLogScore -= 0.5 * (double)nCardinality * (double)(numValues - 1) * Math.log(instances.numInstances());
                break;
            }
            case 4: {
                fLogScore -= (double)(nCardinality * (numValues - 1));
            }
        }
        return fLogScore;
    }

    public double calcScoreWithExtraParent(int nNode, int nCandidateParent) {
        ParentSet oParentSet = this.m_BayesNet.getParentSet(nNode);
        if (oParentSet.contains(nCandidateParent)) {
            return -1.0E100;
        }
        oParentSet.addParent(nCandidateParent, this.m_BayesNet.m_Instances);
        double logScore = this.calcNodeScore(nNode);
        oParentSet.deleteLastParent(this.m_BayesNet.m_Instances);
        return logScore;
    }

    public double calcScoreWithMissingParent(int nNode, int nCandidateParent) {
        ParentSet oParentSet = this.m_BayesNet.getParentSet(nNode);
        if (!oParentSet.contains(nCandidateParent)) {
            return -1.0E100;
        }
        int iParent = oParentSet.deleteParent(nCandidateParent, this.m_BayesNet.m_Instances);
        double logScore = this.calcNodeScore(nNode);
        oParentSet.addParent(nCandidateParent, iParent, this.m_BayesNet.m_Instances);
        return logScore;
    }

    public void setScoreType(SelectedTag newScoreType) {
        if (newScoreType.getTags() == TAGS_SCORE_TYPE) {
            this.m_nScoreType = newScoreType.getSelectedTag().getID();
        }
    }

    public SelectedTag getScoreType() {
        return new SelectedTag(this.m_nScoreType, TAGS_SCORE_TYPE);
    }

    @Override
    public void setMarkovBlanketClassifier(boolean bMarkovBlanketClassifier) {
        super.setMarkovBlanketClassifier(bMarkovBlanketClassifier);
    }

    @Override
    public boolean getMarkovBlanketClassifier() {
        return super.getMarkovBlanketClassifier();
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tApplies a Markov Blanket correction to the network structure, \n\tafter a network structure is learned. This ensures that all \n\tnodes in the network are part of the Markov blanket of the \n\tclassifier node.", "mbc", 0, "-mbc"));
        newVector.addElement(new Option("\tScore type (BAYES, BDeu, MDL, ENTROPY and AIC)", "S", 1, "-S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setMarkovBlanketClassifier(Utils.getFlag("mbc", options));
        String sScore = Utils.getOption('S', options);
        if (sScore.compareTo("BAYES") == 0) {
            this.setScoreType(new SelectedTag(0, TAGS_SCORE_TYPE));
        }
        if (sScore.compareTo("BDeu") == 0) {
            this.setScoreType(new SelectedTag(1, TAGS_SCORE_TYPE));
        }
        if (sScore.compareTo("MDL") == 0) {
            this.setScoreType(new SelectedTag(2, TAGS_SCORE_TYPE));
        }
        if (sScore.compareTo("ENTROPY") == 0) {
            this.setScoreType(new SelectedTag(3, TAGS_SCORE_TYPE));
        }
        if (sScore.compareTo("AIC") == 0) {
            this.setScoreType(new SelectedTag(4, TAGS_SCORE_TYPE));
        }
    }

    @Override
    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[3 + superOptions.length];
        int current = 0;
        if (this.getMarkovBlanketClassifier()) {
            options[current++] = "-mbc";
        }
        options[current++] = "-S";
        switch (this.m_nScoreType) {
            case 0: {
                options[current++] = "BAYES";
                break;
            }
            case 1: {
                options[current++] = "BDeu";
                break;
            }
            case 2: {
                options[current++] = "MDL";
                break;
            }
            case 3: {
                options[current++] = "ENTROPY";
                break;
            }
            case 4: {
                options[current++] = "AIC";
            }
        }
        int iOption = 0;
        while (iOption < superOptions.length) {
            options[current++] = superOptions[iOption];
            ++iOption;
        }
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public String scoreTypeTipText() {
        return "The score type determines the measure used to judge the quality of a network structure. It can be one of Bayes, BDeu, Minimum Description Length (MDL), Akaike Information Criterion (AIC), and Entropy.";
    }

    @Override
    public String markovBlanketClassifierTipText() {
        return super.markovBlanketClassifierTipText();
    }

    public String globalInfo() {
        return "The ScoreBasedSearchAlgorithm class supports Bayes net structure search algorithms that are based on maximizing scores (as opposed to for example conditional independence based search algorithms).";
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5196 $");
    }
}

