package com.kotlinnlp.neuralparser.parsers.transitionbased.models.arceagerspine;

import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.featuresextractor.TWFeaturesExtractorTrainable;
import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.inputcontexts.TokensEmbeddingsContext;
import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.supportstructure.multiprediction.MPSupportStructure;
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.features.GroupedDenseFeatures;
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.features.GroupedDenseFeaturesErrors;
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.items.DenseItem;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features;
import com.kotlinnlp.syntaxdecoder.modules.supportstructure.DecodingSupportStructure;
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arceagerspine.ArcEagerSpineState;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arceagerspine.ArcEagerSpineStateView;
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arceagerspine.ArcEagerSpineTransition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.StateView;
import com.kotlinnlp.syntaxdecoder.utils.DecodingContext;
import com.kotlinnlp.syntaxdecoder.utils.ExtensionsKt;
import com.kotlinnlp.utils.MultiMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.collections.MapsKt;
import kotlin.jvm.functions.Function3;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: ArcEagerSpineEmbeddingsFeaturesExtractor.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��T\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\r\u0018��2&\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u0006\u0012\u0004\u0012\u00020\u00070\u0001B\u0005¢\u0006\u0002\u0010\bJ2\u0010\u000e\u001a\u00020\u000f2\f\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u00112\u001a\u0010\u0013\u001a\u0016\u0012\u0012\u0012\u0010\u0012\u0006\u0012\u0004\u0018\u00010\n\u0012\u0004\u0012\u00020\u00150\u00140\u0011H\u0002J6\u0010\u0016\u001a\u00020\u000f2$\u0010\u0017\u001a \u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u0012\u0012\u0004\u0012\u00020\u00050\u00182\u0006\u0010\u0019\u001a\u00020\u0007H\u0016J6\u0010\u001a\u001a\u00020\u00052$\u0010\u0017\u001a \u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u0012\u0012\u0004\u0012\u00020\u00050\u00182\u0006\u0010\u0019\u001a\u00020\u0007H\u0014J\u001f\u0010\u001b\u001a\u00020\u00152\b\u0010\u001c\u001a\u0004\u0018\u00010\n2\u0006\u0010\u001d\u001a\u00020\u0004H\u0014¢\u0006\u0002\u0010\u001eJ\u0018\u0010\u001f\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010\n0\u00112\u0006\u0010 \u001a\u00020\u0006H\u0014J\b\u0010!\u001a\u00020\u000fH\u0016J\b\u0010\"\u001a\u00020\u000fH\u0016J\b\u0010#\u001a\u00020\u000fH\u0016J\b\u0010$\u001a\u00020\u000fH\u0016R$\u0010\t\u001a\u00020\n*\u000e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u00020\u000b8BX\u0082\u0004¢\u0006\u0006\u001a\u0004\b\f\u0010\r¨\u0006%"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/transitionbased/models/arceagerspine/ArcEagerSpineEmbeddingsFeaturesExtractor;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/featuresextractor/TWFeaturesExtractorTrainable;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/models/arceagerspine/ArcEagerSpineState;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/models/arceagerspine/ArcEagerSpineTransition;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/inputcontexts/TokensEmbeddingsContext;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/utils/features/GroupedDenseFeatures;", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/models/arceagerspine/ArcEagerSpineStateView;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/supportstructure/multiprediction/MPSupportStructure;", "()V", "groupId", "", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;", "getGroupId", "(Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;)I", "accumulateItemsErrors", "", "items", "", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/utils/items/DenseItem;", "itemsErrors", "Lkotlin/Pair;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "backward", "decodingContext", "Lcom/kotlinnlp/syntaxdecoder/utils/DecodingContext;", "supportStructure", "extract", "getTokenEncoding", "tokenId", "context", "(Ljava/lang/Integer;Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/inputcontexts/TokensEmbeddingsContext;)Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getTokensWindow", "stateView", "newBatch", "newEpoch", "newExample", "update", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/transitionbased/models/arceagerspine/ArcEagerSpineEmbeddingsFeaturesExtractor.class */
public final class ArcEagerSpineEmbeddingsFeaturesExtractor extends TWFeaturesExtractorTrainable<ArcEagerSpineState, ArcEagerSpineTransition, TokensEmbeddingsContext, GroupedDenseFeatures, ArcEagerSpineStateView, MPSupportStructure> {

    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 3)
    /* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/transitionbased/models/arceagerspine/ArcEagerSpineEmbeddingsFeaturesExtractor$WhenMappings.class */
    public final /* synthetic */ class WhenMappings {
        public static final /* synthetic */ int[] $EnumSwitchMapping$0 = new int[Transition.Type.values().length];

        static {
            $EnumSwitchMapping$0[Transition.Type.SHIFT.ordinal()] = 1;
            $EnumSwitchMapping$0[Transition.Type.ROOT.ordinal()] = 2;
            $EnumSwitchMapping$0[Transition.Type.ARC_LEFT.ordinal()] = 3;
            $EnumSwitchMapping$0[Transition.Type.ARC_RIGHT.ordinal()] = 4;
        }
    }

    private final int getGroupId(@NotNull Transition<ArcEagerSpineTransition, ArcEagerSpineState> transition) {
        return Utils.INSTANCE.getGroupId(transition);
    }

    @NotNull
    protected GroupedDenseFeatures extract(@NotNull final DecodingContext<ArcEagerSpineState, ArcEagerSpineTransition, TokensEmbeddingsContext, DenseItem, GroupedDenseFeatures> decodingContext, @NotNull MPSupportStructure mPSupportStructure) {
        Object obj;
        Intrinsics.checkParameterIsNotNull(decodingContext, "decodingContext");
        Intrinsics.checkParameterIsNotNull(mPSupportStructure, "supportStructure");
        final LinkedHashMap linkedHashMap = new LinkedHashMap();
        List actions = decodingContext.getActions();
        ArrayList arrayList = new ArrayList();
        for (Object obj2 : actions) {
            if (((Transition.Action) obj2).isAllowed()) {
                arrayList.add(obj2);
            }
        }
        List transitions = ExtensionsKt.toTransitions(arrayList);
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        for (Object obj3 : transitions) {
            Integer valueOf = Integer.valueOf(getGroupId((Transition) ((ArcEagerSpineTransition) obj3)));
            Object obj4 = linkedHashMap2.get(valueOf);
            if (obj4 == null) {
                ArrayList arrayList2 = new ArrayList();
                linkedHashMap2.put(valueOf, arrayList2);
                obj = arrayList2;
            } else {
                obj = obj4;
            }
            ((List) obj).add(obj3);
        }
        linkedHashMap2.forEach(new BiConsumer<Integer, List<? extends ArcEagerSpineTransition>>() { // from class: com.kotlinnlp.neuralparser.parsers.transitionbased.models.arceagerspine.ArcEagerSpineEmbeddingsFeaturesExtractor$extract$3
            @Override // java.util.function.BiConsumer
            public final void accept(@NotNull Integer num, @NotNull List<? extends ArcEagerSpineTransition> list) {
                DenseNDArray extractViewFeatures;
                Intrinsics.checkParameterIsNotNull(num, "groupId");
                Intrinsics.checkParameterIsNotNull(list, "transitions");
                linkedHashMap.put(num, new LinkedHashMap());
                for (ArcEagerSpineTransition arcEagerSpineTransition : list) {
                    Map map = (Map) MapsKt.getValue(linkedHashMap, num);
                    Integer valueOf2 = Integer.valueOf(arcEagerSpineTransition.getId());
                    extractViewFeatures = ArcEagerSpineEmbeddingsFeaturesExtractor.this.extractViewFeatures((StateView) new ArcEagerSpineStateView(decodingContext.getExtendedState().getState(), arcEagerSpineTransition), (TokensEmbeddingsContext) decodingContext.getExtendedState().getContext());
                    map.put(valueOf2, extractViewFeatures);
                }
            }
        });
        return new GroupedDenseFeatures(new MultiMap(linkedHashMap));
    }

    public /* bridge */ /* synthetic */ Features extract(DecodingContext decodingContext, DecodingSupportStructure decodingSupportStructure) {
        return extract((DecodingContext<ArcEagerSpineState, ArcEagerSpineTransition, TokensEmbeddingsContext, DenseItem, GroupedDenseFeatures>) decodingContext, (MPSupportStructure) decodingSupportStructure);
    }

    public void backward(@NotNull final DecodingContext<ArcEagerSpineState, ArcEagerSpineTransition, TokensEmbeddingsContext, DenseItem, GroupedDenseFeatures> decodingContext, @NotNull MPSupportStructure mPSupportStructure) {
        Intrinsics.checkParameterIsNotNull(decodingContext, "decodingContext");
        Intrinsics.checkParameterIsNotNull(mPSupportStructure, "supportStructure");
        final Map transitionsMap = ExtensionsKt.toTransitionsMap(decodingContext.getActions());
        ((GroupedDenseFeaturesErrors) ((GroupedDenseFeatures) decodingContext.getFeatures()).getErrors()).getErrorsMap().forEach(new Function3<Object, Integer, DenseNDArray, Unit>() { // from class: com.kotlinnlp.neuralparser.parsers.transitionbased.models.arceagerspine.ArcEagerSpineEmbeddingsFeaturesExtractor$backward$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj, Object obj2, Object obj3) {
                invoke(obj, ((Number) obj2).intValue(), (DenseNDArray) obj3);
                return Unit.INSTANCE;
            }

            public final void invoke(@NotNull Object obj, int i, @NotNull DenseNDArray denseNDArray) {
                Intrinsics.checkParameterIsNotNull(obj, "<anonymous parameter 0>");
                Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
                ArcEagerSpineEmbeddingsFeaturesExtractor.this.accumulateItemsErrors(((TokensEmbeddingsContext) decodingContext.getExtendedState().getContext()).getItems(), CollectionsKt.zip(ArcEagerSpineEmbeddingsFeaturesExtractor.this.getTokensWindow(new ArcEagerSpineStateView(decodingContext.getExtendedState().getState(), (ArcEagerSpineTransition) MapsKt.getValue(transitionsMap, Integer.valueOf(i)))), denseNDArray.splitV(new int[]{((TokensEmbeddingsContext) decodingContext.getExtendedState().getContext()).getEncodingSize()})));
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(3);
            }
        });
    }

    public /* bridge */ /* synthetic */ void backward(DecodingContext decodingContext, DecodingSupportStructure decodingSupportStructure) {
        backward((DecodingContext<ArcEagerSpineState, ArcEagerSpineTransition, TokensEmbeddingsContext, DenseItem, GroupedDenseFeatures>) decodingContext, (MPSupportStructure) decodingSupportStructure);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.parsers.transitionbased.templates.featuresextractor.TWFeaturesExtractorTrainable
    @NotNull
    public List<Integer> getTokensWindow(@NotNull ArcEagerSpineStateView arcEagerSpineStateView) {
        Intrinsics.checkParameterIsNotNull(arcEagerSpineStateView, "stateView");
        switch (WhenMappings.$EnumSwitchMapping$0[arcEagerSpineStateView.getTransition().getType().ordinal()]) {
            case 1:
                return CollectionsKt.listOf(new Integer[]{arcEagerSpineStateView.getStack()[0], arcEagerSpineStateView.getBuffer()[0]});
            case 2:
                return CollectionsKt.listOf(arcEagerSpineStateView.getStack()[0]);
            case 3:
            case 4:
                return CollectionsKt.listOf(new Integer[]{arcEagerSpineStateView.getStack()[0], arcEagerSpineStateView.getBuffer()[0]});
            default:
                throw new RuntimeException();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.parsers.transitionbased.templates.featuresextractor.TWFeaturesExtractorTrainable
    @NotNull
    public DenseNDArray getTokenEncoding(@Nullable Integer num, @NotNull TokensEmbeddingsContext tokensEmbeddingsContext) {
        Intrinsics.checkParameterIsNotNull(tokensEmbeddingsContext, "context");
        return tokensEmbeddingsContext.getTokenEncoding(num);
    }

    public void newExample() {
    }

    public void newBatch() {
    }

    public void newEpoch() {
    }

    public void update() {
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final void accumulateItemsErrors(List<DenseItem> list, List<Pair<Integer, DenseNDArray>> list2) {
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            Integer num = (Integer) pair.component1();
            DenseNDArray denseNDArray = (DenseNDArray) pair.component2();
            if (num != null) {
                list.get(num.intValue()).accumulateErrors(denseNDArray);
            }
        }
    }
}
