package com.kotlinnlp.neuralparser.parsers.arcstandard.atpdjoint;

import com.kotlinnlp.dependencytree.Deprel;
import com.kotlinnlp.dependencytree.POSTag;
import com.kotlinnlp.neuralparser.parsers.arcstandard.atpdjoint.actionsembeddings.ActionsEmbeddingsMap;
import com.kotlinnlp.neuralparser.parsers.arcstandard.atpdjoint.actionsembeddings.ActionsEmbeddingsOptimizer;
import com.kotlinnlp.neuralparser.templates.featuresextractor.TokensWindowFeaturesExtractorTrainable;
import com.kotlinnlp.neuralparser.templates.inputcontexts.TokensAmbiguousPOSContext;
import com.kotlinnlp.neuralparser.templates.inputcontexts.TokensEncodingContext;
import com.kotlinnlp.neuralparser.templates.supportstructure.compositeprediction.ATPDJointSupportStructure;
import com.kotlinnlp.neuralparser.utils.features.DenseFeatures;
import com.kotlinnlp.neuralparser.utils.features.DenseFeaturesErrors;
import com.kotlinnlp.neuralparser.utils.items.DenseItem;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.deeplearning.embeddings.Embedding;
import com.kotlinnlp.simplednn.simplemath.SimplemathKt;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.simplednn.utils.DictionarySet;
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features;
import com.kotlinnlp.syntaxdecoder.modules.supportstructure.DecodingSupportStructure;
import com.kotlinnlp.syntaxdecoder.syntax.DependencyRelation;
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arcstandard.ArcStandardTransition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arcstandard.transitions.ArcLeft;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arcstandard.transitions.ArcRight;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arcstandard.transitions.Root;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arcstandard.transitions.Shift;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.StateView;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.templates.StackBufferState;
import com.kotlinnlp.syntaxdecoder.utils.DecodingContext;
import com.kotlinnlp.syntaxdecoder.utils.ExtensionsKt;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: ArcStandardATPDFeaturesExtractor.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0092\u0001\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010!\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0010\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n\u0002\b\r\u0018��2,\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u0005\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00020\u0006\u0012\u0004\u0012\u00020\u00070\u0001BG\u0012\u0006\u0010\b\u001a\u00020\t\u0012\u0006\u0010\n\u001a\u00020\u000b\u0012\u0006\u0010\f\u001a\u00020\r\u0012\f\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\u00100\u000f\u0012\f\u0010\u0011\u001a\b\u0012\u0004\u0012\u00020\u00130\u0012\u0012\f\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00150\u0012¢\u0006\u0002\u0010\u0016J\u0018\u0010(\u001a\u00020)2\u0006\u0010*\u001a\u00020\u00192\u0006\u0010+\u001a\u00020\rH\u0002JJ\u0010,\u001a\u00020)2$\u0010-\u001a \u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020/\u0012\u0004\u0012\u00020\u00050.2\u001a\u00100\u001a\u0016\u0012\u0012\u0012\u0010\u0012\u0006\u0012\u0004\u0018\u00010\r\u0012\u0004\u0012\u00020\u00190201H\u0002J>\u00103\u001a\u00020)2$\u0010-\u001a \u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020/\u0012\u0004\u0012\u00020\u00050.2\u0006\u00104\u001a\u00020\u00072\u0006\u00105\u001a\u000206H\u0016J\b\u00107\u001a\u00020)H\u0002J6\u00108\u001a\u00020\u00052$\u0010-\u001a \u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020/\u0012\u0004\u0012\u00020\u00050.2\u0006\u00104\u001a\u00020\u0007H\u0014J\u001f\u00109\u001a\u00020\u00192\b\u0010:\u001a\u0004\u0018\u00010\r2\u0006\u0010;\u001a\u00020\u0004H\u0014¢\u0006\u0002\u0010<J\u001e\u0010=\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010\r012\f\u0010>\u001a\b\u0012\u0004\u0012\u00020\u00020\u0006H\u0014J\b\u0010?\u001a\u00020)H\u0016J\b\u0010@\u001a\u00020)H\u0016J\b\u0010A\u001a\u00020)H\u0016J\b\u0010B\u001a\u00020)H\u0016R\u000e\u0010\b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\n\u001a\u00020\u000bX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u00190\u0018X\u0082.¢\u0006\u0002\n��R\u0014\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\u00100\u000fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001a\u001a\u00020\u0019X\u0082\u0004¢\u0006\u0002\n��R&\u0010\u001b\u001a\u001a\u0012\u0016\u0012\u0014\u0018\u00010\u001dR\u000e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u00020\u001e0\u001cX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u001f\u001a\b\u0012\u0004\u0012\u00020\u00190\u001cX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0011\u001a\b\u0012\u0004\u0012\u00020\u00130\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00150\u0012X\u0082\u0004¢\u0006\u0002\n��R(\u0010 \u001a\u00020\r*\u00120\u001dR\u000e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u00020\u001e8BX\u0082\u0004¢\u0006\u0006\u001a\u0004\b!\u0010\"R$\u0010#\u001a\u00020\r*\u000e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u00020\u001e8BX\u0082\u0004¢\u0006\u0006\u001a\u0004\b$\u0010%R(\u0010&\u001a\u00020\r*\u00120\u001dR\u000e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u00020\u001e8BX\u0082\u0004¢\u0006\u0006\u001a\u0004\b'\u0010\"¨\u0006C"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint/ArcStandardATPDFeaturesExtractor;", "Lcom/kotlinnlp/neuralparser/templates/featuresextractor/TokensWindowFeaturesExtractorTrainable;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/state/templates/StackBufferState;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/models/arcstandard/ArcStandardTransition;", "Lcom/kotlinnlp/neuralparser/templates/inputcontexts/TokensAmbiguousPOSContext;", "Lcom/kotlinnlp/neuralparser/utils/features/DenseFeatures;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/state/StateView;", "Lcom/kotlinnlp/neuralparser/templates/supportstructure/compositeprediction/ATPDJointSupportStructure;", "actionsEmbeddings", "Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint/actionsembeddings/ActionsEmbeddingsMap;", "actionsEmbeddingsOptimizer", "Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint/actionsembeddings/ActionsEmbeddingsOptimizer;", "actionsEncodingSize", "", "actionsEncoderOptimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "deprelTags", "Lcom/kotlinnlp/simplednn/utils/DictionarySet;", "Lcom/kotlinnlp/dependencytree/Deprel;", "posTags", "Lcom/kotlinnlp/dependencytree/POSTag;", "(Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint/actionsembeddings/ActionsEmbeddingsMap;Lcom/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint/actionsembeddings/ActionsEmbeddingsOptimizer;ILcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;Lcom/kotlinnlp/simplednn/utils/DictionarySet;Lcom/kotlinnlp/simplednn/utils/DictionarySet;)V", "actionsEncoder", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "appliedActionZerosErrors", "appliedActions", "", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;", "appliedActionsEncodingErrors", "deprelKey", "getDeprelKey", "(Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;)I", "key", "getKey", "(Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;)I", "posTagKey", "getPosTagKey", "accumulateActionEmbeddingErrors", "", "errors", "actionIndex", "accumulateItemsErrors", "decodingContext", "Lcom/kotlinnlp/syntaxdecoder/utils/DecodingContext;", "Lcom/kotlinnlp/neuralparser/utils/items/DenseItem;", "itemsErrors", "", "Lkotlin/Pair;", "backward", "supportStructure", "propagateToInput", "", "backwardActionsEncoding", "extract", "getTokenEncoding", "tokenId", "context", "(Ljava/lang/Integer;Lcom/kotlinnlp/neuralparser/templates/inputcontexts/TokensAmbiguousPOSContext;)Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getTokensWindow", "stateView", "newBatch", "newEpoch", "newExample", "update", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/arcstandard/atpdjoint/ArcStandardATPDFeaturesExtractor.class */
public final class ArcStandardATPDFeaturesExtractor extends TokensWindowFeaturesExtractorTrainable<StackBufferState, ArcStandardTransition, TokensAmbiguousPOSContext, DenseFeatures, StateView<StackBufferState>, ATPDJointSupportStructure> {
    private final List<Transition<ArcStandardTransition, StackBufferState>.Action> appliedActions;
    private final List<DenseNDArray> appliedActionsEncodingErrors;
    private final DenseNDArray appliedActionZerosErrors;
    private RecurrentNeuralProcessor<DenseNDArray> actionsEncoder;
    private final ActionsEmbeddingsMap actionsEmbeddings;
    private final ActionsEmbeddingsOptimizer actionsEmbeddingsOptimizer;
    private final int actionsEncodingSize;
    private final ParamsOptimizer<NetworkParameters> actionsEncoderOptimizer;
    private final DictionarySet<Deprel> deprelTags;
    private final DictionarySet<POSTag> posTags;

    private final int getKey(@NotNull Transition<ArcStandardTransition, StackBufferState> transition) {
        if (transition instanceof Shift) {
            return 0;
        }
        if (transition instanceof Root) {
            return 1;
        }
        if (transition instanceof ArcLeft) {
            return 2;
        }
        if (transition instanceof ArcRight) {
            return 3;
        }
        throw new RuntimeException("unknown transition");
    }

    private final int getDeprelKey(@NotNull Transition<ArcStandardTransition, StackBufferState>.Action action) {
        if (action.getTransition() instanceof Shift) {
            return 0;
        }
        if (!(action instanceof DependencyRelation)) {
            throw new RuntimeException("unknown action");
        }
        DictionarySet<Deprel> dictionarySet = this.deprelTags;
        Deprel deprel = ((DependencyRelation) action).getDeprel();
        if (deprel == null) {
            Intrinsics.throwNpe();
        }
        Integer id = dictionarySet.getId(deprel);
        if (id == null) {
            Intrinsics.throwNpe();
        }
        return id.intValue() + 1;
    }

    private final int getPosTagKey(@NotNull Transition<ArcStandardTransition, StackBufferState>.Action action) {
        if (action.getTransition() instanceof Shift) {
            return 0;
        }
        if (!(action instanceof DependencyRelation)) {
            throw new RuntimeException("unknown action");
        }
        DictionarySet<POSTag> dictionarySet = this.posTags;
        POSTag posTag = ((DependencyRelation) action).getPosTag();
        if (posTag == null) {
            Intrinsics.throwNpe();
        }
        Integer id = dictionarySet.getId(posTag);
        if (id == null) {
            Intrinsics.throwNpe();
        }
        return id.intValue() + 1;
    }

    @NotNull
    protected DenseFeatures extract(@NotNull DecodingContext<StackBufferState, ArcStandardTransition, TokensAmbiguousPOSContext, DenseItem, DenseFeatures> decodingContext, @NotNull ATPDJointSupportStructure aTPDJointSupportStructure) {
        Embedding embedding;
        Intrinsics.checkParameterIsNotNull(decodingContext, "decodingContext");
        Intrinsics.checkParameterIsNotNull(aTPDJointSupportStructure, "supportStructure");
        boolean trainingMode = ((TokensAmbiguousPOSContext) decodingContext.getExtendedState().getContext()).getTrainingMode();
        if (trainingMode && this.appliedActions.size() > this.appliedActionsEncodingErrors.size()) {
            this.appliedActionsEncodingErrors.add(this.appliedActionZerosErrors);
        }
        List appliedActions = decodingContext.getExtendedState().getAppliedActions();
        if (appliedActions.isEmpty()) {
            if (trainingMode) {
                if (!this.appliedActions.isEmpty()) {
                    backwardActionsEncoding();
                }
                this.appliedActions.add(null);
            }
            embedding = this.actionsEmbeddings.getNullEmbedding();
        } else {
            Transition<ArcStandardTransition, StackBufferState>.Action action = (Transition.Action) CollectionsKt.last(appliedActions);
            if (trainingMode) {
                this.appliedActions.add(action);
            }
            embedding = this.actionsEmbeddings.get(getKey(action.getTransition()), getPosTagKey(action), getDeprelKey(action));
        }
        this.actionsEncoder = aTPDJointSupportStructure.getActionProcessor();
        return new DenseFeatures(SimplemathKt.concatVectorsV(new DenseNDArray[]{RecurrentNeuralProcessor.forward$default(aTPDJointSupportStructure.getActionProcessor(), embedding.getArray().getValues(), appliedActions.isEmpty(), false, false, 12, (Object) null), extractViewFeatures(new StateView(decodingContext.getExtendedState().getState()), (TokensEncodingContext) decodingContext.getExtendedState().getContext())}));
    }

    public /* bridge */ /* synthetic */ Features extract(DecodingContext decodingContext, DecodingSupportStructure decodingSupportStructure) {
        return extract((DecodingContext<StackBufferState, ArcStandardTransition, TokensAmbiguousPOSContext, DenseItem, DenseFeatures>) decodingContext, (ATPDJointSupportStructure) decodingSupportStructure);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.templates.featuresextractor.TokensWindowFeaturesExtractorTrainable
    @NotNull
    public DenseNDArray getTokenEncoding(@Nullable Integer num, @NotNull TokensAmbiguousPOSContext tokensAmbiguousPOSContext) {
        Intrinsics.checkParameterIsNotNull(tokensAmbiguousPOSContext, "context");
        return tokensAmbiguousPOSContext.getTokenEncoding(num);
    }

    public void backward(@NotNull DecodingContext<StackBufferState, ArcStandardTransition, TokensAmbiguousPOSContext, DenseItem, DenseFeatures> decodingContext, @NotNull ATPDJointSupportStructure aTPDJointSupportStructure, boolean z) {
        Intrinsics.checkParameterIsNotNull(decodingContext, "decodingContext");
        Intrinsics.checkParameterIsNotNull(aTPDJointSupportStructure, "supportStructure");
        if (z) {
            List<Integer> tokensWindow = getTokensWindow(new StateView<>(decodingContext.getExtendedState().getState()));
            DenseNDArray array = ((DenseFeaturesErrors) ((DenseFeatures) decodingContext.getFeatures()).getErrors()).getArray();
            DenseNDArray[] splitV = array.splitV(new int[]{this.actionsEncodingSize, array.getLength() - this.actionsEncodingSize});
            this.appliedActionsEncodingErrors.add(splitV[0]);
            accumulateItemsErrors(decodingContext, CollectionsKt.zip(tokensWindow, splitV[1].splitV(new int[]{((TokensAmbiguousPOSContext) decodingContext.getExtendedState().getContext()).getEncodingSize()})));
        }
    }

    public /* bridge */ /* synthetic */ void backward(DecodingContext decodingContext, DecodingSupportStructure decodingSupportStructure, boolean z) {
        backward((DecodingContext<StackBufferState, ArcStandardTransition, TokensAmbiguousPOSContext, DenseItem, DenseFeatures>) decodingContext, (ATPDJointSupportStructure) decodingSupportStructure, z);
    }

    public void newExample() {
        this.actionsEncoderOptimizer.newExample();
    }

    public void newBatch() {
        this.actionsEncoderOptimizer.newBatch();
    }

    public void newEpoch() {
        this.actionsEncoderOptimizer.newEpoch();
    }

    public void update() {
        if (!this.appliedActions.isEmpty()) {
            backwardActionsEncoding();
        }
        this.actionsEmbeddingsOptimizer.update();
        this.actionsEncoderOptimizer.update();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.templates.featuresextractor.TokensWindowFeaturesExtractorTrainable
    @NotNull
    public List<Integer> getTokensWindow(@NotNull StateView<StackBufferState> stateView) {
        Intrinsics.checkParameterIsNotNull(stateView, "stateView");
        return CollectionsKt.listOf(new Integer[]{(Integer) ExtensionsKt.getItemOrNull(stateView.getState().getStack(), -3), (Integer) ExtensionsKt.getItemOrNull(stateView.getState().getStack(), -2), (Integer) ExtensionsKt.getItemOrNull(stateView.getState().getStack(), -1), (Integer) ExtensionsKt.getItemOrNull(stateView.getState().getBuffer(), 0)});
    }

    private final void backwardActionsEncoding() {
        if (this.appliedActions.size() > this.appliedActionsEncodingErrors.size()) {
            this.appliedActionsEncodingErrors.add(this.appliedActionZerosErrors);
        }
        RecurrentNeuralProcessor<DenseNDArray> recurrentNeuralProcessor = this.actionsEncoder;
        if (recurrentNeuralProcessor == null) {
            Intrinsics.throwUninitializedPropertyAccessException("actionsEncoder");
        }
        List<DenseNDArray> list = this.appliedActionsEncodingErrors;
        if (list == null) {
            throw new TypeCastException("null cannot be cast to non-null type java.util.Collection<T>");
        }
        Object[] array = list.toArray(new DenseNDArray[list.size()]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        RecurrentNeuralProcessor.backward$default(recurrentNeuralProcessor, (DenseNDArray[]) array, true, (List) null, 4, (Object) null);
        ParamsOptimizer<NetworkParameters> paramsOptimizer = this.actionsEncoderOptimizer;
        RecurrentNeuralProcessor<DenseNDArray> recurrentNeuralProcessor2 = this.actionsEncoder;
        if (recurrentNeuralProcessor2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("actionsEncoder");
        }
        ParamsOptimizer.accumulate$default(paramsOptimizer, recurrentNeuralProcessor2.getParamsErrors(false), false, 2, (Object) null);
        RecurrentNeuralProcessor<DenseNDArray> recurrentNeuralProcessor3 = this.actionsEncoder;
        if (recurrentNeuralProcessor3 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("actionsEncoder");
        }
        int i = 0;
        for (DenseNDArray denseNDArray : recurrentNeuralProcessor3.getInputSequenceErrors(false)) {
            int i2 = i;
            i++;
            accumulateActionEmbeddingErrors(denseNDArray, i2);
        }
        this.appliedActions.clear();
        this.appliedActionsEncodingErrors.clear();
    }

    private final void accumulateActionEmbeddingErrors(DenseNDArray denseNDArray, int i) {
        if (i == 0) {
            this.actionsEmbeddingsOptimizer.accumulateNullEmbeddingsErrors(denseNDArray);
            return;
        }
        Transition<ArcStandardTransition, StackBufferState>.Action action = this.appliedActions.get(i);
        if (action == null) {
            Intrinsics.throwNpe();
        }
        Transition<ArcStandardTransition, StackBufferState>.Action action2 = action;
        this.actionsEmbeddingsOptimizer.accumulate(getKey(action2.getTransition()), getPosTagKey(action2), getDeprelKey(action2), denseNDArray);
    }

    private final void accumulateItemsErrors(DecodingContext<StackBufferState, ArcStandardTransition, TokensAmbiguousPOSContext, DenseItem, DenseFeatures> decodingContext, List<Pair<Integer, DenseNDArray>> list) {
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            ((TokensAmbiguousPOSContext) decodingContext.getExtendedState().getContext()).accumulateItemErrors((Integer) pair.getFirst(), (DenseNDArray) pair.getSecond());
        }
    }

    public ArcStandardATPDFeaturesExtractor(@NotNull ActionsEmbeddingsMap actionsEmbeddingsMap, @NotNull ActionsEmbeddingsOptimizer actionsEmbeddingsOptimizer, int i, @NotNull ParamsOptimizer<NetworkParameters> paramsOptimizer, @NotNull DictionarySet<Deprel> dictionarySet, @NotNull DictionarySet<POSTag> dictionarySet2) {
        Intrinsics.checkParameterIsNotNull(actionsEmbeddingsMap, "actionsEmbeddings");
        Intrinsics.checkParameterIsNotNull(actionsEmbeddingsOptimizer, "actionsEmbeddingsOptimizer");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer, "actionsEncoderOptimizer");
        Intrinsics.checkParameterIsNotNull(dictionarySet, "deprelTags");
        Intrinsics.checkParameterIsNotNull(dictionarySet2, "posTags");
        this.actionsEmbeddings = actionsEmbeddingsMap;
        this.actionsEmbeddingsOptimizer = actionsEmbeddingsOptimizer;
        this.actionsEncodingSize = i;
        this.actionsEncoderOptimizer = paramsOptimizer;
        this.deprelTags = dictionarySet;
        this.posTags = dictionarySet2;
        this.appliedActions = new ArrayList();
        this.appliedActionsEncodingErrors = new ArrayList();
        this.appliedActionZerosErrors = DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.actionsEncodingSize, 0, 2, (DefaultConstructorMarker) null));
    }
}
