package com.kotlinnlp.simplednn.deeplearning.attention.multihead;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer;
import com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayerParameters;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.utils.BaseExtensionsKt;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: MultiHeadAttentionNetwork.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��R\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\u0018��22\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u00020\u0001B\u001f\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\b\b\u0002\u0010\b\u001a\u00020\t¢\u0006\u0002\u0010\nJ\u0016\u0010\u0017\u001a\u00020\u00182\f\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\u00030\u0002H\u0016J\u001c\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u00030\u00022\f\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u00030\u0002H\u0016J\u0016\u0010\u001c\u001a\b\u0012\u0004\u0012\u00020\u00030\u00022\u0006\u0010\u001d\u001a\u00020\u0007H\u0016J\"\u0010\u001e\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030\u001fR\u00020 0\u0002j\u0002`!2\u0006\u0010\u001d\u001a\u00020\u0007H\u0016R\u0014\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\f0\u0002X\u0082.¢\u0006\u0002\n��R\u000e\u0010\r\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\b\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\u0011\u001a\b\u0012\u0004\u0012\u00020\u00030\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u0014\u0010\u0006\u001a\u00020\u0007X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0016¨\u0006\""}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/attention/multihead/MultiHeadAttentionNetwork;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "model", "Lcom/kotlinnlp/simplednn/deeplearning/attention/multihead/MultiHeadAttentionParameters;", "propagateToInput", "", "id", "", "(Lcom/kotlinnlp/simplednn/deeplearning/attention/multihead/MultiHeadAttentionParameters;ZI)V", "attentionLayers", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayer;", "errorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "getId", "()I", "mergeProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "getModel", "()Lcom/kotlinnlp/simplednn/deeplearning/attention/multihead/MultiHeadAttentionParameters;", "getPropagateToInput", "()Z", "backward", "", "outputErrors", "forward", "input", "getInputErrors", "copy", "getParamsErrors", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray$Errors;", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsList;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/attention/multihead/MultiHeadAttentionNetwork.class */
public final class MultiHeadAttentionNetwork implements NeuralProcessor<List<? extends DenseNDArray>, List<? extends DenseNDArray>, List<? extends DenseNDArray>, List<? extends DenseNDArray>> {
    private final ParamsErrorsAccumulator errorsAccumulator;
    private List<ScaledDotAttentionLayer> attentionLayers;
    private final BatchFeedforwardProcessor<DenseNDArray> mergeProcessor;

    @NotNull
    private final MultiHeadAttentionParameters model;
    private final boolean propagateToInput;
    private final int id;

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    /* renamed from: forward, reason: avoid collision after fix types in other method */
    public List<DenseNDArray> forward2(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "input");
        List<ScaledDotAttentionLayerParameters> attention = this.model.getAttention();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(attention, 10));
        for (ScaledDotAttentionLayerParameters scaledDotAttentionLayerParameters : attention) {
            List<DenseNDArray> list2 = list;
            ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
            Iterator<T> it = list2.iterator();
            while (it.hasNext()) {
                arrayList2.add(AugmentedArray.Companion.invoke((DenseNDArray) it.next()));
            }
            arrayList.add(new ScaledDotAttentionLayer(arrayList2, scaledDotAttentionLayerParameters, 0.0d, 4, null));
        }
        this.attentionLayers = arrayList;
        List<ScaledDotAttentionLayer> list3 = this.attentionLayers;
        if (list3 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayers");
        }
        List<ScaledDotAttentionLayer> list4 = list3;
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list4, 10));
        for (ScaledDotAttentionLayer scaledDotAttentionLayer : list4) {
            scaledDotAttentionLayer.forward();
            List<AugmentedArray<DenseNDArray>> outputArrays = scaledDotAttentionLayer.getOutputArrays();
            ArrayList arrayList4 = new ArrayList(CollectionsKt.collectionSizeOrDefault(outputArrays, 10));
            Iterator<T> it2 = outputArrays.iterator();
            while (it2.hasNext()) {
                arrayList4.add((DenseNDArray) ((AugmentedArray) it2.next()).getValues());
            }
            arrayList3.add(arrayList4);
        }
        ArrayList arrayList5 = arrayList3;
        BatchFeedforwardProcessor<DenseNDArray> batchFeedforwardProcessor = this.mergeProcessor;
        List foldUp = BaseExtensionsKt.foldUp(arrayList5);
        if (foldUp == null) {
            throw new TypeCastException("null cannot be cast to non-null type java.util.Collection<T>");
        }
        Object[] array = foldUp.toArray(new List[0]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        return BatchFeedforwardProcessor.forward$default((BatchFeedforwardProcessor) batchFeedforwardProcessor, (List[]) array, false, 2, (Object) null);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ List<? extends DenseNDArray> forward(List<? extends DenseNDArray> list) {
        return forward2((List<DenseNDArray>) list);
    }

    /* renamed from: backward, reason: avoid collision after fix types in other method */
    public void backward2(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        this.errorsAccumulator.clear();
        this.mergeProcessor.backward2(list);
        ParamsErrorsAccumulator.accumulate$default(this.errorsAccumulator, (List) this.mergeProcessor.getParamsErrors(false), false, 2, (Object) null);
        List foldUp = BaseExtensionsKt.foldUp(this.mergeProcessor.getInputsErrors(false));
        List<ScaledDotAttentionLayer> list2 = this.attentionLayers;
        if (list2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayers");
        }
        for (Pair pair : CollectionsKt.zip(list2, foldUp)) {
            ScaledDotAttentionLayer scaledDotAttentionLayer = (ScaledDotAttentionLayer) pair.component1();
            for (Pair pair2 : CollectionsKt.zip(scaledDotAttentionLayer.getOutputArrays(), (List) pair.component2())) {
                ((AugmentedArray) pair2.component1()).assignErrors((DenseNDArray) pair2.component2());
            }
            this.errorsAccumulator.accumulate((List<? extends ParamsArray.Errors<?>>) scaledDotAttentionLayer.backward(getPropagateToInput()), false);
        }
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ void backward(List<? extends DenseNDArray> list) {
        backward2((List<DenseNDArray>) list);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public List<ParamsArray.Errors<?>> getParamsErrors(boolean z) {
        return this.errorsAccumulator.getParamsErrors(z);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getInputErrors */
    public List<? extends DenseNDArray> getInputErrors2(boolean z) {
        List<ScaledDotAttentionLayer> list = this.attentionLayers;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayers");
        }
        List<AugmentedArray<DenseNDArray>> inputArrays = ((ScaledDotAttentionLayer) CollectionsKt.first(list)).getInputArrays();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(inputArrays, 10));
        Iterator<T> it = inputArrays.iterator();
        while (it.hasNext()) {
            arrayList.add(((AugmentedArray) it.next()).getErrors().copy());
        }
        ArrayList arrayList2 = arrayList;
        List<ScaledDotAttentionLayer> list2 = this.attentionLayers;
        if (list2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayers");
        }
        List<ScaledDotAttentionLayer> list3 = this.attentionLayers;
        if (list3 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayers");
        }
        Iterator it2 = CollectionsKt.takeLast(list2, list3.size() - 1).iterator();
        while (it2.hasNext()) {
            for (Pair pair : CollectionsKt.zip(arrayList2, ((ScaledDotAttentionLayer) it2.next()).getInputArrays())) {
                ((DenseNDArray) pair.component1()).assignSum((NDArray<?>) ((AugmentedArray) pair.component2()).getErrors());
            }
        }
        return arrayList2;
    }

    @NotNull
    public final MultiHeadAttentionParameters getModel() {
        return this.model;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    public int getId() {
        return this.id;
    }

    public MultiHeadAttentionNetwork(@NotNull MultiHeadAttentionParameters multiHeadAttentionParameters, boolean z, int i) {
        Intrinsics.checkParameterIsNotNull(multiHeadAttentionParameters, "model");
        this.model = multiHeadAttentionParameters;
        this.propagateToInput = z;
        this.id = i;
        this.errorsAccumulator = new ParamsErrorsAccumulator();
        this.mergeProcessor = new BatchFeedforwardProcessor<>(this.model.getMerge(), 0.0d, true, 0, 10, (DefaultConstructorMarker) null);
    }

    public /* synthetic */ MultiHeadAttentionNetwork(MultiHeadAttentionParameters multiHeadAttentionParameters, boolean z, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(multiHeadAttentionParameters, z, (i2 & 4) != 0 ? 0 : i);
    }

    @NotNull
    /* renamed from: propagateErrors, reason: avoid collision after fix types in other method */
    public List<DenseNDArray> propagateErrors2(@NotNull List<DenseNDArray> list, @NotNull ParamsOptimizer paramsOptimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(list, "errors");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer, "optimizer");
        return (List) NeuralProcessor.DefaultImpls.propagateErrors(this, list, paramsOptimizer, z);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ List<? extends DenseNDArray> propagateErrors(List<? extends DenseNDArray> list, ParamsOptimizer paramsOptimizer, boolean z) {
        return propagateErrors2((List<DenseNDArray>) list, paramsOptimizer, z);
    }
}
