package com.kotlinnlp.simplednn.helpers.training;

import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax;
import com.kotlinnlp.simplednn.core.functionalities.losses.LossCalculator;
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator;
import com.kotlinnlp.simplednn.core.layers.LayerConfiguration;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.dataset.SequenceExample;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: SequenceTrainingHelper.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��<\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0010\u0006\n��\n\u0002\u0010\u000b\n\u0002\b\u0006\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u0002H\u00010\u00040\u0003BG\u0012\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00028��0\u0006\u0012\f\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\t0\b\u0012\u0006\u0010\n\u001a\u00020\u000b\u0012\u0012\b\u0002\u0010\f\u001a\f\u0012\u0006\u0012\u0004\u0018\u00010\u000e\u0018\u00010\r\u0012\b\b\u0002\u0010\u000f\u001a\u00020\u0010¢\u0006\u0002\u0010\u0011J\u0016\u0010\u0014\u001a\u00020\u000e2\f\u0010\u0015\u001a\b\u0012\u0004\u0012\u00028��0\u0004H\u0014R\u0018\u0010\f\u001a\f\u0012\u0006\u0012\u0004\u0018\u00010\u000e\u0018\u00010\rX\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010\u0005\u001a\b\u0012\u0004\u0012\u00028��0\u0006X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0012\u0010\u0013¨\u0006\u0016"}, d2 = {"Lcom/kotlinnlp/simplednn/helpers/training/SequenceTrainingHelper;", "NDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/helpers/training/TrainingHelper;", "Lcom/kotlinnlp/simplednn/dataset/SequenceExample;", "neuralProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;", "optimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "lossCalculator", "Lcom/kotlinnlp/simplednn/core/functionalities/losses/LossCalculator;", "mePropK", "", "", "verbose", "", "(Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;Lcom/kotlinnlp/simplednn/core/functionalities/losses/LossCalculator;Ljava/util/List;Z)V", "getNeuralProcessor", "()Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;", "learnFromExample", "example", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/helpers/training/SequenceTrainingHelper.class */
public final class SequenceTrainingHelper<NDArrayType extends NDArray<NDArrayType>> extends TrainingHelper<SequenceExample<NDArrayType>> {

    @NotNull
    private final RecurrentNeuralProcessor<NDArrayType> neuralProcessor;
    private final List<Double> mePropK;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.helpers.training.TrainingHelper
    public double learnFromExample(@NotNull SequenceExample<NDArrayType> example) {
        Intrinsics.checkParameterIsNotNull(example, "example");
        RecurrentNeuralProcessor.forward$default(getNeuralProcessor(), example.getSequenceFeatures(), null, false, false, 14, null);
        DenseNDArray[] outputSequence$default = RecurrentNeuralProcessor.getOutputSequence$default(getNeuralProcessor(), false, 1, null);
        ArrayList<DenseNDArray> sequenceOutputGold = example.getSequenceOutputGold();
        if (sequenceOutputGold == null) {
            throw new TypeCastException("null cannot be cast to non-null type java.util.Collection<T>");
        }
        Object[] array = sequenceOutputGold.toArray(new DenseNDArray[sequenceOutputGold.size()]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        RecurrentNeuralProcessor.backward$default((RecurrentNeuralProcessor) getNeuralProcessor(), getLossCalculator().calculateErrors(outputSequence$default, (DenseNDArray[]) array), false, (List) this.mePropK, 2, (Object) null);
        LossCalculator lossCalculator = getLossCalculator();
        ArrayList<DenseNDArray> sequenceOutputGold2 = example.getSequenceOutputGold();
        if (sequenceOutputGold2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type java.util.Collection<T>");
        }
        Object[] array2 = sequenceOutputGold2.toArray(new DenseNDArray[sequenceOutputGold2.size()]);
        if (array2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        return lossCalculator.calculateMeanLoss(outputSequence$default, (DenseNDArray[]) array2);
    }

    @Override // com.kotlinnlp.simplednn.helpers.training.TrainingHelper
    @NotNull
    public RecurrentNeuralProcessor<NDArrayType> getNeuralProcessor() {
        return this.neuralProcessor;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public SequenceTrainingHelper(@NotNull RecurrentNeuralProcessor<NDArrayType> neuralProcessor, @NotNull ParamsOptimizer<NetworkParameters> optimizer, @NotNull LossCalculator lossCalculator, @Nullable List<Double> list, boolean z) {
        super(neuralProcessor, optimizer, lossCalculator, z);
        Intrinsics.checkParameterIsNotNull(neuralProcessor, "neuralProcessor");
        Intrinsics.checkParameterIsNotNull(optimizer, "optimizer");
        Intrinsics.checkParameterIsNotNull(lossCalculator, "lossCalculator");
        this.neuralProcessor = neuralProcessor;
        this.mePropK = list;
        ActivationFunction activationFunction = ((LayerConfiguration) CollectionsKt.last((List) getNeuralProcessor().getNeuralNetwork().getLayersConfiguration())).getActivationFunction();
        if (!(((getLossCalculator() instanceof SoftmaxCrossEntropyCalculator) && (activationFunction instanceof Softmax)) || !((getLossCalculator() instanceof SoftmaxCrossEntropyCalculator) || (activationFunction instanceof Softmax)))) {
            throw new IllegalArgumentException("Softmax cross-entropy loss must be used with the softmax as output activation function and vice versa".toString());
        }
    }

    public /* synthetic */ SequenceTrainingHelper(RecurrentNeuralProcessor recurrentNeuralProcessor, ParamsOptimizer paramsOptimizer, LossCalculator lossCalculator, List list, boolean z, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(recurrentNeuralProcessor, paramsOptimizer, lossCalculator, (i & 8) != 0 ? (List) null : list, (i & 16) != 0 ? false : z);
    }
}
