package com.kotlinnlp.neuralparser.parsers.transitionbased.templates.parsers.birnn.simple;

import com.kotlinnlp.neuralparser.language.Sentence;
import com.kotlinnlp.neuralparser.language.Token;
import com.kotlinnlp.neuralparser.parsers.transitionbased.TransitionBasedParser;
import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.inputcontexts.TokensEmbeddingsContext;
import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.parsers.birnn.simple.BiRNNParserModel;
import com.kotlinnlp.neuralparser.utils.items.DenseItem;
import com.kotlinnlp.simplednn.deeplearning.birnn.deepbirnn.DeepBiRNNEncoder;
import com.kotlinnlp.simplednn.deeplearning.embeddings.Embedding;
import com.kotlinnlp.simplednn.deeplearning.embeddings.EmbeddingsMapByDictionary;
import com.kotlinnlp.simplednn.simplemath.SimplemathKt;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features;
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.FeaturesErrors;
import com.kotlinnlp.syntaxdecoder.modules.supportstructure.DecodingSupportStructure;
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.State;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: BiRNNParser.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��j\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0006\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\b&\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u0002*\u0014\b\u0001\u0010\u0003*\u000e\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u0002H\u00010\u0004*\b\b\u0002\u0010\u0005*\u00020\u0006*\u0012\b\u0003\u0010\u0007*\f\u0012\u0004\u0012\u0002H\u0005\u0012\u0002\b\u00030\b*\b\b\u0004\u0010\t*\u00020\n*\n\b\u0005\u0010\u000b \u0001*\u00020\f22\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u00020\u000e\u0012\u0004\u0012\u00020\u000f\u0012\u0004\u0012\u0002H\u0005\u0012\u0004\u0012\u0002H\u0007\u0012\u0004\u0012\u0002H\t\u0012\u0004\u0012\u0002H\u000b0\rB-\u0012\u0006\u0010\u0010\u001a\u00028\u0005\u0012\u0006\u0010\u0011\u001a\u00020\u0012\u0012\u0006\u0010\u0013\u001a\u00020\u0012\u0012\u0006\u0010\u0014\u001a\u00020\u0015\u0012\u0006\u0010\u0016\u001a\u00020\u0015¢\u0006\u0002\u0010\u0017J\u0018\u0010\u001f\u001a\u00020\u000e2\u0006\u0010 \u001a\u00020!2\u0006\u0010\"\u001a\u00020#H\u0016J\u0018\u0010$\u001a\u00020\u001a2\u0006\u0010%\u001a\u00020&2\u0006\u0010'\u001a\u00020&H\u0002R\u0017\u0010\u0018\u001a\b\u0012\u0004\u0012\u00020\u001a0\u0019¢\u0006\b\n��\u001a\u0004\b\u001b\u0010\u001cR\u0017\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u001a0\u0019¢\u0006\b\n��\u001a\u0004\b\u001e\u0010\u001cR\u000e\u0010\u0013\u001a\u00020\u0012X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0011\u001a\u00020\u0012X\u0082\u0004¢\u0006\u0002\n��¨\u0006("}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/parsers/birnn/simple/BiRNNParser;", "StateType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/state/State;", "TransitionType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;", "FeaturesErrorsType", "Lcom/kotlinnlp/syntaxdecoder/modules/featuresextractor/features/FeaturesErrors;", "FeaturesType", "Lcom/kotlinnlp/syntaxdecoder/modules/featuresextractor/features/Features;", "SupportStructureType", "Lcom/kotlinnlp/syntaxdecoder/modules/supportstructure/DecodingSupportStructure;", "ModelType", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/parsers/birnn/simple/BiRNNParserModel;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/TransitionBasedParser;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/inputcontexts/TokensEmbeddingsContext;", "Lcom/kotlinnlp/neuralparser/utils/items/DenseItem;", "model", "wordDropoutCoefficient", "", "posDropoutCoefficient", "beamSize", "", "maxParallelThreads", "(Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/parsers/birnn/simple/BiRNNParserModel;DDII)V", "biRNNEncoder", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/deepbirnn/DeepBiRNNEncoder;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getBiRNNEncoder", "()Lcom/kotlinnlp/simplednn/deeplearning/birnn/deepbirnn/DeepBiRNNEncoder;", "paddingVectorEncoder", "getPaddingVectorEncoder", "buildContext", "sentence", "Lcom/kotlinnlp/neuralparser/language/Sentence;", "trainingMode", "", "buildTokenEmbedding", "posEmbedding", "Lcom/kotlinnlp/simplednn/deeplearning/embeddings/Embedding;", "wordEmbedding", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/transitionbased/templates/parsers/birnn/simple/BiRNNParser.class */
public abstract class BiRNNParser<StateType extends State<StateType>, TransitionType extends Transition<TransitionType, StateType>, FeaturesErrorsType extends FeaturesErrors, FeaturesType extends Features<FeaturesErrorsType, ?>, SupportStructureType extends DecodingSupportStructure, ModelType extends BiRNNParserModel> extends TransitionBasedParser<StateType, TransitionType, TokensEmbeddingsContext, DenseItem, FeaturesErrorsType, FeaturesType, SupportStructureType, ModelType> {

    @NotNull
    private final DeepBiRNNEncoder<DenseNDArray> biRNNEncoder;

    @NotNull
    private final DeepBiRNNEncoder<DenseNDArray> paddingVectorEncoder;
    private final double wordDropoutCoefficient;
    private final double posDropoutCoefficient;

    @NotNull
    public final DeepBiRNNEncoder<DenseNDArray> getBiRNNEncoder() {
        return this.biRNNEncoder;
    }

    @NotNull
    public final DeepBiRNNEncoder<DenseNDArray> getPaddingVectorEncoder() {
        return this.paddingVectorEncoder;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.neuralparser.parsers.transitionbased.TransitionBasedParser
    @NotNull
    public TokensEmbeddingsContext buildContext(@NotNull Sentence sentence, boolean z) {
        Intrinsics.checkParameterIsNotNull(sentence, "sentence");
        List<Token> tokens = sentence.getTokens();
        List<Token> list = tokens;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        for (Token token : list) {
            EmbeddingsMapByDictionary posEmbeddings = ((BiRNNParserModel) getModel()).getPosEmbeddings();
            String pos = token.getPos();
            if (pos == null) {
                Intrinsics.throwNpe();
            }
            arrayList.add(posEmbeddings.get(pos, z ? this.posDropoutCoefficient : 0.0d));
        }
        ArrayList arrayList2 = arrayList;
        List<Token> list2 = tokens;
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            arrayList3.add(((BiRNNParserModel) getModel()).getWordEmbeddings().get(((Token) it.next()).getNormalizedWord(), z ? this.wordDropoutCoefficient : 0.0d));
        }
        ArrayList arrayList4 = arrayList3;
        Iterable until = RangesKt.until(0, tokens.size());
        ArrayList arrayList5 = new ArrayList(CollectionsKt.collectionSizeOrDefault(until, 10));
        IntIterator it2 = until.iterator();
        while (it2.hasNext()) {
            int nextInt = it2.nextInt();
            arrayList5.add(buildTokenEmbedding((Embedding) arrayList2.get(nextInt), (Embedding) arrayList4.get(nextInt)));
        }
        ArrayList arrayList6 = arrayList5;
        DenseNDArray denseNDArray = (DenseNDArray) ArraysKt.first(this.paddingVectorEncoder.encode(new DenseNDArray[]{buildTokenEmbedding(((BiRNNParserModel) getModel()).getPosEmbeddings().getNullEmbedding(), ((BiRNNParserModel) getModel()).getWordEmbeddings().getNullEmbedding())}));
        List<Token> list3 = tokens;
        ArrayList arrayList7 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list3, 10));
        Iterator<T> it3 = list3.iterator();
        while (it3.hasNext()) {
            arrayList7.add(new DenseItem(((Token) it3.next()).getId()));
        }
        ArrayList arrayList8 = arrayList7;
        DeepBiRNNEncoder<DenseNDArray> deepBiRNNEncoder = this.biRNNEncoder;
        ArrayList arrayList9 = arrayList6;
        Object[] array = arrayList9.toArray(new DenseNDArray[arrayList9.size()]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        return new TokensEmbeddingsContext(arrayList8, tokens, arrayList2, arrayList4, denseNDArray, ((BiRNNParserModel) getModel()).getDeepBiRNN().getOutputSize(), deepBiRNNEncoder.encode((NDArray[]) array), false, false, 384, null);
    }

    private final DenseNDArray buildTokenEmbedding(Embedding embedding, Embedding embedding2) {
        return SimplemathKt.concatVectorsV(new DenseNDArray[]{embedding.getArray().getValues(), embedding2.getArray().getValues()});
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public BiRNNParser(@NotNull ModelType modeltype, double d, double d2, int i, int i2) {
        super(modeltype, i, i2);
        Intrinsics.checkParameterIsNotNull(modeltype, "model");
        this.wordDropoutCoefficient = d;
        this.posDropoutCoefficient = d2;
        this.biRNNEncoder = new DeepBiRNNEncoder<>(modeltype.getDeepBiRNN());
        this.paddingVectorEncoder = new DeepBiRNNEncoder<>(modeltype.getDeepBiRNN());
    }
}
