/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.inference.messagepassing;

import com.google.common.base.Stopwatch;
import eu.amidst.core.distribution.ConditionalDistribution;
import eu.amidst.core.exponentialfamily.MomentParameters;
import eu.amidst.core.exponentialfamily.NaturalParameters;
import eu.amidst.core.inference.InferenceAlgorithm;
import eu.amidst.core.inference.InferenceEngine;
import eu.amidst.core.inference.Sampler;
import eu.amidst.core.inference.messagepassing.Message;
import eu.amidst.core.inference.messagepassing.MessagePassingAlgorithm;
import eu.amidst.core.inference.messagepassing.Node;
import eu.amidst.core.io.BayesianNetworkLoader;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.models.DAG;
import eu.amidst.core.variables.Variable;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

public class VMP
extends MessagePassingAlgorithm<NaturalParameters>
implements InferenceAlgorithm,
Sampler {
    boolean testELBO = false;

    public Random getRandom() {
        return this.random;
    }

    public void setTestELBO(boolean testELBO) {
        this.testELBO = testELBO;
    }

    @Override
    public Message<NaturalParameters> newSelfMessage(Node node) {
        Map<Variable, MomentParameters> momentParents = node.getMomentParents();
        Message<NaturalParameters> message = new Message<NaturalParameters>(node);
        message.setVector(node.getPDist().getExpectedNaturalFromParents(momentParents));
        message.setDone(node.messageDoneFromParents());
        return message;
    }

    @Override
    public Message<NaturalParameters> newMessageToParent(Node child, Node parent) {
        Map<Variable, MomentParameters> momentChildCoParents = child.getMomentParents();
        Message<NaturalParameters> message = new Message<NaturalParameters>(parent);
        message.setVector(child.getPDist().getExpectedNaturalToParent(child.nodeParentToVariable(parent), momentChildCoParents));
        message.setDone(child.messageDoneToParent(parent.getMainVariable()));
        return message;
    }

    @Override
    public void updateCombinedMessage(Node node, Message<NaturalParameters> message) {
        node.getQDist().setNaturalParameters(message.getVector());
        node.setIsDone(message.isDone());
    }

    @Override
    public boolean testConvergence() {
        boolean convergence = false;
        double newelbo = this.computeLogProbabilityOfEvidence();
        double percentage = 100.0 * Math.abs(newelbo - this.local_elbo) / Math.abs(this.local_elbo);
        if (percentage < this.threshold || this.local_iter > this.getMaxIter()) {
            convergence = true;
        }
        if (this.testELBO && !convergence && newelbo / (double)this.nodes.size() < this.local_elbo / (double)this.nodes.size() - 0.01 && this.local_iter > -1 || Double.isNaN(this.local_elbo)) {
            throw new IllegalStateException("The elbo is not monotonically increasing at iter " + this.local_iter + ": " + percentage + ", " + this.local_elbo + ", " + newelbo);
        }
        this.local_elbo = newelbo;
        return convergence;
    }

    @Override
    public double computeLogProbabilityOfEvidence() {
        return this.nodes.stream().filter(node -> node.isActive()).mapToDouble(node -> this.computeELBO((Node)node)).sum();
    }

    public double computeELBO(Node node) {
        Map<Variable, MomentParameters> momentParents = node.getMomentParents();
        double elbo = 0.0;
        if (!node.isObserved()) {
            elbo -= node.getQDist().kl(node.getPDist().getExpectedNaturalFromParents(momentParents), node.getPDist().getExpectedLogNormalizer(momentParents));
        } else {
            NaturalParameters expectedNatural = node.getPDist().getExpectedNaturalFromParents(momentParents);
            elbo += expectedNatural.dotProduct(node.getSufficientStatistics());
            elbo -= node.getPDist().getExpectedLogNormalizer(momentParents);
            elbo += node.getPDist().computeLogBaseMeasure(this.assignment);
        }
        if (elbo > 0.1 && !node.isObserved() || Double.isNaN(elbo)) {
            throw new IllegalStateException("NUMERICAL ERROR!!!!!!!!: " + node.getMainVariable().getName() + ", " + elbo);
        }
        return elbo;
    }

    @Override
    public BayesianNetwork getSamplingModel() {
        DAG dag = new DAG(this.model.getVariables());
        List<ConditionalDistribution> distributionList = this.model.getVariables().getListOfVariables().stream().map(var -> this.getPosterior((Variable)var)).collect(Collectors.toList());
        return new BayesianNetwork(dag, distributionList);
    }

    public static void main(String[] arguments) throws IOException, ClassNotFoundException {
        BayesianNetwork bn = BayesianNetworkLoader.loadFromFile("./networks/dataWeka/Munin1.bn");
        System.out.println(bn.getNumberOfVars());
        System.out.println(bn.getDAG().getNumberOfLinks());
        System.out.println(bn.getConditionalDistributions().stream().mapToInt(p -> p.getNumberOfParameters()).max().getAsInt());
        VMP vmp = new VMP();
        InferenceEngine.setInferenceAlgorithm(vmp);
        Variable var = bn.getVariables().getVariableById(0);
        Object uni = null;
        double avg = 0.0;
        for (int i = 0; i < 20; ++i) {
            Stopwatch watch = Stopwatch.createStarted();
            uni = InferenceEngine.getPosterior(var, bn);
            System.out.println(watch.stop());
            avg += (double)watch.elapsed(TimeUnit.MILLISECONDS);
        }
        System.out.println(avg / 20.0);
        System.out.println((Object)uni);
    }
}

