package org.integratedmodelling.riskwiz.inference.ls;

import com.ibm.icu.text.PluralRules;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Set;
import java.util.Vector;
import org.integratedmodelling.riskwiz.bn.BNNode;
import org.integratedmodelling.riskwiz.bn.BeliefNetwork;
import org.integratedmodelling.riskwiz.domain.DomainFactory;
import org.integratedmodelling.riskwiz.graph.RiskUndirectedGraph;
import org.integratedmodelling.riskwiz.jtree.IJoinTree;
import org.integratedmodelling.riskwiz.jtree.JTEdge;
import org.integratedmodelling.riskwiz.pt.CPT;
import org.integratedmodelling.riskwiz.pt.PT;
import org.integratedmodelling.riskwiz.pt.map.DomainMap2;
import org.integratedmodelling.riskwiz.pt.map.FMarginalizationMap;
import org.integratedmodelling.riskwiz.pt.map.SubtableFastMap2;

/* loaded from: input_file:lib/riskwiz-1.0.0.jar:org/integratedmodelling/riskwiz/inference/ls/JoinTree.class */
public class JoinTree extends RiskUndirectedGraph<JTVertexHugin, JTEdgeHugin> implements IJoinTree<JTVertexHugin> {
    BeliefNetwork bn;
    Hashtable<BNNode, ClusterBundle> clusterHash;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/riskwiz-1.0.0.jar:org/integratedmodelling/riskwiz/inference/ls/JoinTree$ClusterBundle.class */
    public class ClusterBundle {
        private JTVertexHugin jtcluster;
        private DomainMap2 fopmap;
        private DomainMap2 liklihoodfmap;
        private FMarginalizationMap mfmap;

        public ClusterBundle(JTVertexHugin jTVertexHugin, DomainMap2 domainMap2, FMarginalizationMap fMarginalizationMap, DomainMap2 domainMap22) {
            this.fopmap = domainMap2;
            this.jtcluster = jTVertexHugin;
            this.mfmap = fMarginalizationMap;
            this.liklihoodfmap = domainMap22;
        }
    }

    public JoinTree(BeliefNetwork beliefNetwork) {
        super(JTEdgeHugin.class);
        this.bn = beliefNetwork;
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void initializeStructiure() {
        Set<JTVertexHugin> vertexSet = vertexSet();
        Iterator<JTVertexHugin> it2 = vertexSet.iterator();
        while (it2.hasNext()) {
            it2.next().setAll(1.0d);
        }
        Iterator<JTEdgeHugin> it3 = edgeSet().iterator();
        while (it3.hasNext()) {
            it3.next().setAll(1.0d);
        }
        this.clusterHash = new Hashtable<>();
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                JTVertexHugin assignParentCluster = assignParentCluster(bNNode);
                PT pt = assignParentCluster.getPt();
                CPT discreteCPT = bNNode.getDiscreteCPT();
                SubtableFastMap2 createSubtableFastMap = pt.createSubtableFastMap(discreteCPT);
                this.clusterHash.put(bNNode, new ClusterBundle(assignParentCluster, createSubtableFastMap, new FMarginalizationMap(pt.getDomainProduct(), bNNode.getDiscretizedDomain()), pt.createSubtableFastMap(DomainFactory.createDomainProduct(bNNode.getDiscretizedDomain()))));
                pt.multiplyBySubtable(discreteCPT, createSubtableFastMap);
            }
        }
        for (JTVertexHugin jTVertexHugin : vertexSet) {
            Iterator<JTEdgeHugin> it4 = edgesOf(jTVertexHugin).iterator();
            while (it4.hasNext()) {
                jTVertexHugin.createFastMaps(it4.next());
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void initialize() {
        Iterator<JTVertexHugin> it2 = vertexSet().iterator();
        while (it2.hasNext()) {
            it2.next().setAll(1.0d);
        }
        Iterator<JTEdgeHugin> it3 = edgeSet().iterator();
        while (it3.hasNext()) {
            it3.next().setAll(1.0d);
        }
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
                clusterBundle.jtcluster.getPt().multiplyBySubtableFast(bNNode.getDiscreteCPT(), (SubtableFastMap2) clusterBundle.fopmap);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void initializeLikelihoods() {
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                initializeLikelihood(bNNode);
            }
        }
    }

    public void initializeLikelihood(BNNode bNNode) {
        ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
        PT pt = clusterBundle.jtcluster.getPt();
        if (bNNode.hasEvidence()) {
            if (clusterBundle.liklihoodfmap instanceof SubtableFastMap2) {
                pt.multiplyBySubtableFast(bNNode.getEvidence(), (SubtableFastMap2) clusterBundle.liklihoodfmap);
            } else {
                pt.multiplyBySubtable(bNNode.getEvidence(), clusterBundle.liklihoodfmap);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void propagateEvidence(JTVertexHugin jTVertexHugin) {
        unmarkAll();
        collectEvidence(jTVertexHugin);
        unmarkAll();
        distributeEvidence(jTVertexHugin);
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void propagateEvidence(BNNode bNNode) {
        ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
        JTVertexHugin jTVertexHugin = clusterBundle.jtcluster;
        jTVertexHugin.getPt().multiplyBySubtable(bNNode.getEvidence(), clusterBundle.liklihoodfmap);
        unmarkAll();
        collectEvidence(jTVertexHugin);
        unmarkAll();
        distributeEvidence(jTVertexHugin);
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void propagateEvidence() {
        propagateEvidence(vertexSet().iterator().next());
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void collectEvidence(JTVertexHugin jTVertexHugin) {
        jTVertexHugin.isMarked = true;
        for (JTVertexHugin jTVertexHugin2 : getNeighbors(jTVertexHugin)) {
            if (!jTVertexHugin2.isMarked) {
                collectEvidence(jTVertexHugin2);
                passMessage(jTVertexHugin2, jTVertexHugin);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void distributeEvidence(JTVertexHugin jTVertexHugin) {
        jTVertexHugin.isMarked = true;
        for (JTVertexHugin jTVertexHugin2 : getNeighbors(jTVertexHugin)) {
            if (!jTVertexHugin2.isMarked) {
                passMessage(jTVertexHugin, jTVertexHugin2);
                distributeEvidence(jTVertexHugin2);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void passMessage(JTVertexHugin jTVertexHugin, JTVertexHugin jTVertexHugin2) {
        JTEdgeHugin edge = getEdge(jTVertexHugin, jTVertexHugin2);
        PT pt = edge.getPt();
        PT m10066clone = pt.m10066clone();
        PT.marginalizeDomainsFast(pt, jTVertexHugin.getPt(), jTVertexHugin.getFMarginalizationMap(edge));
        jTVertexHugin2.getPt().multiplyAndDivideBySubtableFast(pt, m10066clone, jTVertexHugin2.getSubtableOpFastMap(edge));
    }

    public void unmarkAll() {
        Iterator<JTVertexHugin> it2 = vertexSet().iterator();
        while (it2.hasNext()) {
            it2.next().isMarked = false;
        }
    }

    public void setNodeMarginals() {
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
                JTVertexHugin jTVertexHugin = clusterBundle.jtcluster;
                FMarginalizationMap fMarginalizationMap = clusterBundle.mfmap;
                PT pt = new PT(fMarginalizationMap.getProjectionDomainProduct());
                PT.marginalizeDomainsFast(pt, jTVertexHugin.getPt(), fMarginalizationMap);
                bNNode.setMarginal(pt);
            }
        }
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public void setNodeConditionalMarginals() {
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                ClusterBundle clusterBundle = this.clusterHash.get(bNNode);
                JTVertexHugin jTVertexHugin = clusterBundle.jtcluster;
                FMarginalizationMap fMarginalizationMap = clusterBundle.mfmap;
                PT pt = new PT(fMarginalizationMap.getProjectionDomainProduct());
                PT.marginalizeDomainsFast(pt, jTVertexHugin.getPt(), fMarginalizationMap);
                pt.normalize();
                bNNode.setMarginal(pt);
            }
        }
    }

    private JTVertexHugin assignParentCluster(BNNode bNNode) {
        HashSet hashSet = new HashSet();
        hashSet.add(bNNode);
        hashSet.addAll(this.bn.getParents((BeliefNetwork) bNNode));
        for (JTVertexHugin jTVertexHugin : vertexSet()) {
            if (jTVertexHugin.containsAll(hashSet)) {
                return jTVertexHugin;
            }
        }
        return null;
    }

    @Override // org.integratedmodelling.riskwiz.jtree.IJoinTree
    public BeliefNetwork getBeliefNetwork() {
        return this.bn;
    }

    public void dump() {
        System.out.println("Join Tree");
        System.out.println("Edges:");
        for (JTEdgeHugin jTEdgeHugin : edgeSet()) {
            System.out.print(getEdgeSource(jTEdgeHugin).getName());
            System.out.print("<--->");
            System.out.print(getEdgeTarget(jTEdgeHugin).getName());
            System.out.println("");
            System.out.println("");
        }
        System.out.println("Nodes:");
        for (JTVertexHugin jTVertexHugin : vertexSet()) {
            System.out.print(String.valueOf(jTVertexHugin.getName()) + PluralRules.KEYWORD_RULE_SEPARATOR);
            printNodes(jTVertexHugin.getClique());
        }
    }

    public void check() {
        System.out.println("Join Tree check");
        for (JTEdgeHugin jTEdgeHugin : edgeSet()) {
            JTVertexHugin edgeSource = getEdgeSource(jTEdgeHugin);
            JTVertexHugin edgeTarget = getEdgeTarget(jTEdgeHugin);
            Set<BNNode> sepset = jTEdgeHugin.getSepset();
            if (!edgeSource.getClique().containsAll(sepset) || !edgeTarget.getClique().containsAll(sepset) || sepset.isEmpty()) {
                System.out.println("Error in" + edgeSource.getName() + "<-->" + edgeTarget.getName());
                System.out.print(String.valueOf(edgeSource.getName()) + " Vertex nodes:");
                printNodes(edgeSource.getClique());
                System.out.println("SepsetNodes:");
                printNodes(sepset);
                System.out.print(String.valueOf(edgeTarget.getName()) + " Vertex nodes:");
                printNodes(edgeSource.getClique());
                System.out.println("");
            }
        }
    }

    public void printNodes(Set<BNNode> set) {
        Iterator<BNNode> it2 = set.iterator();
        while (it2.hasNext()) {
            System.out.print(String.valueOf(it2.next().getName()) + ", ");
        }
        System.out.println("");
    }

    public void checkClusterAssignment() {
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                JTVertexHugin jTVertexHugin = this.clusterHash.get(bNNode).jtcluster;
                System.out.println(String.valueOf(bNNode.getName()) + PluralRules.KEYWORD_RULE_SEPARATOR);
                printNodes(jTVertexHugin.getClique());
            }
        }
    }

    public void checkClusterAssignment2() {
        HashSet<JTVertexHugin> hashSet = new HashSet();
        hashSet.addAll(vertexSet());
        for (BNNode bNNode : this.bn.vertexSet()) {
            if (bNNode.isNature()) {
                JTVertexHugin jTVertexHugin = this.clusterHash.get(bNNode).jtcluster;
                hashSet.remove(jTVertexHugin);
                System.out.println(String.valueOf(bNNode.getName()) + ":  ");
                System.out.print(String.valueOf(jTVertexHugin.getName()) + PluralRules.KEYWORD_RULE_SEPARATOR);
                printNodes(jTVertexHugin.getClique());
            }
        }
        System.out.println("Not Assigned:");
        for (JTVertexHugin jTVertexHugin2 : hashSet) {
            System.out.println(jTVertexHugin2.getName());
            printNodes(jTVertexHugin2.getClique());
        }
    }

    public void printCluster(String str) {
        for (JTVertexHugin jTVertexHugin : vertexSet()) {
            if (jTVertexHugin.getName().equalsIgnoreCase(str)) {
                System.out.println(jTVertexHugin.getName());
                System.out.println(jTVertexHugin.getPt().toString());
            }
        }
    }

    public void printEdges() {
        for (JTEdgeHugin jTEdgeHugin : edgeSet()) {
            System.out.println("source: " + jTEdgeHugin.getSourceVertex().getName() + "  :");
            printNodes(jTEdgeHugin.getSourceVertex().getClique());
            System.out.println("edge has nodes: ");
            printNodes(jTEdgeHugin.getSepset());
            System.out.println("target: " + jTEdgeHugin.getTargetVertex().getName() + " :");
            printNodes(jTEdgeHugin.getTargetVertex().getClique());
        }
    }

    public void checkJTProperty() {
        JTVertexHugin[] jTVertexHuginArr = new JTVertexHugin[vertexSet().size()];
        vertexSet().toArray(jTVertexHuginArr);
        if (jTVertexHuginArr.length > 1) {
            for (int i = 0; i < jTVertexHuginArr.length; i++) {
                for (int i2 = i + 1; i2 < jTVertexHuginArr.length; i2++) {
                    Set<BNNode> intersection = JTEdge.intersection(jTVertexHuginArr[i].getClique(), jTVertexHuginArr[i2].getClique());
                    if (!intersection.isEmpty()) {
                        checkPath(jTVertexHuginArr[i], jTVertexHuginArr[i2], intersection);
                    }
                }
            }
        }
    }

    public void checkNeighbours() {
        for (JTVertexHugin jTVertexHugin : vertexSet()) {
            Set<JTVertexHugin> neighbors = getNeighbors(jTVertexHugin);
            System.out.println(String.valueOf(jTVertexHugin.getName()) + "  has neighbors:");
            Iterator<JTVertexHugin> it2 = neighbors.iterator();
            while (it2.hasNext()) {
                System.out.print(String.valueOf(it2.next().getName()) + ",");
            }
            System.out.println("");
        }
    }

    public void checkNormalization() {
        for (JTVertexHugin jTVertexHugin : vertexSet()) {
            System.out.print(String.valueOf(jTVertexHugin.getName()) + ", sum is  ");
            System.out.println(jTVertexHugin.getPt().sum());
            System.out.println("");
        }
    }

    private void checkPath(JTVertexHugin jTVertexHugin, JTVertexHugin jTVertexHugin2, Set<BNNode> set) {
        for (JTVertexHugin jTVertexHugin3 : getNeighbors(jTVertexHugin)) {
            Vector<JTVertexHugin> vector = new Vector<>();
            vector.add(jTVertexHugin);
            vector.add(jTVertexHugin3);
            if (jTVertexHugin3.getClique().containsAll(set)) {
                checkPath(jTVertexHugin, jTVertexHugin3, jTVertexHugin2, set, true, vector);
            } else {
                checkPath(jTVertexHugin, jTVertexHugin3, jTVertexHugin2, set, false, vector);
            }
        }
    }

    private void checkPath(JTVertexHugin jTVertexHugin, JTVertexHugin jTVertexHugin2, JTVertexHugin jTVertexHugin3, Set<BNNode> set, boolean z, Vector<JTVertexHugin> vector) {
        if (jTVertexHugin2 == jTVertexHugin3) {
            if (z) {
                System.out.println("OK:");
                reportPath(vector);
                return;
            } else {
                System.out.println("Error:");
                printNodes(set);
                reportPath(vector);
                return;
            }
        }
        for (JTVertexHugin jTVertexHugin4 : getNeighbors(jTVertexHugin2)) {
            if (!vector.contains(jTVertexHugin4)) {
                Vector<JTVertexHugin> vector2 = new Vector<>();
                vector2.addAll(vector);
                vector2.add(jTVertexHugin4);
                if (jTVertexHugin4.getClique().containsAll(set)) {
                    checkPath(jTVertexHugin, jTVertexHugin4, jTVertexHugin3, set, z, vector2);
                } else {
                    checkPath(jTVertexHugin, jTVertexHugin4, jTVertexHugin3, set, false, vector2);
                }
            }
        }
    }

    public void reportPath(Vector<JTVertexHugin> vector) {
        Iterator<JTVertexHugin> it2 = vector.iterator();
        while (it2.hasNext()) {
            JTVertexHugin next = it2.next();
            System.out.print(String.valueOf(next.getName()) + ",");
            printNodes(next.getClique());
        }
        System.out.println("");
    }
}
