package edu.emory.mathcs.nlp.component.dep;

import edu.emory.mathcs.nlp.common.treebank.DEPTagEn;
import edu.emory.mathcs.nlp.component.template.OnlineComponent;
import edu.emory.mathcs.nlp.component.template.eval.Eval;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.learning.util.MLUtils;
import it.unimi.dsi.fastutil.ints.IntCollection;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.io.InputStream;
import java.util.List;

/* loaded from: input_file:edu/emory/mathcs/nlp/component/dep/DEPParser.class */
public class DEPParser<N extends AbstractNLPNode<N>> extends OnlineComponent<N, DEPState<N>> {
    private static final long serialVersionUID = 7031031976396726276L;
    private DEPLabelCandidate<N> label_candidates;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/emory/mathcs/nlp/component/dep/DEPParser$DEPTriple.class */
    public class DEPTriple {
        int headId;
        int yhat;
        double score;

        public DEPTriple() {
            set(-1, -1, -1.7976931348623157E308d);
        }

        public void set(int i, int i2, double d) {
            this.headId = i;
            this.yhat = i2;
            this.score = d;
        }

        public boolean isNull() {
            return this.headId < 0;
        }
    }

    public DEPParser() {
        super(false);
        this.label_candidates = new DEPLabelCandidate<>();
    }

    public DEPParser(InputStream inputStream) {
        super(false, inputStream);
        this.label_candidates = new DEPLabelCandidate<>();
    }

    @Override // edu.emory.mathcs.nlp.component.template.OnlineComponent
    public Eval createEvaluator() {
        return new DEPEval();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.mathcs.nlp.component.template.OnlineComponent
    public DEPState<N> initState(N[] nArr) {
        return new DEPState<>(nArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.mathcs.nlp.component.template.OnlineComponent
    public DEPState<N> initState(List<N[]> list) {
        return null;
    }

    @Override // edu.emory.mathcs.nlp.component.template.OnlineComponent
    protected void putLabel(String str, int i) {
        this.label_candidates.add(str, i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.mathcs.nlp.component.template.OnlineComponent
    public int[] getPrediction(DEPState<N> dEPState, float[] fArr) {
        return this.label_candidates.getLabelIndices(dEPState.getStack(), dEPState.getInput(), fArr);
    }

    public DEPLabelCandidate<N> getLabelCandidates() {
        return this.label_candidates;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.mathcs.nlp.component.template.OnlineComponent
    public void postProcess(DEPState<N> dEPState) {
        N[] nodes = dEPState.getNodes();
        for (int i = 1; i < nodes.length; i++) {
            N n = nodes[i];
            if (!n.hasDependencyHead()) {
                DEPParser<N>.DEPTriple dEPTriple = new DEPTriple();
                processHeadless(dEPState, dEPTriple, nodes, i, -1);
                processHeadless(dEPState, dEPTriple, nodes, i, 1);
                if (dEPTriple.isNull()) {
                    n.setDependencyHead(nodes[0], DEPTagEn.DEP_ROOT);
                } else {
                    n.setDependencyHead(nodes[dEPTriple.headId], new DEPLabel(this.optimizer.getLabel(dEPTriple.yhat)).getDeprel());
                }
            }
        }
    }

    void processHeadless(DEPState<N> dEPState, DEPParser<N>.DEPTriple dEPTriple, N[] nArr, int i, int i2) {
        IntSet leftArcs = i2 > 0 ? this.label_candidates.getLeftArcs() : this.label_candidates.getRightArcs();
        N n = nArr[i];
        int i3 = 0;
        int i4 = i;
        while (true) {
            int i5 = i4 + i2;
            if (0 > i5 || i5 >= nArr.length) {
                return;
            }
            i3++;
            if (i3 > 5) {
                return;
            }
            if (!nArr[i5].isDescendantOf(n)) {
                if (i2 > 0) {
                    dEPState.reset(i, i5);
                } else {
                    dEPState.reset(i5, i);
                }
                int argmax = MLUtils.argmax(this.optimizer.scores(this.feature_template.createFeatureVector(dEPState, isTrain())), (IntCollection) leftArcs);
                if (dEPTriple.score < r0[argmax]) {
                    dEPTriple.set(i5, argmax, r0[argmax]);
                }
            }
            i4 = i5;
        }
    }
}
