package org.integratedmodelling.riskwiz.influence.jensen;

import java.util.Hashtable;
import java.util.Iterator;
import org.integratedmodelling.riskwiz.bn.BNNode;
import org.integratedmodelling.riskwiz.bn.BeliefNetwork;
import org.integratedmodelling.riskwiz.domain.DiscreteDomain;
import org.integratedmodelling.riskwiz.domain.DomainFactory;
import org.integratedmodelling.riskwiz.influence.JTPotential;
import org.integratedmodelling.riskwiz.influence.jensen.StrongJoinTree;
import org.integratedmodelling.riskwiz.jtree.IJoinTreePN;
import org.integratedmodelling.riskwiz.pt.PT;
import org.integratedmodelling.riskwiz.pt.TableFactory;
import org.integratedmodelling.riskwiz.pt.map.FMarginalizationMap;
import org.integratedmodelling.riskwiz.pt.map.SubtableFastMap2;
import org.jgrapht.Graphs;

/* loaded from: input_file:lib/riskwiz-1.0.0.jar:org/integratedmodelling/riskwiz/influence/jensen/PolicyNetworkJoinTree.class */
public class PolicyNetworkJoinTree extends StrongJoinTree implements IJoinTreePN<SJTVertex> {
    protected Hashtable<DiscreteDomain, PT> decisionPotentialHash;

    public PolicyNetworkJoinTree(BeliefNetwork beliefNetwork) {
        super(beliefNetwork);
    }

    public PolicyNetworkJoinTree(StrongJoinTree strongJoinTree) {
        this.bn = strongJoinTree.getBeliefNetwork();
        this.root = strongJoinTree.getRoot();
        this.policyHash = strongJoinTree.policyHash;
        this.clusterHash = strongJoinTree.clusterHash;
        Graphs.addGraph(this, strongJoinTree);
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTreePN
    public void setDecision(BNNode bNNode, int i) {
        this.decisionPotentialHash.put(bNNode.getDiscretizedDomain(), TableFactory.createObservation(bNNode.getDiscretizedDomain(), i));
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTreePN
    public void setDecision(BNNode bNNode, String str) {
        this.decisionPotentialHash.put(bNNode.getDiscretizedDomain(), TableFactory.createObservation(bNNode.getDiscretizedDomain(), str));
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTreePN
    public void setOptimalPolicy(BNNode bNNode) {
        this.decisionPotentialHash.remove(bNNode.getDomain());
        this.decisionPotentialHash.put(bNNode.getDiscretizedDomain(), this.policyHash.get(bNNode.getDomain()));
    }

    @Override // org.integratedmodelling.riskwiz.influence.jensen.StrongJoinTree, org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision, org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void initializeStructiure() {
        setVaciousProbabilityPotentials();
        this.clusterHash = new Hashtable<>();
        for (BNNode bNNode : this.bn.vertexSet()) {
            SJTVertex assignParentCluster = assignParentCluster(bNNode);
            if (bNNode.isNature()) {
                JTPotential potential = assignParentCluster.getPotential();
                this.clusterHash.put(bNNode, new StrongJoinTree.ClusterBundle(assignParentCluster, potential.createSubtableFastMap(bNNode.getDiscreteCPT()), potential.createFMarginalizationMap(bNNode.getDiscretizedDomain()), potential.createSubtableFastMap(DomainFactory.createDomainProduct(bNNode.getDiscretizedDomain()))));
            } else if (bNNode.isUtility()) {
                JTPotential potential2 = assignParentCluster.getPotential();
                this.clusterHash.put(bNNode, new StrongJoinTree.ClusterBundle(assignParentCluster, potential2.createSubtableFastMap(bNNode.getDiscreteCPT().getParentsDomains()), potential2.createFMarginalizationMap(bNNode.getDiscreteCPT().getParentsDomains()), null));
            } else if (bNNode.isDecision()) {
                JTPotential potential3 = assignParentCluster.getPotential();
                this.clusterHash.put(bNNode, new StrongJoinTree.ClusterBundle(assignParentCluster, null, potential3.createFMarginalizationMap(bNNode.getDiscretizedDomain()), potential3.createSubtableFastMap(DomainFactory.createDomainProduct(bNNode.getDiscretizedDomain()))));
            }
        }
        for (SJTVertex sJTVertex : vertexSet()) {
            Iterator<SJTEdge> it2 = edgesOf(sJTVertex).iterator();
            while (it2.hasNext()) {
                sJTVertex.createFastMapsPN(it2.next());
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.influence.jensen.StrongJoinTree, org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision, org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void initialize() {
        setVaciousProbabilityPotentials();
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                StrongJoinTree.ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
                clusterBundle.jtcluster.getPotential().multiplyByProbabilitySubtable(bNNode.getDiscreteCPT(), clusterBundle.fopmap);
            } else if (bNNode.isDecision()) {
                JTPotential potential = this.clusterHash.get(bNNode).jtcluster.getPotential();
                potential.multiplyByProbabilitySubtable(this.decisionPotentialHash.get(bNNode), potential.createSubtableFastMap(this.decisionPotentialHash.get(bNNode)));
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.influence.jensen.StrongJoinTree, org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision, org.integratedmodelling.riskwiz.jtree.IJoinTree
    public BeliefNetwork getBeliefNetwork() {
        return this.bn;
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void propagateEvidence(SJTVertex sJTVertex) {
        unmarkAll();
        collectEvidence(sJTVertex);
        unmarkAll();
        distributeEvidence(sJTVertex);
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void propagateEvidence(BNNode bNNode) {
        StrongJoinTree.ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
        SJTVertex sJTVertex = clusterBundle.jtcluster;
        sJTVertex.getPotential().multiplyByProbabilitySubtable(bNNode.getEvidence(), clusterBundle.liklihoodfmap);
        propagateEvidence(sJTVertex);
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void collectEvidence(SJTVertex sJTVertex) {
        sJTVertex.isMarked = true;
        for (SJTVertex sJTVertex2 : getNeighbors(sJTVertex)) {
            if (!sJTVertex2.isMarked) {
                collectEvidence(sJTVertex2);
                passMessage(sJTVertex2, sJTVertex);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void distributeEvidence(SJTVertex sJTVertex) {
        sJTVertex.isMarked = true;
        for (SJTVertex sJTVertex2 : getNeighbors(sJTVertex)) {
            if (!sJTVertex2.isMarked) {
                passMessage(sJTVertex, sJTVertex2);
                distributeEvidence(sJTVertex2);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void passMessage(SJTVertex sJTVertex, SJTVertex sJTVertex2) {
        SJTEdge edge = getEdge(sJTVertex, sJTVertex2);
        PT probabilityPotential = edge.getPotential().getProbabilityPotential();
        PT m10066clone = probabilityPotential.m10066clone();
        PT.marginalizeDomainsFast(probabilityPotential, sJTVertex.getPotential().getProbabilityPotential(), sJTVertex.getFMarginalizationMapPN(edge));
        PT probabilityPotential2 = sJTVertex2.getPotential().getProbabilityPotential();
        SubtableFastMap2 subtableOpFastMap = sJTVertex2.getSubtableOpFastMap(edge);
        probabilityPotential2.multiplyBySubtableFast(probabilityPotential, subtableOpFastMap);
        probabilityPotential2.divideBySubtableFast(m10066clone, subtableOpFastMap);
    }

    public void setNodeMarginals() {
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature() || bNNode.isDecision()) {
                StrongJoinTree.ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
                SJTVertex sJTVertex = clusterBundle.jtcluster;
                FMarginalizationMap fMarginalizationMap = clusterBundle.mfmap;
                PT pt = new PT(fMarginalizationMap.getProjectionDomainProduct());
                PT.marginalizeDomainsFast(pt, sJTVertex.getPotential().getProbabilityPotential(), fMarginalizationMap);
                bNNode.setMarginal(pt);
            } else if (bNNode.isUtility()) {
                StrongJoinTree.ClusterBundle clusterBundle2 = this.clusterHash.get(bNNode);
                PT m10066clone = clusterBundle2.jtcluster.getPotential().getProbabilityPotential().m10066clone();
                m10066clone.multiplyBySubtable(bNNode.getDiscreteCPT(), clusterBundle2.fopmap);
                bNNode.setMarginalUtility(m10066clone.sum());
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void setNodeConditionalMarginals() {
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature() || bNNode.isDecision()) {
                StrongJoinTree.ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
                SJTVertex sJTVertex = clusterBundle.jtcluster;
                FMarginalizationMap fMarginalizationMap = clusterBundle.mfmap;
                PT pt = new PT(fMarginalizationMap.getProjectionDomainProduct());
                PT.marginalizeDomainsFast(pt, sJTVertex.getPotential().getProbabilityPotential(), fMarginalizationMap);
                pt.normalize();
                bNNode.setMarginal(pt);
            } else if (bNNode.isUtility()) {
                StrongJoinTree.ClusterBundle clusterBundle2 = this.clusterHash.get(bNNode);
                PT m10066clone = clusterBundle2.jtcluster.getPotential().getProbabilityPotential().m10066clone();
                m10066clone.normalize();
                m10066clone.multiplyBySubtable(bNNode.getDiscreteCPT(), clusterBundle2.fopmap);
                bNNode.setMarginalUtility(m10066clone.sum());
            }
        }
    }

    public void setVaciousProbabilityPotentials() {
        Iterator<SJTVertex> it2 = vertexSet().iterator();
        while (it2.hasNext()) {
            it2.next().getPotential().setVaciousProb();
        }
        Iterator<SJTEdge> it3 = edgeSet().iterator();
        while (it3.hasNext()) {
            it3.next().getPotential().setVaciousProb();
        }
    }
}
