package org.integratedmodelling.riskwiz.influence.jensen;

import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.Vector;
import org.integratedmodelling.riskwiz.bn.BNNode;
import org.integratedmodelling.riskwiz.bn.BeliefNetwork;
import org.integratedmodelling.riskwiz.domain.DiscreteDomain;
import org.integratedmodelling.riskwiz.domain.DomainFactory;
import org.integratedmodelling.riskwiz.graph.RiskUndirectedGraph;
import org.integratedmodelling.riskwiz.influence.JTPotential;
import org.integratedmodelling.riskwiz.influence.jensen.SJTVertex;
import org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision;
import org.integratedmodelling.riskwiz.pt.CPT;
import org.integratedmodelling.riskwiz.pt.map.DomainMap2;
import org.integratedmodelling.riskwiz.pt.map.FMarginalizationMap;
import org.integratedmodelling.riskwiz.pt.map.FastMap2;
import org.jgrapht.traverse.BreadthFirstIterator;

/* loaded from: input_file:lib/riskwiz-1.0.0.jar:org/integratedmodelling/riskwiz/influence/jensen/StrongJoinTree.class */
public class StrongJoinTree extends RiskUndirectedGraph<SJTVertex, SJTEdge> implements IJoinTreeDecision<SJTVertex> {
    protected BeliefNetwork bn;
    protected SJTVertex root;
    protected Hashtable<BNNode, ClusterBundle> clusterHash;
    protected Hashtable<DiscreteDomain, CPT> policyHash;
    protected Vector<Object> rootMmaps;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:lib/riskwiz-1.0.0.jar:org/integratedmodelling/riskwiz/influence/jensen/StrongJoinTree$ClusterBundle.class */
    public class ClusterBundle {
        protected SJTVertex jtcluster;
        protected DomainMap2 fopmap;
        protected FastMap2 liklihoodfmap;
        protected FMarginalizationMap mfmap;

        public ClusterBundle(SJTVertex sJTVertex, DomainMap2 domainMap2, FMarginalizationMap fMarginalizationMap, FastMap2 fastMap2) {
            this.fopmap = domainMap2;
            this.jtcluster = sJTVertex;
            this.mfmap = fMarginalizationMap;
            this.liklihoodfmap = fastMap2;
        }
    }

    public StrongJoinTree(BeliefNetwork beliefNetwork) {
        super(SJTEdge.class);
        this.bn = beliefNetwork;
    }

    public StrongJoinTree() {
        super(SJTEdge.class);
    }

    public void setVaciousPotentials() {
        Iterator<SJTVertex> it2 = vertexSet().iterator();
        while (it2.hasNext()) {
            it2.next().setVacious();
        }
        Iterator<SJTEdge> it3 = edgeSet().iterator();
        while (it3.hasNext()) {
            it3.next().setVacious();
        }
    }

    public void initializeStructiure() {
        setVaciousPotentials();
        this.clusterHash = new Hashtable<>();
        for (BNNode bNNode : this.bn.vertexSet()) {
            SJTVertex assignParentCluster = assignParentCluster(bNNode);
            if (bNNode.isNature()) {
                JTPotential potential = assignParentCluster.getPotential();
                this.clusterHash.put(bNNode, new ClusterBundle(assignParentCluster, potential.createSubtableFastMap(bNNode.getDiscreteCPT()), potential.createFMarginalizationMap(bNNode.getDiscretizedDomain()), potential.createSubtableFastMap(DomainFactory.createDomainProduct(bNNode.getDiscretizedDomain()))));
            } else if (bNNode.isUtility()) {
                JTPotential potential2 = assignParentCluster.getPotential();
                this.clusterHash.put(bNNode, new ClusterBundle(assignParentCluster, potential2.createSubtableFastMap(bNNode.getDiscreteCPT().getParentsDomains()), potential2.createFMarginalizationMap(bNNode.getDiscreteCPT().getParentsDomains()), null));
            }
        }
        for (SJTVertex sJTVertex : vertexSet()) {
            Iterator<SJTEdge> it2 = edgesOf(sJTVertex).iterator();
            while (it2.hasNext()) {
                sJTVertex.createFastMaps(it2.next());
            }
        }
        createRootMarginalizationFastMap();
    }

    public void initialize() {
        setVaciousPotentials();
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
                clusterBundle.jtcluster.getPotential().multiplyByProbabilitySubtable(bNNode.getDiscreteCPT(), clusterBundle.fopmap);
            } else if (bNNode.isUtility()) {
                ClusterBundle clusterBundle2 = this.clusterHash.get(bNNode);
                clusterBundle2.jtcluster.getPotential().addUtilitySubtable(bNNode.getDiscreteCPT(), clusterBundle2.fopmap);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision
    public void initializeLikelihoods() {
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isProbabilistic()) {
                initializeLikelihood(bNNode);
            }
        }
    }

    public void initializeLikelihood(BNNode bNNode) {
        ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
        JTPotential potential = clusterBundle.jtcluster.getPotential();
        if (bNNode.hasEvidence()) {
            potential.multiplyByProbabilitySubtable(bNNode.getEvidence(), clusterBundle.liklihoodfmap);
        }
    }

    public void collectEvidence(SJTVertex sJTVertex) {
        sJTVertex.isMarked = true;
        for (SJTVertex sJTVertex2 : getNeighbors(sJTVertex)) {
            if (!sJTVertex2.isMarked) {
                collectEvidence(sJTVertex2);
                passMessage(sJTVertex2, sJTVertex);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision
    public void propagateEvidence() {
        this.policyHash = new Hashtable<>();
        unmarkAll();
        collectEvidence(this.root);
        JTPotential.marginalizeDomainsSequence(this.root.getPotential(), this.rootMmaps, this.policyHash);
    }

    public void passMessage(SJTVertex sJTVertex, SJTVertex sJTVertex2) {
        SJTEdge edge = getEdge(sJTVertex, sJTVertex2);
        edge.getPotential();
        JTPotential marginalizeDomainsSequence = JTPotential.marginalizeDomainsSequence(sJTVertex.getPotential(), sJTVertex.getMarginalizationFastMap(edge), this.policyHash);
        edge.setPotential(marginalizeDomainsSequence);
        sJTVertex2.getPotential().multiplyBySubtableFast(marginalizeDomainsSequence, sJTVertex2.getSubtableOpFastMap(edge));
    }

    public void unmarkAll() {
        Iterator<SJTVertex> it2 = vertexSet().iterator();
        while (it2.hasNext()) {
            it2.next().isMarked = false;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SJTVertex assignParentCluster(BNNode bNNode) {
        if (bNNode.isDecision()) {
            return assignParentClusterDecision(bNNode);
        }
        HashSet hashSet = new HashSet();
        if (!bNNode.isUtility()) {
            hashSet.add(bNNode);
        }
        hashSet.addAll(this.bn.getParents((BeliefNetwork) bNNode));
        SJTVertex sJTVertex = null;
        Iterator<SJTVertex> it2 = vertexSet().iterator();
        while (true) {
            if (!it2.hasNext()) {
                break;
            }
            SJTVertex next = it2.next();
            if (next.containsAll(hashSet)) {
                sJTVertex = next;
                break;
            }
        }
        return sJTVertex;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected SJTVertex assignParentClusterDecision(BNNode bNNode) {
        BreadthFirstIterator breadthFirstIterator = new BreadthFirstIterator(this, getRoot());
        while (breadthFirstIterator.hasNext()) {
            SJTVertex sJTVertex = (SJTVertex) breadthFirstIterator.next();
            if (sJTVertex.contains(bNNode)) {
                return sJTVertex;
            }
        }
        return null;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision
    public SJTVertex getRoot() {
        return this.root;
    }

    public void setRoot(SJTVertex sJTVertex) {
        this.root = sJTVertex;
    }

    public BeliefNetwork getBeliefNetwork() {
        return this.bn;
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision
    public CPT getPolicy(DiscreteDomain discreteDomain) {
        return this.policyHash.get(discreteDomain);
    }

    public void createRootMarginalizationFastMap() {
        this.rootMmaps = new Vector<>();
        TreeSet treeSet = new TreeSet(new SJTVertex.EliminationOrder());
        treeSet.addAll(this.root.getClique());
        Vector vector = new Vector();
        vector.addAll(this.root.getPotential().getDomainProduct());
        while (!treeSet.isEmpty()) {
            Vector vector2 = new Vector();
            while (!treeSet.isEmpty() && ((BNNode) treeSet.first()).isNature()) {
                BNNode bNNode = (BNNode) treeSet.first();
                vector2.add(bNNode.getDiscretizedDomain());
                treeSet.remove(bNNode);
            }
            if (!vector2.isEmpty()) {
                this.rootMmaps.add(new FMarginalizationMap((Vector<DiscreteDomain>) vector, (Vector<DiscreteDomain>) vector2));
                vector.removeAll(vector2);
            }
            while (!treeSet.isEmpty() && ((BNNode) treeSet.first()).isDecision()) {
                BNNode bNNode2 = (BNNode) treeSet.first();
                this.rootMmaps.add(bNNode2.getDomain());
                vector.remove(bNNode2.getDomain());
                treeSet.remove(bNNode2);
            }
        }
    }

    public void report(SJTVertex sJTVertex) {
        System.out.println(" clique:");
        Iterator<BNNode> it2 = sJTVertex.getClique().iterator();
        while (it2.hasNext()) {
            System.out.println(it2.next().getName());
        }
        System.out.println();
    }

    public void reportAll() {
        Iterator<SJTVertex> it2 = vertexSet().iterator();
        while (it2.hasNext()) {
            report(it2.next());
        }
    }
}
