/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.learning.parametric.bayesian;

import eu.amidst.core.datastream.DataInstance;
import eu.amidst.core.datastream.DataOnMemory;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.learning.parametric.bayesian.BayesianParameterLearningAlgorithm;
import eu.amidst.core.learning.parametric.bayesian.SVB;
import eu.amidst.core.learning.parametric.bayesian.utils.DataPosterior;
import eu.amidst.core.learning.parametric.bayesian.utils.PlateuStructure;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.models.DAG;
import eu.amidst.core.utils.CompoundVector;
import eu.amidst.core.utils.Serialization;
import eu.amidst.core.variables.Variable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.IntStream;

public class ParallelSVB
implements BayesianParameterLearningAlgorithm {
    DataStream<DataInstance> data;
    SVB[] svbEngines;
    DAG dag;
    int nCores = -1;
    SVB SVBEngine = new SVB();
    double logLikelihood;
    int seed = 0;
    boolean activateOutput = false;

    @Override
    public void setSeed(int seed_) {
        this.seed = seed_;
    }

    public void setNCores(int nCores) {
        this.nCores = nCores;
    }

    public SVB getSVBEngine() {
        return this.SVBEngine;
    }

    public void setSVBEngine(SVB SVBEngine) {
        this.SVBEngine = SVBEngine;
    }

    @Override
    public void initLearning() {
        if (this.nCores == -1) {
            this.nCores = Runtime.getRuntime().availableProcessors();
        }
        this.SVBEngine.setDAG(this.dag);
        this.SVBEngine.setSeed(this.seed);
        this.SVBEngine.initLearning();
        this.svbEngines = new SVB[this.nCores];
        for (int i = 0; i < this.nCores; ++i) {
            this.svbEngines[i] = Serialization.deepCopy(this.SVBEngine);
            this.svbEngines[i].initLearning();
        }
        this.SVBEngine = this.svbEngines[0];
    }

    @Override
    public double updateModel(DataOnMemory<DataInstance> batch) {
        throw new UnsupportedOperationException("Use standard StreamingSVB for sequential updating");
    }

    @Override
    public void setDataStream(DataStream<DataInstance> data_) {
        this.data = data_;
    }

    @Override
    public double getLogMarginalProbability() {
        return this.logLikelihood;
    }

    @Override
    public void setWindowsSize(int batchSize_) {
        this.SVBEngine.setWindowsSize(batchSize_);
    }

    @Override
    public int getWindowsSize() {
        return this.SVBEngine.getWindowsSize();
    }

    @Override
    public void runLearning() {
        this.initLearning();
        Iterator<DataOnMemory<DataInstance>> iterator = this.data.iterableOverBatches(this.SVBEngine.getWindowsSize()).iterator();
        this.logLikelihood = 0.0;
        while (iterator.hasNext()) {
            CompoundVector posterior = this.svbEngines[0].getNaturalParameterPrior();
            ArrayList<DataOnMemory<DataInstance>> dataBatches = new ArrayList<DataOnMemory<DataInstance>>();
            for (int cont = 0; iterator.hasNext() && cont < this.nCores; ++cont) {
                dataBatches.add(iterator.next());
            }
            SVB.BatchOutput out = IntStream.range(0, dataBatches.size()).parallel().mapToObj(i -> this.svbEngines[i].updateModelOnBatchParallel((DataOnMemory)dataBatches.get(i))).reduce(SVB.BatchOutput::sumNonStateless).get();
            this.logLikelihood += out.getElbo();
            posterior.sum(out.getVector());
            for (int i2 = 0; i2 < this.nCores; ++i2) {
                this.svbEngines[i2].updateNaturalParameterPrior(posterior);
            }
        }
    }

    @Override
    public double updateModel(DataStream<DataInstance> data) {
        this.logLikelihood = Double.NEGATIVE_INFINITY;
        boolean convergence = false;
        while (!convergence) {
            CompoundVector posterior = this.svbEngines[0].getNaturalParameterPrior();
            Iterator<DataOnMemory<DataInstance>> iterator = data.iterableOverBatches(this.SVBEngine.getWindowsSize()).iterator();
            double local_loglikelihood = 0.0;
            while (iterator.hasNext()) {
                ArrayList<DataOnMemory<DataInstance>> dataBatches = new ArrayList<DataOnMemory<DataInstance>>();
                for (int cont = 0; iterator.hasNext() && cont < this.nCores; ++cont) {
                    dataBatches.add(iterator.next());
                }
                SVB.BatchOutput out = IntStream.range(0, dataBatches.size()).parallel().mapToObj(i -> this.svbEngines[i].updateModelOnBatchParallel((DataOnMemory)dataBatches.get(i))).reduce(SVB.BatchOutput::sumNonStateless).get();
                posterior.sum(out.getVector());
                local_loglikelihood += out.getElbo();
            }
            for (int i2 = 0; i2 < this.nCores; ++i2) {
                this.svbEngines[i2].updateNaturalParameterPrior(posterior);
            }
            if (Math.abs(this.logLikelihood - local_loglikelihood) / (double)this.SVBEngine.getPlateuStructure().getNumberOfReplications() < 0.01) {
                convergence = true;
                continue;
            }
            if ((this.logLikelihood - local_loglikelihood) / (double)this.SVBEngine.getPlateuStructure().getNumberOfReplications() > 0.01) {
                throw new IllegalStateException("Non increasing log likelihood: " + local_loglikelihood + " , " + this.logLikelihood);
            }
            this.logLikelihood = local_loglikelihood;
        }
        for (SVB svbEngine : this.svbEngines) {
            svbEngine.applyTransition();
        }
        return this.logLikelihood;
    }

    @Override
    public List<DataPosterior> computePosterior(DataOnMemory<DataInstance> batch) {
        throw new UnsupportedOperationException("Method not implemented");
    }

    @Override
    public List<DataPosterior> computePosterior(DataOnMemory<DataInstance> batch, List<Variable> latentVariables) {
        throw new UnsupportedOperationException("Method not implemented");
    }

    @Override
    public double predictedLogLikelihood(DataOnMemory<DataInstance> batch) {
        return this.SVBEngine.predictedLogLikelihood(batch);
    }

    @Override
    public void setPlateuStructure(PlateuStructure plateuStructure) {
        this.SVBEngine.setPlateuStructure(plateuStructure);
    }

    @Override
    public void setDAG(DAG dag_) {
        this.dag = dag_;
    }

    @Override
    public BayesianNetwork getLearntBayesianNetwork() {
        return this.svbEngines[0].getLearntBayesianNetwork();
    }

    @Override
    public void setParallelMode(boolean parallelMode) {
    }

    @Override
    public void setOutput(boolean activateOutput_) {
        this.activateOutput = activateOutput_;
    }
}

