package edu.emory.mathcs.nlp.component.ner;

import edu.emory.mathcs.nlp.common.collection.tuple.ObjectIntIntTriple;
import edu.emory.mathcs.nlp.component.template.eval.Eval;
import edu.emory.mathcs.nlp.component.template.eval.F1Eval;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.state.L2RState;
import edu.emory.mathcs.nlp.component.template.util.BILOU;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.List;

/* loaded from: input_file:edu/emory/mathcs/nlp/component/ner/NERState.class */
public class NERState<N extends AbstractNLPNode<N>> extends L2RState<N> {
    public NERState(N[] nArr) {
        super(nArr);
    }

    @Override // edu.emory.mathcs.nlp.component.template.state.L2RState
    protected String getLabel(N n) {
        return n.getNamedEntityTag();
    }

    @Override // edu.emory.mathcs.nlp.component.template.state.L2RState
    protected String setLabel(N n, String str) {
        String namedEntityTag = n.getNamedEntityTag();
        n.setNamedEntityTag(str);
        return namedEntityTag;
    }

    @Override // edu.emory.mathcs.nlp.component.template.state.L2RState, edu.emory.mathcs.nlp.component.template.state.NLPState
    public void evaluate(Eval eval) {
        Int2ObjectMap<ObjectIntIntTriple<String>> collectEntityMap = BILOU.collectEntityMap(this.oracle, (v0) -> {
            return v0.toString();
        }, 1, this.nodes.length);
        Int2ObjectMap<ObjectIntIntTriple<String>> collectEntityMap2 = BILOU.collectEntityMap(this.nodes, this::getLabel, 1, this.nodes.length);
        ((F1Eval) eval).add(countCorrect(collectEntityMap2, collectEntityMap), collectEntityMap2.size(), collectEntityMap.size());
    }

    /* JADX WARN: Multi-variable type inference failed */
    private int countCorrect(Int2ObjectMap<ObjectIntIntTriple<String>> int2ObjectMap, Int2ObjectMap<ObjectIntIntTriple<String>> int2ObjectMap2) {
        int i = 0;
        ObjectIterator it = int2ObjectMap.int2ObjectEntrySet().iterator();
        while (it.hasNext()) {
            Int2ObjectMap.Entry entry = (Int2ObjectMap.Entry) it.next();
            ObjectIntIntTriple objectIntIntTriple = (ObjectIntIntTriple) int2ObjectMap2.get(entry.getIntKey());
            if (objectIntIntTriple != null && ((String) objectIntIntTriple.o).equals(((ObjectIntIntTriple) entry.getValue()).o)) {
                i++;
            }
        }
        return i;
    }

    public void postProcess() {
        for (int i = 2; i < this.nodes.length; i++) {
            postProcessBILOUAux(i);
        }
        List<ObjectIntIntTriple<String>> collectEntityList = BILOU.collectEntityList(this.nodes, this::getLabel, 1, this.nodes.length);
        for (int i2 = 1; i2 < this.nodes.length; i2++) {
            setLabel(this.nodes[i2], "O");
        }
        for (ObjectIntIntTriple<String> objectIntIntTriple : collectEntityList) {
            if (objectIntIntTriple.i1 == objectIntIntTriple.i2) {
                setLabel(this.nodes[objectIntIntTriple.i1], BILOU.toBILOUTag(BILOU.U, objectIntIntTriple.o));
            } else {
                setLabel(this.nodes[objectIntIntTriple.i1], BILOU.toBILOUTag(BILOU.B, objectIntIntTriple.o));
                setLabel(this.nodes[objectIntIntTriple.i2], BILOU.toBILOUTag(BILOU.L, objectIntIntTriple.o));
                for (int i3 = objectIntIntTriple.i1 + 1; i3 < objectIntIntTriple.i2; i3++) {
                    setLabel(this.nodes[i3], BILOU.toBILOUTag(BILOU.I, objectIntIntTriple.o));
                }
            }
        }
    }

    private void postProcessBILOUAux(int i) {
        N n = this.nodes[i - 1];
        N n2 = this.nodes[i];
        BILOU bilou = BILOU.toBILOU(getLabel(n));
        BILOU bilou2 = BILOU.toBILOU(getLabel(n2));
        switch (bilou) {
            case U:
                switch (bilou2) {
                    case I:
                        setLabel(n, BILOU.changeChunkType(BILOU.I, getLabel(n)));
                        return;
                    case L:
                        setLabel(n, BILOU.changeChunkType(BILOU.B, getLabel(n)));
                        return;
                    default:
                        return;
                }
            case I:
                switch (bilou2) {
                    case U:
                        setLabel(n2, BILOU.changeChunkType(BILOU.I, getLabel(n2)));
                        return;
                    default:
                        return;
                }
            case L:
                switch (bilou2) {
                    case I:
                        setLabel(n, BILOU.changeChunkType(BILOU.I, getLabel(n)));
                        return;
                    default:
                        return;
                }
            case O:
                switch (bilou2) {
                    case I:
                        setLabel(n, getLabel(n2));
                        return;
                    default:
                        return;
                }
            default:
                return;
        }
    }
}
