/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.learning.parametric.bayesian;

import eu.amidst.core.datastream.DataInstance;
import eu.amidst.core.datastream.DataOnMemory;
import eu.amidst.core.exponentialfamily.EF_TruncatedExponential;
import eu.amidst.core.exponentialfamily.MomentParameters;
import eu.amidst.core.inference.messagepassing.Node;
import eu.amidst.core.inference.messagepassing.VMP;
import eu.amidst.core.io.BayesianNetworkLoader;
import eu.amidst.core.learning.parametric.bayesian.SVB;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.utils.BayesianNetworkSampler;
import eu.amidst.core.utils.CompoundVector;
import eu.amidst.core.utils.Serialization;
import eu.amidst.core.variables.DistributionType;
import eu.amidst.core.variables.Variable;
import eu.amidst.core.variables.Variables;
import java.io.IOException;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;

public class MultiDriftSVB
extends SVB {
    EF_TruncatedExponential ef_TExpP;
    EF_TruncatedExponential ef_TExpQ;
    Variable truncatedExpVar;
    boolean firstBatch = true;
    CompoundVector posteriorT_1 = null;
    CompoundVector prior = null;
    double delta = 0.1;

    public double getDelta() {
        return this.delta;
    }

    public void setDelta(double delta) {
        this.delta = delta;
    }

    @Override
    public void initLearning() {
        super.initLearning();
        this.truncatedExpVar = new Variables().newTruncatedExponential("TruncatedExponentialVar");
        this.ef_TExpP = (EF_TruncatedExponential)((DistributionType)this.truncatedExpVar.getDistributionType()).newEFUnivariateDistribution(this.getDelta());
        this.ef_TExpQ = (EF_TruncatedExponential)((DistributionType)this.truncatedExpVar.getDistributionType()).newEFUnivariateDistribution(this.getDelta());
        this.firstBatch = true;
        this.prior = this.plateuStructure.getPlateauNaturalParameterPrior();
    }

    @Override
    public BayesianNetwork getLearntBayesianNetwork() {
        CompoundVector prior = this.plateuStructure.getPlateauNaturalParameterPrior();
        this.updateNaturalParameterPrior(this.plateuStructure.getPlateauNaturalParameterPosterior());
        BayesianNetwork learntBN = new BayesianNetwork(this.dag, this.ef_extendedBN.toConditionalDistribution());
        this.updateNaturalParameterPrior(prior);
        return learntBN;
    }

    public double updateModelWithConceptDrift(DataOnMemory<DataInstance> batch) {
        this.plateuStructure.setEvidence(batch.getList());
        if (this.firstBatch) {
            this.firstBatch = false;
            this.plateuStructure.runInference();
            this.posteriorT_1 = this.plateuStructure.getPlateauNaturalParameterPosterior();
            return this.plateuStructure.getLogProbabilityOfEvidence();
        }
        this.ef_TExpQ = (EF_TruncatedExponential)((DistributionType)this.truncatedExpVar.getDistributionType()).newEFUnivariateDistribution(this.getDelta());
        boolean convergence = false;
        double elbo = Double.NaN;
        for (double niter = 0.0; !convergence && niter < 100.0; niter += 1.0) {
            double lambda = this.ef_TExpQ.getMomentParameters().get(0);
            CompoundVector newPrior = Serialization.deepCopy(this.prior);
            newPrior.multiplyBy(1.0 - lambda);
            CompoundVector newPosterior = Serialization.deepCopy(this.posteriorT_1);
            newPosterior.multiplyBy(lambda);
            newPrior.sum(newPosterior);
            this.plateuStructure.updateNaturalParameterPrior(newPrior);
            this.plateuStructure.runInference();
            double newELBO = this.plateuStructure.getLogProbabilityOfEvidence();
            double[] kl_q_p0_vals = new double[(int)this.plateuStructure.getNonReplictedNodes().count()];
            double[] kl_q_pt_1_vals = new double[(int)this.plateuStructure.getNonReplictedNodes().count()];
            double kl_q_p0 = 0.0;
            int count = 0;
            this.plateuStructure.updateNaturalParameterPrior(this.prior);
            for (Node node : this.plateuStructure.getNonReplictedNodes().collect(Collectors.toList())) {
                Map<Variable, MomentParameters> momentParents = node.getMomentParents();
                kl_q_p0_vals[count] = node.getQDist().kl(node.getPDist().getExpectedNaturalFromParents(momentParents), node.getPDist().getExpectedLogNormalizer(momentParents));
                kl_q_p0 += kl_q_p0_vals[count];
                ++count;
            }
            double kl_q_pt_1 = 0.0;
            count = 0;
            this.plateuStructure.updateNaturalParameterPrior(this.posteriorT_1);
            for (Node node : this.plateuStructure.getNonReplictedNodes().collect(Collectors.toList())) {
                Map<Variable, MomentParameters> momentParents = node.getMomentParents();
                kl_q_pt_1_vals[count] = node.getQDist().kl(node.getPDist().getExpectedNaturalFromParents(momentParents), node.getPDist().getExpectedLogNormalizer(momentParents));
                kl_q_pt_1 += kl_q_pt_1_vals[count];
                ++count;
            }
            this.ef_TExpQ.getNaturalParameters().set(0, -kl_q_pt_1 + kl_q_p0 + this.ef_TExpP.getNaturalParameters().get(0));
            this.ef_TExpQ.fixNumericalInstability();
            this.ef_TExpQ.updateMomentFromNaturalParameters();
            newELBO -= this.ef_TExpQ.kl(this.ef_TExpP.getNaturalParameters(), this.ef_TExpP.computeLogNormalizer());
            if (!Double.isNaN(elbo) && newELBO < elbo) {
                new IllegalStateException("Non increasing lower bound");
            }
            double percentageIncrease = 100.0 * Math.abs((newELBO - elbo) / elbo);
            System.out.print("Delta: " + niter + ", " + newELBO + ", " + elbo + ", " + percentageIncrease + ", " + lambda + ", " + (-kl_q_pt_1 + kl_q_p0));
            for (int i = 0; i < kl_q_p0_vals.length; ++i) {
                System.out.print(", " + (-kl_q_pt_1_vals[i] + kl_q_p0_vals[i]));
            }
            System.out.println();
            if (!Double.isNaN(elbo) && percentageIncrease < this.plateuStructure.getVMP().getThreshold()) {
                convergence = true;
            }
            elbo = newELBO;
        }
        this.posteriorT_1 = this.plateuStructure.getPlateauNaturalParameterPosterior();
        return elbo;
    }

    public double getLambdaValue() {
        return this.ef_TExpQ.getMomentParameters().get(0);
    }

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        BayesianNetwork oneNormalVarBN = BayesianNetworkLoader.loadFromFile("./networks/simulated/Normal.bn");
        System.out.println(oneNormalVarBN);
        int batchSize = 1000;
        MultiDriftSVB svb = new MultiDriftSVB();
        svb.setWindowsSize(batchSize);
        svb.setSeed(0);
        VMP vmp = svb.getPlateuStructure().getVMP();
        vmp.setOutput(false);
        vmp.setTestELBO(true);
        vmp.setMaxIter(1000);
        vmp.setThreshold(1.0E-4);
        svb.setDAG(oneNormalVarBN.getDAG());
        svb.initLearning();
        double pred = 0.0;
        for (int i = 0; i < 10; ++i) {
            if (i % 3 == 0) {
                oneNormalVarBN.randomInitialization(new Random(i));
                System.out.println(oneNormalVarBN);
            }
            BayesianNetworkSampler sampler = new BayesianNetworkSampler(oneNormalVarBN);
            sampler.setSeed(i);
            DataOnMemory<DataInstance> batch = sampler.sampleToDataStream(batchSize).toDataOnMemory();
            if (i > 0) {
                pred += svb.predictedLogLikelihood(batch);
            }
            svb.updateModelWithConceptDrift(batch);
            System.out.println(svb.getLogMarginalProbability());
            System.out.println(svb.getLearntBayesianNetwork());
        }
        System.out.println(pred);
    }
}

