package com.kotlinnlp.neuralparser.parsers.arcstandard.atpdjoint2;

import com.kotlinnlp.dependencytree.Deprel;
import com.kotlinnlp.neuralparser.parsers.arcstandard.tpdjoint.ArcStandardTPDJointActionsScorer;
import com.kotlinnlp.neuralparser.templates.inputcontexts.TokensAmbiguousPOSContext;
import com.kotlinnlp.neuralparser.templates.parsers.birnn.ambiguouspos.BiRNNAmbiguousPOSParser;
import com.kotlinnlp.neuralparser.templates.supportstructure.OutputErrorsInit;
import com.kotlinnlp.neuralparser.templates.supportstructure.compositeprediction.ATPDJointStructureFactory;
import com.kotlinnlp.neuralparser.templates.supportstructure.compositeprediction.ATPDJointSupportStructure;
import com.kotlinnlp.neuralparser.utils.actionsembeddings.ActionsVectorsOptimizer;
import com.kotlinnlp.neuralparser.utils.features.DenseFeatures;
import com.kotlinnlp.neuralparser.utils.features.DenseFeaturesErrors;
import com.kotlinnlp.neuralparser.utils.items.DenseItem;
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax;
import com.kotlinnlp.simplednn.core.functionalities.regularization.WeightsRegularization;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.adam.ADAMMethod;
import com.kotlinnlp.simplednn.core.layers.LayerConfiguration;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.syntaxdecoder.modules.bestactionselector.HighestScoreActionSelector;
import com.kotlinnlp.syntaxdecoder.modules.bestactionselector.MultliActionsSelectorByScore;
import com.kotlinnlp.syntaxdecoder.transitionsystem.ActionsGenerator;
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arcstandard.ArcStandard;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arcstandard.ArcStandardTransition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.templates.StackBufferState;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: BiRNNATPDJoint2ArcStandardParser.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��|\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0010\u000b\n��\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\u0018��2&\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u0006\u0012\u0004\u0012\u00020\u00070\u0001B+\u0012\u0006\u0010\b\u001a\u00020\u0007\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u0012\b\b\u0002\u0010\u000b\u001a\u00020\f\u0012\b\b\u0002\u0010\r\u001a\u00020\f¢\u0006\u0002\u0010\u000eJ(\u0010\u0011\u001a\u00020\u00122\u0016\u0010\u0013\u001a\u00120\u0014R\u000e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u00020\u00152\u0006\u0010\u0016\u001a\u00020\u0017H\u0014J\u0014\u0010\u0018\u001a\u000e\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u00030\u0019H\u0014J\b\u0010\u001a\u001a\u00020\u001bH\u0014J \u0010\u001c\u001a\u001a\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u001e\u0012\u0004\u0012\u00020\u00170\u001dH\u0014J\b\u0010\u001f\u001a\u00020 H\u0014J \u0010!\u001a\u001a\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u001e\u0012\u0004\u0012\u00020\u00170\"H\u0014J\b\u0010#\u001a\u00020$H\u0014J\b\u0010%\u001a\u00020&H\u0014R\u000e\u0010\u000f\u001a\u00020\u0010X\u0082\u0004¢\u0006\u0002\n��¨\u0006'"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint2/BiRNNATPDJoint2ArcStandardParser;", "Lcom/kotlinnlp/neuralparser/templates/parsers/birnn/ambiguouspos/BiRNNAmbiguousPOSParser;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/state/templates/StackBufferState;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/models/arcstandard/ArcStandardTransition;", "Lcom/kotlinnlp/neuralparser/utils/features/DenseFeaturesErrors;", "Lcom/kotlinnlp/neuralparser/utils/features/DenseFeatures;", "Lcom/kotlinnlp/neuralparser/templates/supportstructure/compositeprediction/ATPDJointSupportStructure;", "Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint2/BiRNNATPDJoint2ArcStandardParserModel;", "model", "wordDropoutCoefficient", "", "beamSize", "", "maxParallelThreads", "(Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint2/BiRNNATPDJoint2ArcStandardParserModel;DII)V", "useSoftmaxOutput", "", "beforeApplyAction", "", "action", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;", "context", "Lcom/kotlinnlp/neuralparser/templates/inputcontexts/TokensAmbiguousPOSContext;", "buildActionsGenerator", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/ActionsGenerator$MorphoSyntacticLabeled;", "buildActionsScorer", "Lcom/kotlinnlp/neuralparser/parsers/arcstandard/tpdjoint/ArcStandardTPDJointActionsScorer;", "buildBestActionSelector", "Lcom/kotlinnlp/syntaxdecoder/modules/bestactionselector/HighestScoreActionSelector;", "Lcom/kotlinnlp/neuralparser/utils/items/DenseItem;", "buildFeaturesExtractor", "Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint2/ArcStandardATPDFeaturesExtractor;", "buildMultiActionsSelector", "Lcom/kotlinnlp/syntaxdecoder/modules/bestactionselector/MultliActionsSelectorByScore;", "buildSupportStructureFactory", "Lcom/kotlinnlp/neuralparser/templates/supportstructure/compositeprediction/ATPDJointStructureFactory;", "buildTransitionSystem", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/models/arcstandard/ArcStandard;", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint2/BiRNNATPDJoint2ArcStandardParser.class */
public final class BiRNNATPDJoint2ArcStandardParser extends BiRNNAmbiguousPOSParser<StackBufferState, ArcStandardTransition, DenseFeaturesErrors, DenseFeatures, ATPDJointSupportStructure, BiRNNATPDJoint2ArcStandardParserModel> {
    private final boolean useSoftmaxOutput;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    /* renamed from: buildTransitionSystem, reason: merged with bridge method [inline-methods] */
    public ArcStandard mo7buildTransitionSystem() {
        return new ArcStandard();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    /* renamed from: buildActionsGenerator, reason: merged with bridge method [inline-methods] */
    public ActionsGenerator.MorphoSyntacticLabeled<StackBufferState, ArcStandardTransition> mo8buildActionsGenerator() {
        Object obj;
        Set elementsReversedSet = ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getCorpusDictionary().getDeprelTags().getElementsReversedSet();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Object obj2 : elementsReversedSet) {
            Deprel.Position direction = ((Deprel) obj2).getDirection();
            Object obj3 = linkedHashMap.get(direction);
            if (obj3 == null) {
                ArrayList arrayList = new ArrayList();
                linkedHashMap.put(direction, arrayList);
                obj = arrayList;
            } else {
                obj = obj3;
            }
            ((List) obj).add(obj2);
        }
        return new ActionsGenerator.MorphoSyntacticLabeled<>(linkedHashMap, ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getCorpusDictionary().getDeprelPosTagCombinations());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    /* renamed from: buildActionsScorer, reason: merged with bridge method [inline-methods] */
    public ArcStandardTPDJointActionsScorer mo9buildActionsScorer() {
        return new ArcStandardTPDJointActionsScorer(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getActionsScoresActivation(), ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getTransitionScorerNetwork(), ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getPosDeprelScorerNetworkModel(), new ParamsOptimizer(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getTransitionScorerNetwork().getModel(), new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, (WeightsRegularization) null, 24, (DefaultConstructorMarker) null)), new ParamsOptimizer(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getPosDeprelScorerNetworkModel().getParams(), new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, (WeightsRegularization) null, 24, (DefaultConstructorMarker) null)), ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getCorpusDictionary().getDeprelTags(), ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getCorpusDictionary().getPosTags());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    public ATPDJointStructureFactory buildSupportStructureFactory() {
        return new ATPDJointStructureFactory(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getAppliedActionsNetwork(), ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getTransitionScorerNetwork(), ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getPosDeprelScorerNetworkModel(), this.useSoftmaxOutput ? OutputErrorsInit.AllErrors : OutputErrorsInit.AllZeros);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    /* renamed from: buildFeaturesExtractor, reason: merged with bridge method [inline-methods] */
    public ArcStandardATPDFeaturesExtractor mo10buildFeaturesExtractor() {
        return new ArcStandardATPDFeaturesExtractor(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getActionsVectors(), new ActionsVectorsOptimizer(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getActionsVectors(), new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, (WeightsRegularization) null, 24, (DefaultConstructorMarker) null)), new ParamsOptimizer(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getAppliedActionsNetwork().getModel(), new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, (WeightsRegularization) null, 24, (DefaultConstructorMarker) null)), ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getCorpusDictionary().getDeprelTags(), ((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getCorpusDictionary().getPosTags(), ((LayerConfiguration) CollectionsKt.last(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getAppliedActionsNetwork().getLayersConfiguration())).getSize());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    /* renamed from: buildBestActionSelector, reason: merged with bridge method [inline-methods] */
    public HighestScoreActionSelector<StackBufferState, ArcStandardTransition, DenseItem, TokensAmbiguousPOSContext> mo11buildBestActionSelector() {
        return new HighestScoreActionSelector<>();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.NeuralParser
    @NotNull
    /* renamed from: buildMultiActionsSelector, reason: merged with bridge method [inline-methods] */
    public MultliActionsSelectorByScore<StackBufferState, ArcStandardTransition, DenseItem, TokensAmbiguousPOSContext> mo12buildMultiActionsSelector() {
        return new MultliActionsSelectorByScore<>();
    }

    /* renamed from: beforeApplyAction, reason: avoid collision after fix types in other method */
    protected void beforeApplyAction2(@NotNull Transition<ArcStandardTransition, StackBufferState>.Action action, @NotNull TokensAmbiguousPOSContext tokensAmbiguousPOSContext) {
        Intrinsics.checkParameterIsNotNull(action, "action");
        Intrinsics.checkParameterIsNotNull(tokensAmbiguousPOSContext, "context");
    }

    @Override // com.kotlinnlp.neuralparser.NeuralParser
    public /* bridge */ /* synthetic */ void beforeApplyAction(Transition.Action action, TokensAmbiguousPOSContext tokensAmbiguousPOSContext) {
        beforeApplyAction2((Transition<ArcStandardTransition, StackBufferState>.Action) action, tokensAmbiguousPOSContext);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    /* JADX WARN: Multi-variable type inference failed */
    public BiRNNATPDJoint2ArcStandardParser(@NotNull BiRNNATPDJoint2ArcStandardParserModel biRNNATPDJoint2ArcStandardParserModel, double d, int i, int i2) {
        super(biRNNATPDJoint2ArcStandardParserModel, d, i, i2);
        Intrinsics.checkParameterIsNotNull(biRNNATPDJoint2ArcStandardParserModel, "model");
        this.useSoftmaxOutput = Intrinsics.areEqual(((LayerConfiguration) CollectionsKt.last(((BiRNNATPDJoint2ArcStandardParserModel) getModel()).getTransitionScorerNetwork().getLayersConfiguration())).getActivationFunction(), new Softmax());
    }

    public /* synthetic */ BiRNNATPDJoint2ArcStandardParser(BiRNNATPDJoint2ArcStandardParserModel biRNNATPDJoint2ArcStandardParserModel, double d, int i, int i2, int i3, DefaultConstructorMarker defaultConstructorMarker) {
        this(biRNNATPDJoint2ArcStandardParserModel, (i3 & 2) != 0 ? 0.0d : d, (i3 & 4) != 0 ? 1 : i, (i3 & 8) != 0 ? 1 : i2);
    }
}
