package com.kotlinnlp.simplednn.deeplearning.attention.attentiverecurrentnetwork;

import com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerStructure;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.deeplearning.attention.attentionnetwork.AttentionNetwork;
import com.kotlinnlp.simplednn.deeplearning.attention.attentionnetwork.AttentionNetworkParameters;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.UninitializedPropertyAccessException;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: BackwardHelper.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��n\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010!\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010 \n\u0002\b\b\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\r\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u0014\u0010!\u001a\u00020\"2\f\u0010#\u001a\b\u0012\u0004\u0012\u00020\r0\u0014J\"\u0010$\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0%0\u00142\u0006\u0010&\u001a\u00020\rH\u0002J \u0010'\u001a\u00020\"2\u0006\u0010#\u001a\u00020\r2\u0006\u0010(\u001a\u00020)2\u0006\u0010*\u001a\u00020)H\u0002J\u001e\u0010+\u001a\u00020\r2\f\u0010,\u001a\b\u0012\u0004\u0012\u00020\r0-2\u0006\u0010#\u001a\u00020\rH\u0002J\b\u0010.\u001a\u00020\u0007H\u0002J\u0010\u0010/\u001a\u00020\r2\u0006\u0010#\u001a\u00020\rH\u0002J\u0010\u00100\u001a\u0002012\b\b\u0002\u00102\u001a\u00020)J\b\u00103\u001a\u00020\u001fH\u0002J\b\u00104\u001a\u00020\"H\u0002J\b\u00105\u001a\u00020\"H\u0002J\"\u00106\u001a\u00020\"2\u0018\u00107\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0%0\u0014H\u0002J\u0014\u00108\u001a\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0%H\u0002J\u001c\u00109\u001a\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0%2\u0006\u0010:\u001a\u00020\rH\u0002J\u001c\u0010;\u001a\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0%2\u0006\u0010<\u001a\u00020\rH\u0002J\u001c\u0010=\u001a\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0%2\u0006\u0010<\u001a\u00020\rH\u0002R\u0014\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\b\u001a\u00020\u0007X\u0082.¢\u0006\u0002\n��R\u0014\u0010\t\u001a\b\u0012\u0004\u0012\u00020\n0\u0006X\u0082\u000e¢\u0006\u0002\n��R\u0019\u0010\u000b\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010\r0\f¢\u0006\b\n��\u001a\u0004\b\u000e\u0010\u000fR\u0011\u0010\u0010\u001a\u00020\r8F¢\u0006\u0006\u001a\u0004\b\u0011\u0010\u0012R0\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\r0\u00142\f\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\r0\u0014@BX\u0086.¢\u0006\u000e\n��\u001a\u0004\b\u0016\u0010\u000f\"\u0004\b\u0017\u0010\u0018R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\n0\u0006X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u001a\u001a\u00020\rX\u0082.¢\u0006\u0002\n��R\u000e\u0010\u001b\u001a\u00020\rX\u0082.¢\u0006\u0002\n��R\u000e\u0010\u001c\u001a\u00020\u001dX\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u001f0\u0006X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010 \u001a\u00020\u001fX\u0082.¢\u0006\u0002\n��¨\u0006>"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/attention/attentiverecurrentnetwork/BackwardHelper;", "", "network", "Lcom/kotlinnlp/simplednn/deeplearning/attention/attentiverecurrentnetwork/AttentiveRecurrentNetwork;", "(Lcom/kotlinnlp/simplednn/deeplearning/attention/attentiverecurrentnetwork/AttentiveRecurrentNetwork;)V", "attentionNetworkAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "Lcom/kotlinnlp/simplednn/deeplearning/attention/attentionnetwork/AttentionNetworkParameters;", "attentionNetworkParamsErrors", "contextErrorsAccumulator", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "contextLabelsErrors", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getContextLabelsErrors", "()Ljava/util/List;", "initHiddenErrors", "getInitHiddenErrors", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "<set-?>", "", "inputSequenceErrors", "getInputSequenceErrors", "setInputSequenceErrors", "(Ljava/util/List;)V", "outputErrorsAccumulator", "recurrentContextErrors", "recurrentStateEncodingErrors", "stateIndex", "", "transformLayerAccumulator", "Lcom/kotlinnlp/simplednn/core/layers/models/feedforward/simple/FeedforwardLayerParameters;", "transformLayerParamsErrors", "backward", "", "outputErrors", "backwardStateEncoder", "Lkotlin/Pair;", "stateEncodingErrors", "backwardStep", "isFirstState", "", "isLastState", "backwardTransformLayer", "layer", "Lcom/kotlinnlp/simplednn/core/layers/models/feedforward/simple/FeedforwardLayerStructure;", "getAttentionParamsErrors", "getOutputNetworkInputErrors", "getParamsErrors", "Lcom/kotlinnlp/simplednn/deeplearning/attention/attentiverecurrentnetwork/AttentiveRecurrentNetworkParameters;", "copy", "getTransformParamsErrors", "initBackward", "initSequenceErrors", "propagateStateEncodingErrors", "stateEncoderInputErrors", "recurrentContextBackwardStep", "splitOutputNetworkErrors", "outputNetworkErrors", "splitRNNInputErrors", "errors", "splitTransformErrors", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/attention/attentiverecurrentnetwork/BackwardHelper.class */
public final class BackwardHelper {

    @NotNull
    private final List<DenseNDArray> contextLabelsErrors;

    @NotNull
    private List<DenseNDArray> inputSequenceErrors;
    private int stateIndex;
    private ParamsErrorsAccumulator<FeedforwardLayerParameters> transformLayerAccumulator;
    private ParamsErrorsAccumulator<AttentionNetworkParameters> attentionNetworkAccumulator;
    private ParamsErrorsAccumulator<NetworkParameters> contextErrorsAccumulator;
    private ParamsErrorsAccumulator<NetworkParameters> outputErrorsAccumulator;
    private FeedforwardLayerParameters transformLayerParamsErrors;
    private AttentionNetworkParameters attentionNetworkParamsErrors;
    private DenseNDArray recurrentContextErrors;
    private DenseNDArray recurrentStateEncodingErrors;
    private final AttentiveRecurrentNetwork network;

    @NotNull
    public final List<DenseNDArray> getContextLabelsErrors() {
        return this.contextLabelsErrors;
    }

    @NotNull
    public final List<DenseNDArray> getInputSequenceErrors() {
        List<DenseNDArray> list = this.inputSequenceErrors;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("inputSequenceErrors");
        }
        return list;
    }

    private final void setInputSequenceErrors(List<DenseNDArray> list) {
        this.inputSequenceErrors = list;
    }

    @NotNull
    public final DenseNDArray getInitHiddenErrors() {
        Object first = CollectionsKt.first(this.network.getRecurrentContextProcessor().getInitHiddenErrors());
        if (first == null) {
            Intrinsics.throwNpe();
        }
        return (DenseNDArray) first;
    }

    public final void backward(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        initBackward();
        IntIterator it = RangesKt.reversed(RangesKt.until(0, list.size())).iterator();
        while (it.hasNext()) {
            int nextInt = it.nextInt();
            this.stateIndex = nextInt;
            backwardStep(list.get(nextInt), nextInt == 0, nextInt == CollectionsKt.getLastIndex(list));
        }
        ParamsErrorsAccumulator.accumulate$default(this.contextErrorsAccumulator, this.network.getRecurrentContextProcessor().getParamsErrors(false), false, 2, null);
        this.transformLayerAccumulator.averageErrors();
        this.attentionNetworkAccumulator.averageErrors();
        this.outputErrorsAccumulator.averageErrors();
    }

    @NotNull
    public final AttentiveRecurrentNetworkParameters getParamsErrors(boolean z) {
        return new AttentiveRecurrentNetworkParameters(this.attentionNetworkAccumulator.getParamsErrors(z), this.transformLayerAccumulator.getParamsErrors(z), this.contextErrorsAccumulator.getParamsErrors(z), this.outputErrorsAccumulator.getParamsErrors(z));
    }

    @NotNull
    public static /* bridge */ /* synthetic */ AttentiveRecurrentNetworkParameters getParamsErrors$default(BackwardHelper backwardHelper, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return backwardHelper.getParamsErrors(z);
    }

    private final void initBackward() {
        initSequenceErrors();
        this.contextLabelsErrors.clear();
        this.transformLayerAccumulator.reset();
        this.attentionNetworkAccumulator.reset();
        this.contextErrorsAccumulator.reset();
        this.outputErrorsAccumulator.reset();
    }

    private final void initSequenceErrors() {
        int sequenceSize = this.network.getSequenceSize();
        ArrayList arrayList = new ArrayList(sequenceSize);
        for (int i = 0; i < sequenceSize; i++) {
            arrayList.add(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.network.getModel().getInputSize(), 0, 2, null)));
        }
        this.inputSequenceErrors = arrayList;
    }

    private final void backwardStep(DenseNDArray denseNDArray, boolean z, boolean z2) {
        DenseNDArray assignSum;
        Pair<DenseNDArray, DenseNDArray> splitOutputNetworkErrors = splitOutputNetworkErrors(getOutputNetworkInputErrors(denseNDArray));
        DenseNDArray denseNDArray2 = (DenseNDArray) splitOutputNetworkErrors.component1();
        DenseNDArray denseNDArray3 = (DenseNDArray) splitOutputNetworkErrors.component2();
        if (z2) {
            assignSum = denseNDArray2;
        } else {
            DenseNDArray denseNDArray4 = this.recurrentStateEncodingErrors;
            if (denseNDArray4 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("recurrentStateEncodingErrors");
            }
            assignSum = denseNDArray2.assignSum((NDArray<?>) denseNDArray4);
        }
        propagateStateEncodingErrors(backwardStateEncoder(assignSum));
        DenseNDArray denseNDArray5 = this.recurrentContextErrors;
        if (denseNDArray5 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("recurrentContextErrors");
        }
        denseNDArray5.assignSum((NDArray<?>) denseNDArray3);
        if (z) {
            this.contextLabelsErrors.add(0, null);
            return;
        }
        Pair<DenseNDArray, DenseNDArray> recurrentContextBackwardStep = recurrentContextBackwardStep();
        DenseNDArray denseNDArray6 = (DenseNDArray) recurrentContextBackwardStep.component1();
        this.contextLabelsErrors.add(0, (DenseNDArray) recurrentContextBackwardStep.component2());
        this.recurrentStateEncodingErrors = denseNDArray6;
    }

    private final Pair<DenseNDArray, DenseNDArray> splitOutputNetworkErrors(DenseNDArray denseNDArray) {
        List<DenseNDArray> splitV = denseNDArray.splitV(this.network.getModel().getAttentionParams().getOutputSize(), this.network.getModel().getRecurrentContextSize());
        return new Pair<>(splitV.get(0), splitV.get(1));
    }

    private final DenseNDArray getOutputNetworkInputErrors(DenseNDArray denseNDArray) {
        FeedforwardNeuralProcessor<DenseNDArray> feedforwardNeuralProcessor = this.network.getUsedOutputProcessors().get(this.stateIndex);
        FeedforwardNeuralProcessor.backward$default(feedforwardNeuralProcessor, denseNDArray, true, null, 4, null);
        ParamsErrorsAccumulator.accumulate$default(this.outputErrorsAccumulator, feedforwardNeuralProcessor.getParamsErrors(false), false, 2, null);
        return feedforwardNeuralProcessor.getInputErrors(false);
    }

    private final List<Pair<DenseNDArray, DenseNDArray>> backwardStateEncoder(DenseNDArray denseNDArray) {
        AttentionNetwork<DenseNDArray> attentionNetwork = this.network.getUsedAttentionNetworks().get(this.stateIndex);
        AttentionNetworkParameters attentionParamsErrors = getAttentionParamsErrors();
        AttentionNetwork.backward$default(attentionNetwork, denseNDArray, attentionParamsErrors, true, null, 8, null);
        ParamsErrorsAccumulator.accumulate$default(this.attentionNetworkAccumulator, attentionParamsErrors, false, 2, null);
        return CollectionsKt.zip(attentionNetwork.getInputErrors(), attentionNetwork.getAttentionErrors());
    }

    private final void propagateStateEncodingErrors(List<Pair<DenseNDArray, DenseNDArray>> list) {
        DenseNDArray assignSum;
        List<FeedforwardLayerStructure<DenseNDArray>> list2 = this.network.getUsedTransformLayers().get(this.stateIndex);
        int i = 0;
        for (Object obj : list) {
            int i2 = i;
            i++;
            Pair pair = (Pair) obj;
            DenseNDArray denseNDArray = (DenseNDArray) pair.component1();
            Pair<DenseNDArray, DenseNDArray> splitTransformErrors = splitTransformErrors(backwardTransformLayer(list2.get(i2), (DenseNDArray) pair.component2()));
            DenseNDArray denseNDArray2 = (DenseNDArray) splitTransformErrors.component1();
            DenseNDArray denseNDArray3 = (DenseNDArray) splitTransformErrors.component2();
            if (i2 == 0) {
                assignSum = denseNDArray3;
            } else {
                DenseNDArray denseNDArray4 = this.recurrentContextErrors;
                if (denseNDArray4 == null) {
                    Intrinsics.throwUninitializedPropertyAccessException("recurrentContextErrors");
                }
                assignSum = denseNDArray4.assignSum((NDArray<?>) denseNDArray3);
            }
            this.recurrentContextErrors = assignSum;
            List<DenseNDArray> list3 = this.inputSequenceErrors;
            if (list3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("inputSequenceErrors");
            }
            list3.get(i2).assignSum((NDArray<?>) denseNDArray2.sum(denseNDArray));
        }
    }

    private final DenseNDArray backwardTransformLayer(FeedforwardLayerStructure<DenseNDArray> feedforwardLayerStructure, DenseNDArray denseNDArray) {
        FeedforwardLayerParameters transformParamsErrors = getTransformParamsErrors();
        feedforwardLayerStructure.setErrors(denseNDArray);
        feedforwardLayerStructure.backward(transformParamsErrors, true, null);
        ParamsErrorsAccumulator.accumulate$default(this.transformLayerAccumulator, transformParamsErrors, false, 2, null);
        return feedforwardLayerStructure.getInputArray().getErrors();
    }

    private final Pair<DenseNDArray, DenseNDArray> recurrentContextBackwardStep() {
        RecurrentNeuralProcessor<DenseNDArray> recurrentContextProcessor = this.network.getRecurrentContextProcessor();
        DenseNDArray denseNDArray = this.recurrentContextErrors;
        if (denseNDArray == null) {
            Intrinsics.throwUninitializedPropertyAccessException("recurrentContextErrors");
        }
        RecurrentNeuralProcessor.backwardStep$default(recurrentContextProcessor, denseNDArray, true, null, 4, null);
        return splitRNNInputErrors(this.network.getRecurrentContextProcessor().getInputErrors(this.stateIndex - 1, false));
    }

    private final Pair<DenseNDArray, DenseNDArray> splitTransformErrors(DenseNDArray denseNDArray) {
        List<DenseNDArray> splitV = denseNDArray.splitV(this.network.getModel().getInputSize(), this.network.getModel().getRecurrentContextSize());
        return new Pair<>(splitV.get(0), splitV.get(1));
    }

    private final Pair<DenseNDArray, DenseNDArray> splitRNNInputErrors(DenseNDArray denseNDArray) {
        List<DenseNDArray> splitV = denseNDArray.splitV(this.network.getModel().getAttentionParams().getOutputSize(), this.network.getModel().getContextLabelSize());
        return new Pair<>(splitV.get(0), splitV.get(1));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final FeedforwardLayerParameters getTransformParamsErrors() {
        FeedforwardLayerParameters feedforwardLayerParameters;
        try {
            FeedforwardLayerParameters feedforwardLayerParameters2 = this.transformLayerParamsErrors;
            if (feedforwardLayerParameters2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("transformLayerParamsErrors");
            }
            feedforwardLayerParameters = feedforwardLayerParameters2;
        } catch (UninitializedPropertyAccessException e) {
            SelfType copy = ((FeedforwardLayerStructure) CollectionsKt.last((List) CollectionsKt.last(this.network.getUsedTransformLayers()))).getParams().copy();
            if (copy == 0) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
            }
            this.transformLayerParamsErrors = (FeedforwardLayerParameters) copy;
            FeedforwardLayerParameters feedforwardLayerParameters3 = this.transformLayerParamsErrors;
            if (feedforwardLayerParameters3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("transformLayerParamsErrors");
            }
            feedforwardLayerParameters = feedforwardLayerParameters3;
        }
        return feedforwardLayerParameters;
    }

    private final AttentionNetworkParameters getAttentionParamsErrors() {
        AttentionNetworkParameters attentionNetworkParameters;
        try {
            AttentionNetworkParameters attentionNetworkParameters2 = this.attentionNetworkParamsErrors;
            if (attentionNetworkParameters2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("attentionNetworkParamsErrors");
            }
            attentionNetworkParameters = attentionNetworkParameters2;
        } catch (UninitializedPropertyAccessException e) {
            this.attentionNetworkParamsErrors = ((AttentionNetwork) CollectionsKt.last(this.network.getUsedAttentionNetworks())).getModel().copy();
            AttentionNetworkParameters attentionNetworkParameters3 = this.attentionNetworkParamsErrors;
            if (attentionNetworkParameters3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("attentionNetworkParamsErrors");
            }
            attentionNetworkParameters = attentionNetworkParameters3;
        }
        return attentionNetworkParameters;
    }

    public BackwardHelper(@NotNull AttentiveRecurrentNetwork attentiveRecurrentNetwork) {
        Intrinsics.checkParameterIsNotNull(attentiveRecurrentNetwork, "network");
        this.network = attentiveRecurrentNetwork;
        this.contextLabelsErrors = new ArrayList();
        this.transformLayerAccumulator = new ParamsErrorsAccumulator<>();
        this.attentionNetworkAccumulator = new ParamsErrorsAccumulator<>();
        this.contextErrorsAccumulator = new ParamsErrorsAccumulator<>();
        this.outputErrorsAccumulator = new ParamsErrorsAccumulator<>();
    }
}
