package com.kotlinnlp.simplednn.core.neuralprocessor.recurrent;

import com.kotlinnlp.simplednn.core.arrays.DistributionArray;
import com.kotlinnlp.simplednn.core.layers.LayerConfiguration;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.LayerStructure;
import com.kotlinnlp.simplednn.core.layers.recurrent.GatedRecurrentLayerStructure;
import com.kotlinnlp.simplednn.core.layers.recurrent.RecurrentLayerStructure;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.simplednn.core.neuralnetwork.structure.recurrent.RecurrentNetworkStructure;
import com.kotlinnlp.simplednn.core.neuralnetwork.structure.recurrent.StructureContextWindow;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IndexedValue;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.IntProgression;
import kotlin.ranges.IntRange;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: RecurrentNeuralProcessor.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0082\u0001\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0004\n\u0002\u0010 \n\u0002\u0010\u0006\n��\n\u0002\u0010\u0011\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\b\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u00032\u00020\u0004B\u0017\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJ\u0012\u0010\u0017\u001a\u00020\u00182\b\b\u0002\u0010\u0019\u001a\u00020\u001aH\u0002J,\u0010\u001b\u001a\u00020\u00182\u0006\u0010\u001c\u001a\u00020\u00162\b\b\u0002\u0010\u001d\u001a\u00020\u001a2\u0012\b\u0002\u0010\u001e\u001a\f\u0012\u0006\u0012\u0004\u0018\u00010 \u0018\u00010\u001fJ7\u0010\u001b\u001a\u00020\u00182\f\u0010!\u001a\b\u0012\u0004\u0012\u00020\u00160\"2\b\b\u0002\u0010\u001d\u001a\u00020\u001a2\u0012\b\u0002\u0010\u001e\u001a\f\u0012\u0006\u0012\u0004\u0018\u00010 \u0018\u00010\u001f¢\u0006\u0002\u0010#J,\u0010$\u001a\u0006\u0012\u0002\b\u00030\u00022\u0006\u0010%\u001a\u00020\b2\u0006\u0010&\u001a\u00020\b2\u0006\u0010'\u001a\u00020(2\b\b\u0002\u0010)\u001a\u00020\u001aJ/\u0010*\u001a\u00020\u00162\u0006\u0010+\u001a\u00028��2\u0006\u0010,\u001a\u00020\u001a2\b\b\u0002\u0010\u0019\u001a\u00020\u001a2\b\b\u0002\u0010-\u001a\u00020\u001a¢\u0006\u0002\u0010.J2\u0010*\u001a\u00020\u00162\u0016\u0010/\u001a\u0012\u0012\u0004\u0012\u00028��00j\b\u0012\u0004\u0012\u00028��`12\b\b\u0002\u0010\u0019\u001a\u00020\u001a2\b\b\u0002\u0010-\u001a\u00020\u001aJ'\u00102\u001a\u00020\u00182\u0006\u0010+\u001a\u00028��2\u0006\u0010\u0019\u001a\u00020\u001a2\b\b\u0002\u0010-\u001a\u00020\u001aH\u0002¢\u0006\u0002\u00103J\u001e\u00104\u001a\u0006\u0012\u0002\b\u00030\u00022\u0006\u00105\u001a\u00020\b2\b\b\u0002\u0010)\u001a\u00020\u001aH\u0002J\u001b\u00106\u001a\b\u0012\u0004\u0012\u00020\u00160\"2\b\b\u0002\u0010)\u001a\u00020\u001a¢\u0006\u0002\u00107J\u0010\u00108\u001a\n\u0012\u0004\u0012\u00028��\u0018\u000109H\u0016J\u0010\u0010:\u001a\u00020\u00162\u0006\u0010)\u001a\u00020\u001aH\u0016J\u001b\u0010;\u001a\b\u0012\u0004\u0012\u00020\u00160\"2\b\b\u0002\u0010)\u001a\u00020\u001a¢\u0006\u0002\u00107J\u0010\u0010<\u001a\u00020\u000b2\u0006\u0010)\u001a\u00020\u001aH\u0016J\u0010\u0010=\u001a\n\u0012\u0004\u0012\u00028��\u0018\u000109H\u0016J4\u0010>\u001a\u00020\u00182\n\u0010?\u001a\u0006\u0012\u0002\b\u00030@2\u0006\u0010A\u001a\u00020\b2\u0006\u0010B\u001a\u00020\u001a2\u0006\u0010\u001d\u001a\u00020\u001a2\u0006\u0010C\u001a\u00020\u001aH\u0002J\u0018\u0010D\u001a\u00020\u00182\u0006\u0010E\u001a\u00020\u001a2\u0006\u0010F\u001a\u00020\u001aH\u0002J\b\u0010G\u001a\u00020\u0018H\u0002R\u000e\u0010\n\u001a\u00020\u000bX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\bX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\r\u001a\u00020\bX\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\u000b0\u000fX\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u0010\u001a\b\u0012\u0004\u0012\u00028��0\u0011X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0012\u001a\u00020\b8BX\u0082\u0004¢\u0006\u0006\u001a\u0004\b\u0013\u0010\u0014R\u000e\u0010\u0015\u001a\u00020\u0016X\u0082\u0004¢\u0006\u0002\n��¨\u0006H"}, d2 = {"Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/structure/recurrent/StructureContextWindow;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "neuralNetwork", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "id", "", "(Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;I)V", "backwardParamsErrors", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "curStateIndex", "lastStateIndex", "paramsErrorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "sequence", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/NNSequence;", "statesSize", "getStatesSize", "()I", "zeroErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "addNewState", "", "saveContributions", "", "backward", "outputErrors", "propagateToInput", "mePropK", "", "", "outputErrorsSequence", "", "([Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;ZLjava/util/List;)V", "calculateRelevance", "stateFrom", "stateTo", "relevantOutcomesDistribution", "Lcom/kotlinnlp/simplednn/core/arrays/DistributionArray;", "copy", "forward", "featuresArray", "firstState", "useDropout", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;ZZZ)Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "sequenceFeaturesArray", "Ljava/util/ArrayList;", "Lkotlin/collections/ArrayList;", "forwardCurrentState", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;ZZ)V", "getInputRelevance", "stateIndex", "getInputSequenceErrors", "(Z)[Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getNextStateStructure", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/structure/recurrent/RecurrentNetworkStructure;", "getOutput", "getOutputSequence", "getParamsErrors", "getPrevStateStructure", "propagateLayerRelevance", "layer", "Lcom/kotlinnlp/simplednn/core/layers/LayerStructure;", "layerIndex", "propagateToPrevState", "replaceInputRelevance", "propagateRelevanceOnCurrentState", "isFirstState", "isLastState", "reset", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor.class */
public final class RecurrentNeuralProcessor<InputNDArrayType extends NDArray<InputNDArrayType>> extends NeuralProcessor implements StructureContextWindow<InputNDArrayType> {
    private final NNSequence<InputNDArrayType> sequence;
    private int curStateIndex;
    private int lastStateIndex;
    private NetworkParameters backwardParamsErrors;
    private ParamsErrorsAccumulator<NetworkParameters> paramsErrorsAccumulator;
    private final DenseNDArray zeroErrors;

    private final int getStatesSize() {
        return this.lastStateIndex + 1;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralnetwork.structure.recurrent.StructureContextWindow
    @Nullable
    public RecurrentNetworkStructure<InputNDArrayType> getPrevStateStructure() {
        int i = this.lastStateIndex;
        int i2 = this.curStateIndex;
        if (1 <= i2 && i >= i2) {
            return this.sequence.getStateStructure(this.curStateIndex - 1);
        }
        return null;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralnetwork.structure.recurrent.StructureContextWindow
    @Nullable
    public RecurrentNetworkStructure<InputNDArrayType> getNextStateStructure() {
        int i = this.lastStateIndex;
        int i2 = this.curStateIndex;
        if (0 <= i2 && i > i2) {
            return this.sequence.getStateStructure(this.curStateIndex + 1);
        }
        return null;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public DenseNDArray getOutput(boolean z) {
        return z ? this.sequence.getStateStructure(this.lastStateIndex).getOutputLayer().getOutputArray().getValues().copy() : this.sequence.getStateStructure(this.lastStateIndex).getOutputLayer().getOutputArray().getValues();
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public NetworkParameters getParamsErrors(boolean z) {
        return this.paramsErrorsAccumulator.getParamsErrors(z);
    }

    @NotNull
    public final DenseNDArray[] getInputSequenceErrors(boolean z) {
        DenseNDArray[] denseNDArrayArr = new DenseNDArray[getStatesSize()];
        int length = denseNDArrayArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            DenseNDArray errors = this.sequence.getStateStructure(i).getInputLayer().getInputArray().getErrors();
            denseNDArrayArr[i2] = z ? errors.copy() : errors;
        }
        return denseNDArrayArr;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray[] getInputSequenceErrors$default(RecurrentNeuralProcessor recurrentNeuralProcessor, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return recurrentNeuralProcessor.getInputSequenceErrors(z);
    }

    @NotNull
    public final DenseNDArray[] getOutputSequence(boolean z) {
        DenseNDArray[] denseNDArrayArr = new DenseNDArray[getStatesSize()];
        int length = denseNDArrayArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            denseNDArrayArr[i] = z ? this.sequence.getStateStructure(i2).getOutputLayer().getOutputArray().getValues().copy() : this.sequence.getStateStructure(i2).getOutputLayer().getOutputArray().getValues();
        }
        return denseNDArrayArr;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray[] getOutputSequence$default(RecurrentNeuralProcessor recurrentNeuralProcessor, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return recurrentNeuralProcessor.getOutputSequence(z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    public final DenseNDArray forward(@NotNull ArrayList<InputNDArrayType> sequenceFeaturesArray, boolean z, boolean z2) {
        Intrinsics.checkParameterIsNotNull(sequenceFeaturesArray, "sequenceFeaturesArray");
        int i = 0;
        for (Object obj : sequenceFeaturesArray) {
            int i2 = i;
            i++;
            forward((NDArray) obj, i2 == 0, z, z2);
        }
        return NeuralProcessor.getOutput$default(this, false, 1, null);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray forward$default(RecurrentNeuralProcessor recurrentNeuralProcessor, ArrayList arrayList, boolean z, boolean z2, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        if ((i & 4) != 0) {
            z2 = false;
        }
        return recurrentNeuralProcessor.forward(arrayList, z, z2);
    }

    @NotNull
    public final DenseNDArray forward(@NotNull InputNDArrayType featuresArray, boolean z, boolean z2, boolean z3) {
        Intrinsics.checkParameterIsNotNull(featuresArray, "featuresArray");
        if (z) {
            reset();
        }
        addNewState(z2);
        this.curStateIndex = this.lastStateIndex;
        forwardCurrentState(featuresArray, z2, z3);
        return NeuralProcessor.getOutput$default(this, false, 1, null);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray forward$default(RecurrentNeuralProcessor recurrentNeuralProcessor, NDArray nDArray, boolean z, boolean z2, boolean z3, int i, Object obj) {
        if ((i & 4) != 0) {
            z2 = false;
        }
        if ((i & 8) != 0) {
            z3 = false;
        }
        return recurrentNeuralProcessor.forward(nDArray, z, z2, z3);
    }

    @NotNull
    public final NDArray<?> calculateRelevance(int i, int i2, @NotNull DistributionArray relevantOutcomesDistribution, boolean z) {
        Intrinsics.checkParameterIsNotNull(relevantOutcomesDistribution, "relevantOutcomesDistribution");
        if (!(i <= i2)) {
            Object[] objArr = {Integer.valueOf(i), Integer.valueOf(i2)};
            String format = String.format("stateFrom (%d) must be <= stateTo (%d)", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            throw new IllegalArgumentException(format.toString());
        }
        if (!(0 <= i && this.lastStateIndex >= i)) {
            Object[] objArr2 = {Integer.valueOf(i), Integer.valueOf(getStatesSize())};
            String format2 = String.format("stateFrom (%d) index exceeded sequence size (%d)", Arrays.copyOf(objArr2, objArr2.length));
            Intrinsics.checkExpressionValueIsNotNull(format2, "java.lang.String.format(this, *args)");
            throw new IllegalArgumentException(format2.toString());
        }
        ((LayerStructure) ArraysKt.last(this.sequence.getStateStructure(i2).getLayers())).setOutputRelevance(relevantOutcomesDistribution);
        IntProgression reversed = RangesKt.reversed(new IntRange(i, i2));
        int first = reversed.getFirst();
        int last = reversed.getLast();
        int step = reversed.getStep();
        if (step <= 0 ? first >= last : first <= last) {
            while (true) {
                this.curStateIndex = first;
                propagateRelevanceOnCurrentState(first == i, first == i2);
                if (first == last) {
                    break;
                }
                first += step;
            }
        }
        return getInputRelevance(i, z);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ NDArray calculateRelevance$default(RecurrentNeuralProcessor recurrentNeuralProcessor, int i, int i2, DistributionArray distributionArray, boolean z, int i3, Object obj) {
        if ((i3 & 8) != 0) {
            z = true;
        }
        return recurrentNeuralProcessor.calculateRelevance(i, i2, distributionArray, z);
    }

    public final void backward(@NotNull DenseNDArray outputErrors, boolean z, @Nullable List<Double> list) {
        Intrinsics.checkParameterIsNotNull(outputErrors, "outputErrors");
        DenseNDArray[] denseNDArrayArr = new DenseNDArray[getStatesSize()];
        int i = 0;
        int length = denseNDArrayArr.length;
        while (i < length) {
            denseNDArrayArr[i] = i == this.lastStateIndex ? outputErrors : this.zeroErrors;
            i++;
        }
        backward(denseNDArrayArr, z, list);
    }

    public static /* bridge */ /* synthetic */ void backward$default(RecurrentNeuralProcessor recurrentNeuralProcessor, DenseNDArray denseNDArray, boolean z, List list, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        if ((i & 4) != 0) {
            list = (List) null;
        }
        recurrentNeuralProcessor.backward(denseNDArray, z, (List<Double>) list);
    }

    public final void backward(@NotNull DenseNDArray[] outputErrorsSequence, boolean z, @Nullable List<Double> list) {
        Intrinsics.checkParameterIsNotNull(outputErrorsSequence, "outputErrorsSequence");
        if (!(outputErrorsSequence.length == getStatesSize())) {
            Object[] objArr = {Integer.valueOf(outputErrorsSequence.length), Integer.valueOf(getStatesSize())};
            String format = String.format("Number of errors (%d) does not reflect the length of the sequence (%d)", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            throw new IllegalArgumentException(format.toString());
        }
        IntProgression reversed = RangesKt.reversed(new IntRange(0, this.lastStateIndex));
        int first = reversed.getFirst();
        int last = reversed.getLast();
        int step = reversed.getStep();
        if (step <= 0 ? first >= last : first <= last) {
            while (true) {
                this.curStateIndex = first;
                this.sequence.getStateStructure(first).backward(outputErrorsSequence[first], this.backwardParamsErrors, z, list);
                ParamsErrorsAccumulator.accumulate$default(this.paramsErrorsAccumulator, this.backwardParamsErrors, false, 2, null);
                if (first == last) {
                    break;
                } else {
                    first += step;
                }
            }
        }
        this.paramsErrorsAccumulator.averageErrors();
    }

    public static /* bridge */ /* synthetic */ void backward$default(RecurrentNeuralProcessor recurrentNeuralProcessor, DenseNDArray[] denseNDArrayArr, boolean z, List list, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        if ((i & 4) != 0) {
            list = (List) null;
        }
        recurrentNeuralProcessor.backward(denseNDArrayArr, z, (List<Double>) list);
    }

    private final void addNewState(boolean z) {
        if (this.lastStateIndex == this.sequence.getLastIndex()) {
            this.sequence.add(new RecurrentNetworkStructure<>(getNeuralNetwork().getLayersConfiguration(), getNeuralNetwork().getModel(), this), z);
        }
        this.lastStateIndex++;
    }

    static /* bridge */ /* synthetic */ void addNewState$default(RecurrentNeuralProcessor recurrentNeuralProcessor, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = false;
        }
        recurrentNeuralProcessor.addNewState(z);
    }

    private final void forwardCurrentState(InputNDArrayType inputndarraytype, boolean z, boolean z2) {
        if (z) {
            this.sequence.getStateStructure(this.lastStateIndex).forward(inputndarraytype, this.sequence.getStateContributions(this.lastStateIndex), z2);
        } else {
            this.sequence.getStateStructure(this.lastStateIndex).forward(inputndarraytype, z2);
        }
    }

    static /* bridge */ /* synthetic */ void forwardCurrentState$default(RecurrentNeuralProcessor recurrentNeuralProcessor, NDArray nDArray, boolean z, boolean z2, int i, Object obj) {
        if ((i & 4) != 0) {
            z2 = false;
        }
        recurrentNeuralProcessor.forwardCurrentState(nDArray, z, z2);
    }

    private final void propagateRelevanceOnCurrentState(boolean z, boolean z2) {
        RecurrentNetworkStructure<InputNDArrayType> stateStructure = this.sequence.getStateStructure(this.curStateIndex);
        boolean z3 = z2;
        for (IndexedValue indexedValue : CollectionsKt.reversed(ArraysKt.withIndex(stateStructure.getLayers()))) {
            int component1 = indexedValue.component1();
            LayerStructure<?> layerStructure = (LayerStructure) indexedValue.component2();
            stateStructure.setCurLayerIndex(component1);
            boolean z4 = layerStructure instanceof RecurrentLayerStructure;
            boolean z5 = component1 > 0 && (stateStructure.getLayers()[component1 - 1] instanceof RecurrentLayerStructure);
            z3 = z3 || z4;
            if (z3) {
                propagateLayerRelevance(layerStructure, component1, !z && z4, component1 > 0 || z, z2 || !z5);
            }
        }
    }

    private final void propagateLayerRelevance(LayerStructure<?> layerStructure, int i, boolean z, boolean z2, boolean z3) {
        if (!(z2 || z)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        LayerParameters<?> layerParameters = this.sequence.getStateContributions(this.curStateIndex).getParamsPerLayer()[i];
        if (layerStructure instanceof GatedRecurrentLayerStructure) {
            ((GatedRecurrentLayerStructure) layerStructure).propagateRelevanceToGates(layerParameters);
        }
        if (z2) {
            if (z3) {
                layerStructure.setInputRelevance(layerParameters);
            } else {
                layerStructure.addInputRelevance(layerParameters);
            }
        }
        if (z) {
            if (layerStructure == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.recurrent.RecurrentLayerStructure<out com.kotlinnlp.simplednn.simplemath.ndarray.NDArray<*>>");
            }
            ((RecurrentLayerStructure) layerStructure).setRecurrentRelevance(layerParameters);
        }
    }

    /* JADX WARN: Type inference failed for: r0v13, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray<?>, com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    private final NDArray<?> getInputRelevance(int i, boolean z) {
        return z ? this.sequence.getStateStructure(i).getInputLayer().getInputArray().getRelevance().copy() : this.sequence.getStateStructure(i).getInputLayer().getInputArray().getRelevance();
    }

    static /* bridge */ /* synthetic */ NDArray getInputRelevance$default(RecurrentNeuralProcessor recurrentNeuralProcessor, int i, boolean z, int i2, Object obj) {
        if ((i2 & 2) != 0) {
            z = true;
        }
        return recurrentNeuralProcessor.getInputRelevance(i, z);
    }

    private final void reset() {
        this.lastStateIndex = -1;
        this.paramsErrorsAccumulator.reset();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public RecurrentNeuralProcessor(@NotNull NeuralNetwork neuralNetwork, int i) {
        super(neuralNetwork, i);
        Intrinsics.checkParameterIsNotNull(neuralNetwork, "neuralNetwork");
        this.sequence = new NNSequence<>(neuralNetwork);
        this.lastStateIndex = -1;
        this.backwardParamsErrors = NeuralNetwork.parametersFactory$default(getNeuralNetwork(), false, null, null, 6, null);
        this.paramsErrorsAccumulator = new ParamsErrorsAccumulator<>();
        this.zeroErrors = DenseNDArrayFactory.INSTANCE.zeros(new Shape(((LayerConfiguration) CollectionsKt.last((List) getNeuralNetwork().getLayersConfiguration())).getSize(), 0, 2, null));
    }

    public /* synthetic */ RecurrentNeuralProcessor(NeuralNetwork neuralNetwork, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(neuralNetwork, (i2 & 2) != 0 ? 0 : i);
    }
}
