package com.kotlinnlp.neuralparser.parsers.lhrparser;

import com.kotlinnlp.dependencytree.DependencyTree;
import com.kotlinnlp.dependencytree.Deprel;
import com.kotlinnlp.dependencytree.POSTag;
import com.kotlinnlp.linguisticdescription.sentence.MorphoSyntacticSentence;
import com.kotlinnlp.neuralparser.NeuralParser;
import com.kotlinnlp.neuralparser.language.ParsingSentence;
import com.kotlinnlp.neuralparser.parsers.lhrparser.decoders.CosineDecoder;
import com.kotlinnlp.neuralparser.parsers.lhrparser.deprelselectors.MorphoDeprelSelector;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodels.contextencoder.ContextEncoder;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodels.headsencoder.HeadsEncoder;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodels.labeler.DeprelLabeler;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodels.labeler.DeprelLabelerModel;
import com.kotlinnlp.neuralparser.parsers.lhrparser.utils.ArcScores;
import com.kotlinnlp.neuralparser.parsers.lhrparser.utils.CyclesFixer;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: LHRParser.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��D\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0003\u0018��2\b\u0012\u0004\u0012\u00020\u00020\u0001B\r\u0012\u0006\u0010\u0003\u001a\u00020\u0002¢\u0006\u0002\u0010\u0004J\u0018\u0010\u000b\u001a\u00020\f2\u0006\u0010\r\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u0010H\u0002J\u0010\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u0014H\u0016J\u0014\u0010\u0015\u001a\u00020\u0016*\u00020\f2\u0006\u0010\u000f\u001a\u00020\u0010H\u0002J\u0014\u0010\u0017\u001a\u00020\u0016*\u00020\f2\u0006\u0010\r\u001a\u00020\u000eH\u0002J\u0014\u0010\u0018\u001a\u00020\u0016*\u00020\f2\u0006\u0010\u000f\u001a\u00020\u0010H\u0002R\u0010\u0010\u0005\u001a\u0004\u0018\u00010\u0006X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0003\u001a\u00020\u0002X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\t\u0010\n¨\u0006\u0019"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LHRParser;", "Lcom/kotlinnlp/neuralparser/NeuralParser;", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LHRModel;", "model", "(Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LHRModel;)V", "deprelLabeler", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler;", "lssEncoder", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LSSEncoder;", "getModel", "()Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LHRModel;", "buildDependencyTree", "Lcom/kotlinnlp/dependencytree/DependencyTree;", "lss", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LatentSyntacticStructure;", "scores", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/utils/ArcScores;", "parse", "Lcom/kotlinnlp/linguisticdescription/sentence/MorphoSyntacticSentence;", "sentence", "Lcom/kotlinnlp/neuralparser/language/ParsingSentence;", "assignHeads", "", "assignLabels", "fixCycles", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/LHRParser.class */
public final class LHRParser implements NeuralParser<LHRModel> {
    private final LSSEncoder lssEncoder;
    private final DeprelLabeler deprelLabeler;

    @NotNull
    private final LHRModel model;

    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    public MorphoSyntacticSentence parse(@NotNull ParsingSentence parsingSentence) {
        Intrinsics.checkParameterIsNotNull(parsingSentence, "sentence");
        LatentSyntacticStructure encode = this.lssEncoder.encode(parsingSentence);
        return parsingSentence.toMorphoSyntacticSentence(buildDependencyTree(encode, new CosineDecoder().decode(encode)), getModel().getMorphoDeprelSelector());
    }

    private final DependencyTree buildDependencyTree(LatentSyntacticStructure latentSyntacticStructure, ArcScores arcScores) {
        DependencyTree dependencyTree = new DependencyTree(latentSyntacticStructure.getSize());
        assignHeads(dependencyTree, arcScores);
        fixCycles(dependencyTree, arcScores);
        assignLabels(dependencyTree, latentSyntacticStructure);
        return dependencyTree;
    }

    private final void assignHeads(@NotNull final DependencyTree dependencyTree, final ArcScores arcScores) {
        Pair<Integer, Double> findHighestScoringTop = arcScores.findHighestScoringTop();
        int intValue = ((Number) findHighestScoringTop.component1()).intValue();
        dependencyTree.setAttachmentScore(intValue, ((Number) findHighestScoringTop.component2()).doubleValue());
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<Integer, Map<Integer, ? extends Double>> entry : arcScores.entrySet()) {
            if (!(entry.getKey().intValue() == intValue)) {
                linkedHashMap.put(entry.getKey(), entry.getValue());
            }
        }
        linkedHashMap.forEach(new BiConsumer<Integer, Map<Integer, ? extends Double>>() { // from class: com.kotlinnlp.neuralparser.parsers.lhrparser.LHRParser$assignHeads$2
            @Override // java.util.function.BiConsumer
            public /* bridge */ /* synthetic */ void accept(Integer num, Map<Integer, ? extends Double> map) {
                accept2(num, (Map<Integer, Double>) map);
            }

            /* renamed from: accept, reason: avoid collision after fix types in other method */
            public final void accept2(@NotNull Integer num, @NotNull Map<Integer, Double> map) {
                Intrinsics.checkParameterIsNotNull(num, "depId");
                Intrinsics.checkParameterIsNotNull(map, "<anonymous parameter 1>");
                Pair<Integer, Double> findHighestScoringHead = arcScores.findHighestScoringHead(num.intValue(), CollectionsKt.listOf(-1));
                if (findHighestScoringHead == null) {
                    Intrinsics.throwNpe();
                }
                DependencyTree.setArc$default(dependencyTree, num.intValue(), ((Number) findHighestScoringHead.component1()).intValue(), (Deprel) null, (POSTag) null, ((Number) findHighestScoringHead.component2()).doubleValue(), true, 12, (Object) null);
            }
        });
    }

    private final void fixCycles(@NotNull DependencyTree dependencyTree, ArcScores arcScores) {
        new CyclesFixer(dependencyTree, arcScores).fixCycles();
    }

    private final void assignLabels(@NotNull DependencyTree dependencyTree, LatentSyntacticStructure latentSyntacticStructure) {
        MorphoDeprelSelector morphoDeprelSelector = getModel().getMorphoDeprelSelector();
        DeprelLabeler deprelLabeler = this.deprelLabeler;
        if (deprelLabeler != null) {
            int i = 0;
            for (Object obj : deprelLabeler.predict(new DeprelLabeler.Input(latentSyntacticStructure, dependencyTree))) {
                int i2 = i;
                i++;
                dependencyTree.setDeprel(i2, morphoDeprelSelector.getBestDeprel((List) obj, latentSyntacticStructure.getSentence(), i2, dependencyTree.getHeads()[i2]));
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    public LHRModel getModel() {
        return this.model;
    }

    public LHRParser(@NotNull LHRModel lHRModel) {
        DeprelLabeler deprelLabeler;
        Intrinsics.checkParameterIsNotNull(lHRModel, "model");
        this.model = lHRModel;
        this.lssEncoder = new LSSEncoder(getModel().getTokensEncoderWrapperModel().buildWrapper(false), new ContextEncoder(getModel().getContextEncoderModel(), false, 0, 4, null), new HeadsEncoder(getModel().getHeadsEncoderModel(), false, 0, 4, null), getModel().getRootEmbedding().getArray().getValues());
        LHRParser lHRParser = this;
        DeprelLabelerModel labelerModel = getModel().getLabelerModel();
        if (labelerModel != null) {
            lHRParser = lHRParser;
            deprelLabeler = new DeprelLabeler(labelerModel, false, 0, 4, null);
        } else {
            deprelLabeler = null;
        }
        lHRParser.deprelLabeler = deprelLabeler;
    }
}
