package com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.SoftmaxBase;
import com.kotlinnlp.simplednn.core.layers.Layer;
import com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArrayMask;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import kotlin.sequences.Sequence;
import kotlin.sequences.SequencesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: ScaledDotAttentionBackwardHelper.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��*\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u000b\n��\b��\u0018��2\b\u0012\u0004\u0012\u00020\u00020\u0001B\r\u0012\u0006\u0010\u0003\u001a\u00020\u0004¢\u0006\u0002\u0010\u0005J\u0016\u0010\b\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00020\u000bH\u0002J\b\u0010\f\u001a\u00020\tH\u0002J\b\u0010\r\u001a\u00020\tH\u0002J\u0016\u0010\u000e\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00020\u000bH\u0002J\u0010\u0010\u000f\u001a\u00020\t2\u0006\u0010\u0010\u001a\u00020\u0011H\u0014R\u0014\u0010\u0003\u001a\u00020\u0004X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0006\u0010\u0007¨\u0006\u0012"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionBackwardHelper;", "Lcom/kotlinnlp/simplednn/core/layers/helpers/BackwardHelper;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayer;", "(Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayer;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayer;", "assignAttentionGradients", "", "outputErrors", "Lkotlin/sequences/Sequence;", "assignInputGradients", "assignParamsGradients", "assignValuesGradients", "execBackward", "propagateToInput", "", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionBackwardHelper.class */
public final class ScaledDotAttentionBackwardHelper extends BackwardHelper<DenseNDArray> {

    @NotNull
    private final ScaledDotAttentionLayer layer;

    /* JADX WARN: Type inference failed for: r0v1, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper
    protected void execBackward(boolean z) {
        Sequence<DenseNDArray> map = SequencesKt.map(CollectionsKt.asSequence(getLayer2().getOutputArrays()), new Function1<AugmentedArray<DenseNDArray>, DenseNDArray>() { // from class: com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionBackwardHelper$execBackward$outputErrors$1
            @NotNull
            public final DenseNDArray invoke(@NotNull AugmentedArray<DenseNDArray> augmentedArray) {
                Intrinsics.checkParameterIsNotNull(augmentedArray, "it");
                return augmentedArray.getErrors();
            }
        });
        assignValuesGradients(map);
        assignAttentionGradients(map);
        assignParamsGradients();
        if (z) {
            assignInputGradients();
        }
    }

    /* JADX WARN: Type inference failed for: r0v16, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v5, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v8, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r1v3, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    private final void assignValuesGradients(Sequence<DenseNDArray> sequence) {
        DenseNDArray fromRows = DenseNDArrayFactory.INSTANCE.fromRows(SequencesKt.toList(sequence));
        DenseNDArray fromColumns = DenseNDArrayFactory.INSTANCE.fromColumns(getLayer2().getAttentionAct$simplednn());
        NDArrayMask dropoutMaskFull$simplednn = getLayer2().getDropoutMaskFull$simplednn();
        if (dropoutMaskFull$simplednn == null || getLayer2().getValues$simplednn().assignErrors(fromColumns.dotLeftMasked(fromRows, dropoutMaskFull$simplednn)) == null) {
            getLayer2().getValues$simplednn().assignErrorsByDot(fromColumns, fromRows);
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v11, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v22, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v26, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v6, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r1v3, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    private final void assignAttentionGradients(Sequence<DenseNDArray> sequence) {
        DenseNDArray values = getLayer2().getKeys$simplednn().getValues();
        DenseNDArray values2 = getLayer2().getQueries$simplednn().getValues();
        final DenseNDArray values3 = getLayer2().getValues$simplednn().getValues();
        DenseNDArray fromRows = DenseNDArrayFactory.INSTANCE.fromRows(SequencesKt.toList(SequencesKt.mapIndexed(SequencesKt.zip(CollectionsKt.asSequence(getLayer2().getAttentionAct$simplednn()), SequencesKt.map(sequence, new Function1<DenseNDArray, DenseNDArray>() { // from class: com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionBackwardHelper$assignAttentionGradients$attentionErrors$1
            @NotNull
            public final DenseNDArray invoke(@NotNull DenseNDArray denseNDArray) {
                Intrinsics.checkParameterIsNotNull(denseNDArray, "it");
                return DenseNDArray.this.dot((NDArray<?>) denseNDArray);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(1);
            }
        })), new Function2<Integer, Pair<? extends DenseNDArray, ? extends DenseNDArray>, DenseNDArray>() { // from class: com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionBackwardHelper$assignAttentionGradients$attentionInnerErrors$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj, Object obj2) {
                return invoke(((Number) obj).intValue(), (Pair<DenseNDArray, DenseNDArray>) obj2);
            }

            /* JADX WARN: Type inference failed for: r0v9, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
            @NotNull
            public final DenseNDArray invoke(int i, @NotNull Pair<DenseNDArray, DenseNDArray> pair) {
                Intrinsics.checkParameterIsNotNull(pair, "<name for destructuring parameter 1>");
                DenseNDArray denseNDArray = (DenseNDArray) pair.component1();
                DenseNDArray denseNDArray2 = (DenseNDArray) pair.component2();
                List<NDArrayMask> dropoutMasks$simplednn = ScaledDotAttentionBackwardHelper.this.getLayer2().getDropoutMasks$simplednn();
                if (dropoutMasks$simplednn != null) {
                    DenseNDArray dotRightMasked = new SoftmaxBase().dfOptimized(denseNDArray).dotRightMasked(denseNDArray2, dropoutMasks$simplednn.get(i));
                    if (dotRightMasked != null) {
                        return dotRightMasked;
                    }
                }
                return new SoftmaxBase().dfOptimized(denseNDArray).dot((NDArray<?>) denseNDArray2);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(2);
            }
        })));
        fromRows.assignProd(getLayer2().getParams().getAttentionFactor$simplednn());
        getLayer2().getQueries$simplednn().assignErrorsByDot(fromRows, values);
        getLayer2().getKeys$simplednn().assignErrorsByDot(fromRows.getT(), values2);
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v15, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v2, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v20, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v24, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v28, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v32, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v36, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v40, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v7, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    private final void assignParamsGradients() {
        ScaledDotAttentionBackwardHelper$assignParamsGradients$1 scaledDotAttentionBackwardHelper$assignParamsGradients$1 = ScaledDotAttentionBackwardHelper$assignParamsGradients$1.INSTANCE;
        DenseNDArray values = getLayer2().getInputMatrix().getValues();
        DenseNDArray errors = getLayer2().getQueries$simplednn().getErrors();
        DenseNDArray errors2 = getLayer2().getKeys$simplednn().getErrors();
        DenseNDArray errors3 = getLayer2().getValues$simplednn().getErrors();
        getErrors(getLayer2().getParams().getQueries().getWeights()).getValues().assignDot(errors.getT(), values);
        getErrors(getLayer2().getParams().getKeys().getWeights()).getValues().assignDot(errors2.getT(), values);
        getErrors(getLayer2().getParams().getValues().getWeights()).getValues().assignDot(errors3.getT(), values);
        getErrors(getLayer2().getParams().getQueries().getBiases()).getValues().assignValues(scaledDotAttentionBackwardHelper$assignParamsGradients$1.invoke(errors));
        getErrors(getLayer2().getParams().getKeys().getBiases()).getValues().assignValues(scaledDotAttentionBackwardHelper$assignParamsGradients$1.invoke(errors2));
        getErrors(getLayer2().getParams().getValues().getBiases()).getValues().assignValues(scaledDotAttentionBackwardHelper$assignParamsGradients$1.invoke(errors3));
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v35, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v5, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    /* JADX WARN: Type inference failed for: r0v9, types: [com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer] */
    private final void assignInputGradients() {
        DenseNDArray errors = getLayer2().getQueries$simplednn().getErrors();
        DenseNDArray errors2 = getLayer2().getKeys$simplednn().getErrors();
        DenseNDArray errors3 = getLayer2().getValues$simplednn().getErrors();
        for (Pair pair : SequencesKt.zip(CollectionsKt.asSequence(getLayer2().getInputArrays()), CollectionsKt.asSequence(errors.dot((NDArray<?>) getLayer2().getParams().getQueries().getWeights().getValues()).assignSum((NDArray<?>) errors2.dot((NDArray<?>) getLayer2().getParams().getKeys().getWeights().getValues())).assignSum((NDArray<?>) errors3.dot((NDArray<?>) getLayer2().getParams().getValues().getWeights().getValues())).mo155getRows()))) {
            ((AugmentedArray) pair.component1()).assignErrors((DenseNDArray) pair.component2());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper
    @NotNull
    /* renamed from: getLayer */
    public Layer<DenseNDArray> getLayer2() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public ScaledDotAttentionBackwardHelper(@NotNull ScaledDotAttentionLayer scaledDotAttentionLayer) {
        super(scaledDotAttentionLayer);
        Intrinsics.checkParameterIsNotNull(scaledDotAttentionLayer, "layer");
        this.layer = scaledDotAttentionLayer;
    }
}
