package com.kotlinnlp.neuralparser.templates.parsers.birnn.ambiguouspos;

import com.kotlinnlp.neuralparser.helpers.Trainer;
import com.kotlinnlp.neuralparser.helpers.Validator;
import com.kotlinnlp.neuralparser.templates.inputcontexts.TokensAmbiguousPOSContext;
import com.kotlinnlp.neuralparser.templates.parsers.birnn.ambiguouspos.BiRNNAmbiguousPOSParserModel;
import com.kotlinnlp.neuralparser.utils.items.DenseItem;
import com.kotlinnlp.neuralparser.utils.items.DenseItemErrors;
import com.kotlinnlp.simplednn.core.functionalities.regularization.WeightsRegularization;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.adam.ADAMMethod;
import com.kotlinnlp.simplednn.deeplearning.birnn.deepbirnn.DeepBiRNNEncoder;
import com.kotlinnlp.simplednn.deeplearning.birnn.deepbirnn.DeepBiRNNOptimizer;
import com.kotlinnlp.simplednn.deeplearning.embeddings.Embedding;
import com.kotlinnlp.simplednn.deeplearning.embeddings.EmbeddingsOptimizer;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.syntaxdecoder.modules.actionserrorssetter.ActionsErrorsSetter;
import com.kotlinnlp.syntaxdecoder.modules.bestactionselector.HighestScoreCorrectActionSelector;
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features;
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.FeaturesErrors;
import com.kotlinnlp.syntaxdecoder.modules.supportstructure.DecodingSupportStructure;
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.oracle.OracleFactory;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.State;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: BiRNNAmbiguousPOSParserTrainer.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0090\u0001\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\b\u0016\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u0002*\u0014\b\u0001\u0010\u0003*\u000e\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u0002H\u00010\u0004*\b\b\u0002\u0010\u0005*\u00020\u0006*\u0012\b\u0003\u0010\u0007*\f\u0012\u0004\u0012\u0002H\u0005\u0012\u0002\b\u00030\b*\n\b\u0004\u0010\t \u0001*\u00020\n*\b\b\u0005\u0010\u000b*\u00020\f22\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u00020\u000e\u0012\u0004\u0012\u00020\u000f\u0012\u0004\u0012\u0002H\u0005\u0012\u0004\u0012\u0002H\u0007\u0012\u0004\u0012\u0002H\t\u0012\u0004\u0012\u0002H\u000b0\rB\u009b\u0001\u0012*\u0010\u0010\u001a&\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u00050\u0011\u0012\u001e\u0010\u0012\u001a\u001a\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00020\u000f\u0012\u0004\u0012\u00020\u000e0\u0013\u0012\u0012\u0010\u0014\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0015\u0012\u0006\u0010\u0016\u001a\u00020\u0017\u0012\u0006\u0010\u0018\u001a\u00020\u0017\u0012\b\b\u0002\u0010\u0019\u001a\u00020\u0017\u0012\b\u0010\u001a\u001a\u0004\u0018\u00010\u001b\u0012\u0006\u0010\u001c\u001a\u00020\u001d\u0012\b\b\u0002\u0010\u001e\u001a\u00020\u001f¢\u0006\u0002\u0010 J\u0010\u0010(\u001a\u00020)2\u0006\u0010*\u001a\u00020'H\u0002J \u0010+\u001a\u00020)2\u0006\u0010,\u001a\u00020\u000e2\u0006\u0010-\u001a\u00020\u00172\u0006\u0010*\u001a\u00020'H\u0002J\u0010\u0010.\u001a\u00020)2\u0006\u0010,\u001a\u00020\u000eH\u0014J(\u0010/\u001a\u00020)2\u0016\u00100\u001a\u001201R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u00042\u0006\u0010,\u001a\u00020\u000eH\u0014J\u0010\u00102\u001a\u00020)2\u0006\u0010,\u001a\u00020\u000eH\u0014J\u0010\u00103\u001a\u00020)2\u0006\u0010,\u001a\u00020\u000eH\u0002J\u001e\u00104\u001a\u0010\u0012\u0004\u0012\u00020'\u0012\u0006\u0012\u0004\u0018\u00010'052\u0006\u0010*\u001a\u00020'H\u0002J\b\u00106\u001a\u00020)H\u0014R\u000e\u0010!\u001a\u00020\"X\u0082\u0004¢\u0006\u0002\n��R2\u0010\u0010\u001a&\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u00050\u0011X\u0088\u0004¢\u0006\u0002\n��R\u0016\u0010#\u001a\n\u0012\u0004\u0012\u00020\u001d\u0018\u00010$X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010%\u001a\b\u0012\u0004\u0012\u00020\u00170$X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010&\u001a\u00020'X\u0082\u0004¢\u0006\u0002\n��¨\u00067"}, d2 = {"Lcom/kotlinnlp/neuralparser/templates/parsers/birnn/ambiguouspos/BiRNNAmbiguousPOSParserTrainer;", "StateType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/state/State;", "TransitionType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;", "FeaturesErrorsType", "Lcom/kotlinnlp/syntaxdecoder/modules/featuresextractor/features/FeaturesErrors;", "FeaturesType", "Lcom/kotlinnlp/syntaxdecoder/modules/featuresextractor/features/Features;", "SupportStructureType", "Lcom/kotlinnlp/syntaxdecoder/modules/supportstructure/DecodingSupportStructure;", "ModelType", "Lcom/kotlinnlp/neuralparser/templates/parsers/birnn/ambiguouspos/BiRNNAmbiguousPOSParserModel;", "Lcom/kotlinnlp/neuralparser/helpers/Trainer;", "Lcom/kotlinnlp/neuralparser/templates/inputcontexts/TokensAmbiguousPOSContext;", "Lcom/kotlinnlp/neuralparser/utils/items/DenseItem;", "neuralParser", "Lcom/kotlinnlp/neuralparser/templates/parsers/birnn/ambiguouspos/BiRNNAmbiguousPOSParser;", "actionsErrorsSetter", "Lcom/kotlinnlp/syntaxdecoder/modules/actionserrorssetter/ActionsErrorsSetter;", "oracleFactory", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/oracle/OracleFactory;", "epochs", "", "batchSize", "minRelevantErrorsCountToUpdate", "validator", "Lcom/kotlinnlp/neuralparser/helpers/Validator;", "modelFilename", "", "verbose", "", "(Lcom/kotlinnlp/neuralparser/templates/parsers/birnn/ambiguouspos/BiRNNAmbiguousPOSParser;Lcom/kotlinnlp/syntaxdecoder/modules/actionserrorssetter/ActionsErrorsSetter;Lcom/kotlinnlp/syntaxdecoder/transitionsystem/oracle/OracleFactory;IIILcom/kotlinnlp/neuralparser/helpers/Validator;Ljava/lang/String;Z)V", "biRNNOptimizer", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/deepbirnn/DeepBiRNNOptimizer;", "preTrainedEmbeddingsOptimizer", "Lcom/kotlinnlp/simplednn/deeplearning/embeddings/EmbeddingsOptimizer;", "wordEmbeddingsOptimizer", "zerosErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "accumulateNullTokenErrors", "", "errors", "accumulateTokenErrors", "context", "tokenIndex", "afterSentenceLearning", "beforeApplyAction", "action", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;", "beforeSentenceLearning", "propagateErrors", "splitTokenErrors", "Lkotlin/Pair;", "update", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/templates/parsers/birnn/ambiguouspos/BiRNNAmbiguousPOSParserTrainer.class */
public class BiRNNAmbiguousPOSParserTrainer<StateType extends State<StateType>, TransitionType extends Transition<TransitionType, StateType>, FeaturesErrorsType extends FeaturesErrors, FeaturesType extends Features<FeaturesErrorsType, ?>, SupportStructureType extends DecodingSupportStructure, ModelType extends BiRNNAmbiguousPOSParserModel> extends Trainer<StateType, TransitionType, TokensAmbiguousPOSContext, DenseItem, FeaturesErrorsType, FeaturesType, SupportStructureType, ModelType> {
    private final DenseNDArray zerosErrors;
    private final EmbeddingsOptimizer<Integer> wordEmbeddingsOptimizer;
    private final EmbeddingsOptimizer<String> preTrainedEmbeddingsOptimizer;
    private final DeepBiRNNOptimizer biRNNOptimizer;
    private final BiRNNAmbiguousPOSParser<StateType, TransitionType, FeaturesErrorsType, FeaturesType, SupportStructureType, ModelType> neuralParser;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.helpers.Trainer
    public void beforeSentenceLearning(@NotNull TokensAmbiguousPOSContext tokensAmbiguousPOSContext) {
        Intrinsics.checkParameterIsNotNull(tokensAmbiguousPOSContext, "context");
        this.biRNNOptimizer.newBatch();
        this.biRNNOptimizer.newExample();
        this.wordEmbeddingsOptimizer.newBatch();
        this.wordEmbeddingsOptimizer.newExample();
        EmbeddingsOptimizer<String> embeddingsOptimizer = this.preTrainedEmbeddingsOptimizer;
        if (embeddingsOptimizer != null) {
            embeddingsOptimizer.newBatch();
        }
        EmbeddingsOptimizer<String> embeddingsOptimizer2 = this.preTrainedEmbeddingsOptimizer;
        if (embeddingsOptimizer2 != null) {
            embeddingsOptimizer2.newExample();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.helpers.Trainer
    public void afterSentenceLearning(@NotNull TokensAmbiguousPOSContext tokensAmbiguousPOSContext) {
        Intrinsics.checkParameterIsNotNull(tokensAmbiguousPOSContext, "context");
        propagateErrors(tokensAmbiguousPOSContext);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.neuralparser.helpers.Trainer
    public void beforeApplyAction(@NotNull Transition<TransitionType, StateType>.Action action, @NotNull TokensAmbiguousPOSContext tokensAmbiguousPOSContext) {
        Intrinsics.checkParameterIsNotNull(action, "action");
        Intrinsics.checkParameterIsNotNull(tokensAmbiguousPOSContext, "context");
    }

    @Override // com.kotlinnlp.neuralparser.helpers.Trainer
    protected void update() {
        this.biRNNOptimizer.update();
        this.wordEmbeddingsOptimizer.update();
        EmbeddingsOptimizer<String> embeddingsOptimizer = this.preTrainedEmbeddingsOptimizer;
        if (embeddingsOptimizer != null) {
            embeddingsOptimizer.update();
        }
    }

    private final void propagateErrors(TokensAmbiguousPOSContext tokensAmbiguousPOSContext) {
        DenseNDArray denseNDArray;
        DeepBiRNNEncoder<DenseNDArray> biRNNEncoder = this.neuralParser.getBiRNNEncoder();
        List<DenseItem> items = tokensAmbiguousPOSContext.getItems();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(items, 10));
        Iterator<T> it = items.iterator();
        while (it.hasNext()) {
            DenseItemErrors m45getErrors = ((DenseItem) it.next()).m45getErrors();
            if (m45getErrors != null) {
                denseNDArray = m45getErrors.getArray();
                if (denseNDArray != null) {
                    arrayList.add(denseNDArray);
                }
            }
            denseNDArray = this.zerosErrors;
            arrayList.add(denseNDArray);
        }
        ArrayList arrayList2 = arrayList;
        Object[] array = arrayList2.toArray(new DenseNDArray[arrayList2.size()]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        biRNNEncoder.backward((DenseNDArray[]) array, true);
        this.biRNNOptimizer.accumulate(this.neuralParser.getBiRNNEncoder().getParamsErrors(false));
        int i = 0;
        for (DenseNDArray denseNDArray2 : this.neuralParser.getBiRNNEncoder().getInputSequenceErrors(false)) {
            int i2 = i;
            i++;
            accumulateTokenErrors(tokensAmbiguousPOSContext, i2, denseNDArray2);
        }
        if (tokensAmbiguousPOSContext.getNullItemErrors() != null) {
            DeepBiRNNEncoder<DenseNDArray> paddingVectorEncoder = this.neuralParser.getPaddingVectorEncoder();
            DenseNDArray[] denseNDArrayArr = new DenseNDArray[1];
            DenseNDArray nullItemErrors = tokensAmbiguousPOSContext.getNullItemErrors();
            if (nullItemErrors == null) {
                Intrinsics.throwNpe();
            }
            denseNDArrayArr[0] = nullItemErrors;
            paddingVectorEncoder.backward(denseNDArrayArr, true);
            this.biRNNOptimizer.accumulate(this.neuralParser.getPaddingVectorEncoder().getParamsErrors(false));
            accumulateNullTokenErrors((DenseNDArray) ArraysKt.first(this.neuralParser.getPaddingVectorEncoder().getInputSequenceErrors(false)));
        }
    }

    private final void accumulateTokenErrors(TokensAmbiguousPOSContext tokensAmbiguousPOSContext, int i, DenseNDArray denseNDArray) {
        Pair<DenseNDArray, DenseNDArray> splitTokenErrors = splitTokenErrors(denseNDArray);
        DenseNDArray denseNDArray2 = (DenseNDArray) splitTokenErrors.component1();
        DenseNDArray denseNDArray3 = (DenseNDArray) splitTokenErrors.component2();
        this.wordEmbeddingsOptimizer.accumulate(tokensAmbiguousPOSContext.getWordEmbeddings().get(i), denseNDArray2);
        EmbeddingsOptimizer<String> embeddingsOptimizer = this.preTrainedEmbeddingsOptimizer;
        if (embeddingsOptimizer != null) {
            List<Embedding> preTrainedWordEmbeddings = tokensAmbiguousPOSContext.getPreTrainedWordEmbeddings();
            if (preTrainedWordEmbeddings == null) {
                Intrinsics.throwNpe();
            }
            Embedding embedding = preTrainedWordEmbeddings.get(i);
            if (denseNDArray3 == null) {
                Intrinsics.throwNpe();
            }
            embeddingsOptimizer.accumulate(embedding, denseNDArray3);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final void accumulateNullTokenErrors(DenseNDArray denseNDArray) {
        Pair<DenseNDArray, DenseNDArray> splitTokenErrors = splitTokenErrors(denseNDArray);
        DenseNDArray denseNDArray2 = (DenseNDArray) splitTokenErrors.component1();
        DenseNDArray denseNDArray3 = (DenseNDArray) splitTokenErrors.component2();
        this.wordEmbeddingsOptimizer.accumulate(((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getWordEmbeddings().getNullEmbedding(), denseNDArray2);
        EmbeddingsOptimizer<String> embeddingsOptimizer = this.preTrainedEmbeddingsOptimizer;
        if (embeddingsOptimizer != null) {
            Embedding nullEmbedding = ((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getWordEmbeddings().getNullEmbedding();
            if (denseNDArray3 == null) {
                Intrinsics.throwNpe();
            }
            embeddingsOptimizer.accumulate(nullEmbedding, denseNDArray3);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final Pair<DenseNDArray, DenseNDArray> splitTokenErrors(DenseNDArray denseNDArray) {
        int posTagsSize = ((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getPosTagsSize();
        int wordEmbeddingSize = posTagsSize + ((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getWordEmbeddingSize();
        return new Pair<>(denseNDArray.getRange(posTagsSize, wordEmbeddingSize), denseNDArray.getRange(wordEmbeddingSize, denseNDArray.getLength()));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    /* JADX WARN: Multi-variable type inference failed */
    public BiRNNAmbiguousPOSParserTrainer(@NotNull BiRNNAmbiguousPOSParser<StateType, TransitionType, FeaturesErrorsType, FeaturesType, SupportStructureType, ? extends ModelType> biRNNAmbiguousPOSParser, @NotNull ActionsErrorsSetter<StateType, TransitionType, DenseItem, TokensAmbiguousPOSContext> actionsErrorsSetter, @NotNull OracleFactory<StateType, TransitionType> oracleFactory, int i, int i2, int i3, @Nullable Validator validator, @NotNull String str, boolean z) {
        super(biRNNAmbiguousPOSParser, actionsErrorsSetter, oracleFactory, i, new HighestScoreCorrectActionSelector(), i2, i3, validator, str, z);
        Intrinsics.checkParameterIsNotNull(biRNNAmbiguousPOSParser, "neuralParser");
        Intrinsics.checkParameterIsNotNull(actionsErrorsSetter, "actionsErrorsSetter");
        Intrinsics.checkParameterIsNotNull(oracleFactory, "oracleFactory");
        Intrinsics.checkParameterIsNotNull(str, "modelFilename");
        this.neuralParser = biRNNAmbiguousPOSParser;
        this.zerosErrors = DenseNDArrayFactory.INSTANCE.zeros(new Shape(((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getBiRNN().getOutputSize(), 0, 2, (DefaultConstructorMarker) null));
        this.wordEmbeddingsOptimizer = new EmbeddingsOptimizer<>(((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getWordEmbeddings(), new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, (WeightsRegularization) null, 24, (DefaultConstructorMarker) null));
        this.preTrainedEmbeddingsOptimizer = ((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getPreTrainedWordEmbeddings() != null ? new EmbeddingsOptimizer<>(((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getPreTrainedWordEmbeddings(), new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, (WeightsRegularization) null, 24, (DefaultConstructorMarker) null)) : null;
        this.biRNNOptimizer = new DeepBiRNNOptimizer(((BiRNNAmbiguousPOSParserModel) this.neuralParser.getModel()).getBiRNN(), new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, (WeightsRegularization) null, 24, (DefaultConstructorMarker) null));
    }

    public /* synthetic */ BiRNNAmbiguousPOSParserTrainer(BiRNNAmbiguousPOSParser biRNNAmbiguousPOSParser, ActionsErrorsSetter actionsErrorsSetter, OracleFactory oracleFactory, int i, int i2, int i3, Validator validator, String str, boolean z, int i4, DefaultConstructorMarker defaultConstructorMarker) {
        this(biRNNAmbiguousPOSParser, actionsErrorsSetter, oracleFactory, i, i2, (i4 & 32) != 0 ? 1 : i3, validator, str, (i4 & 256) != 0 ? true : z);
    }
}
