package org.integratedmodelling.riskwiz.learning.parameter.bayes;

import java.io.IOException;
import java.util.Iterator;
import java.util.Vector;
import org.integratedmodelling.riskwiz.bn.BNNode;
import org.integratedmodelling.riskwiz.bn.BeliefNetwork;
import org.integratedmodelling.riskwiz.domain.DiscreteDomain;
import org.integratedmodelling.riskwiz.learning.IParameterLearner;
import org.integratedmodelling.riskwiz.learning.bndata.IGraphData;
import org.integratedmodelling.riskwiz.learning.bndata.IGraphDataSource;
import org.integratedmodelling.riskwiz.learning.bndata.IGraphTable;
import org.integratedmodelling.riskwiz.learning.dtable.DistTable;
import org.integratedmodelling.riskwiz.pfunction.TabularCPD;
import org.nfunk.jep.ParseException;

/* loaded from: input_file:lib/riskwiz-1.0.0.jar:org/integratedmodelling/riskwiz/learning/parameter/bayes/BayesianLearner.class */
public class BayesianLearner implements IParameterLearner {
    BeliefNetwork bnet;

    @Override // org.integratedmodelling.riskwiz.learning.IParameterLearner
    public void initialize(BeliefNetwork beliefNetwork) {
        this.bnet = beliefNetwork;
        initializeUniformDistributions();
    }

    @Override // org.integratedmodelling.riskwiz.learning.IParameterLearner
    public BeliefNetwork getFinalResult() {
        for (BNNode bNNode : this.bnet.vertexSet()) {
            if (bNNode.isProbabilistic()) {
                bNNode.clearProperty("distribution");
            }
        }
        return this.bnet;
    }

    @Override // org.integratedmodelling.riskwiz.learning.IParameterLearner
    public void learnFromTable(IGraphTable iGraphTable) {
        iGraphTable.connect(this.bnet);
        for (BNNode bNNode : this.bnet.vertexSet()) {
            if (bNNode.isProbabilistic() && !bNNode.isExpression() && iGraphTable.hasCompleteProjection(bNNode)) {
                learnNodeParameters(bNNode, iGraphTable, iGraphTable.getValues());
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.learning.IParameterLearner
    public void learnFromDataSource(IGraphDataSource iGraphDataSource) throws IOException {
        iGraphDataSource.connect(this.bnet);
        while (iGraphDataSource.hasNextValues()) {
            Vector<Vector<String>> nextValues = iGraphDataSource.getNextValues();
            for (BNNode bNNode : this.bnet.vertexSet()) {
                if (bNNode.isProbabilistic() && !bNNode.isExpression() && iGraphDataSource.hasCompleteProjection(bNNode)) {
                    learnNodeParameters(bNNode, iGraphDataSource, nextValues);
                }
            }
            iGraphDataSource.readNextValues();
        }
        iGraphDataSource.close();
    }

    protected void learnNodeParameters(BNNode bNNode, IGraphData iGraphData, Vector<Vector<String>> vector) {
        DistTable distTable = (DistTable) bNNode.getProperty("distribution");
        Iterator<Vector<String>> it2 = vector.iterator();
        while (it2.hasNext()) {
            int[] query = iGraphData.getQuery(bNNode, it2.next());
            if (isCompleteQuery(query)) {
                distTable.getValue1(query).increment(query[0]);
            }
        }
        bNNode.setFunction(distTable.createCPF());
    }

    public boolean isCompleteQuery(int[] iArr) {
        for (int i : iArr) {
            if (i == -1) {
                return false;
            }
        }
        return true;
    }

    public void initializeWithPriors(BeliefNetwork beliefNetwork, int i) throws ParseException {
        this.bnet = beliefNetwork;
        initializeDistributions(i);
    }

    private void initializeUniformDistributions() {
        for (BNNode bNNode : this.bnet.vertexSet()) {
            if (bNNode.getFunction() instanceof TabularCPD) {
                TabularCPD tabularCPD = (TabularCPD) bNNode.getFunction();
                DistTable distTable = new DistTable((DiscreteDomain) tabularCPD.getDomain(), tabularCPD.getParentsDomains());
                distTable.setUniformDistributions();
                bNNode.setProperty("distribution", distTable);
                bNNode.setFunction(distTable.createCPF());
            }
        }
    }

    private void initializeDistributions(int i) throws ParseException {
        for (BNNode bNNode : this.bnet.vertexSet()) {
            if (bNNode.getFunction() instanceof TabularCPD) {
                DistTable distTable = new DistTable((TabularCPD) bNNode.getFunction(), i);
                bNNode.setProperty("distribution", distTable);
                bNNode.setFunction(distTable.createCPF());
            }
        }
    }
}
