package com.kotlinnlp.simplednn.deeplearning.attention.attentionnetwork;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.models.attention.AttentionLayer;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.utils.ItemsPool;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: AttentionNetwork.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��d\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n��\n\u0002\u0010\u000b\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\f\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B1\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\b\b\u0002\u0010\b\u001a\u00020\t\u0012\u0006\u0010\n\u001a\u00020\u000b\u0012\b\b\u0002\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ\b\u0010\u001c\u001a\u00020\u001dH\u0002J \u0010\u001e\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030 R\u00020!0\u001fj\u0002`\"2\u0006\u0010#\u001a\u00020$J,\u0010%\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030 R\u00020!0\u001fj\u0002`\"2\u0006\u0010#\u001a\u00020$2\b\b\u0002\u0010\n\u001a\u00020\u000bH\u0002J\u001a\u0010&\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030 R\u00020!0\u001fj\u0002`\"H\u0002J\u0014\u0010'\u001a\u00020$2\f\u0010(\u001a\b\u0012\u0004\u0012\u00028��0\u001fJ\"\u0010'\u001a\u00020$2\f\u0010(\u001a\b\u0012\u0004\u0012\u00028��0\u001f2\f\u0010)\u001a\b\u0012\u0004\u0012\u00020$0\u001fJ\f\u0010*\u001a\b\u0012\u0004\u0012\u00020$0\u001fJ\u0010\u0010+\u001a\u00020$2\b\b\u0002\u0010,\u001a\u00020\u000bJ\f\u0010-\u001a\b\u0012\u0004\u0012\u00020$0\u001fJ\u0010\u0010.\u001a\u00020$2\b\b\u0002\u0010,\u001a\u00020\u000bJ\u0016\u0010/\u001a\u00020\u001d2\f\u0010(\u001a\b\u0012\u0004\u0012\u00028��0\u001fH\u0002R\u0014\u0010\u000f\u001a\b\u0012\u0004\u0012\u00028��0\u0010X\u0082.¢\u0006\u0002\n��R\u0014\u0010\f\u001a\u00020\rX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u0011\u0010\u0006\u001a\u00020\u0007¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u000e\u0010\u0015\u001a\u00020\u000bX\u0082\u000e¢\u0006\u0002\n��R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u0017R\u000e\u0010\n\u001a\u00020\u000bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0018\u001a\u00020\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u001a\u001a\b\u0012\u0004\u0012\u00028��0\u001bX\u0082\u0004¢\u0006\u0002\n��¨\u00060"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/attention/attentionnetwork/AttentionNetwork;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/utils/ItemsPool$IDItem;", "model", "Lcom/kotlinnlp/simplednn/deeplearning/attention/attentionnetwork/AttentionNetworkParameters;", "inputType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "dropout", "", "propagateToInput", "", "id", "", "(Lcom/kotlinnlp/simplednn/deeplearning/attention/attentionnetwork/AttentionNetworkParameters;Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;DZI)V", "attentionLayer", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/AttentionLayer;", "getId", "()I", "getInputType", "()Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "internalAttentionArraysUsed", "getModel", "()Lcom/kotlinnlp/simplednn/deeplearning/attention/attentionnetwork/AttentionNetworkParameters;", "transformParamsErrorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "transformProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "addTransformErrorsToInput", "", "backward", "", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray$Errors;", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsList;", "outputErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "backwardAttentionLayer", "backwardTransformLayers", "forward", "inputSequence", "attentionSequence", "getAttentionErrors", "getImportanceScore", "copy", "getInputErrors", "getOutput", "setInputSequence", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/attention/attentionnetwork/AttentionNetwork.class */
public final class AttentionNetwork<InputNDArrayType extends NDArray<InputNDArrayType>> implements ItemsPool.IDItem {
    private final ParamsErrorsAccumulator transformParamsErrorsAccumulator;
    private final BatchFeedforwardProcessor<InputNDArrayType> transformProcessor;
    private AttentionLayer<InputNDArrayType> attentionLayer;
    private boolean internalAttentionArraysUsed;

    @NotNull
    private final AttentionNetworkParameters model;

    @NotNull
    private final LayerType.Input inputType;
    private final boolean propagateToInput;
    private final int id;

    @NotNull
    public final DenseNDArray forward(@NotNull List<? extends InputNDArrayType> list) {
        Intrinsics.checkParameterIsNotNull(list, "inputSequence");
        this.internalAttentionArraysUsed = true;
        setInputSequence(list);
        AttentionLayer<InputNDArrayType> attentionLayer = this.attentionLayer;
        if (attentionLayer == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        attentionLayer.forward();
        AttentionLayer<InputNDArrayType> attentionLayer2 = this.attentionLayer;
        if (attentionLayer2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        return attentionLayer2.getOutputArray().getValues();
    }

    @NotNull
    public final DenseNDArray forward(@NotNull List<? extends InputNDArrayType> list, @NotNull List<DenseNDArray> list2) {
        Intrinsics.checkParameterIsNotNull(list, "inputSequence");
        Intrinsics.checkParameterIsNotNull(list2, "attentionSequence");
        this.internalAttentionArraysUsed = false;
        List<? extends InputNDArrayType> list3 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list3, 10));
        Iterator<T> it = list3.iterator();
        while (it.hasNext()) {
            arrayList.add(AugmentedArray.Companion.invoke((NDArray) it.next()));
        }
        ArrayList arrayList2 = arrayList;
        LayerType.Input input = this.inputType;
        List<DenseNDArray> list4 = list2;
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list4, 10));
        Iterator<T> it2 = list4.iterator();
        while (it2.hasNext()) {
            arrayList3.add(AugmentedArray.Companion.invoke((DenseNDArray) it2.next()));
        }
        this.attentionLayer = new AttentionLayer<>(arrayList2, input, arrayList3, this.model.getAttention(), null, 16, null);
        AttentionLayer<InputNDArrayType> attentionLayer = this.attentionLayer;
        if (attentionLayer == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        attentionLayer.forward();
        AttentionLayer<InputNDArrayType> attentionLayer2 = this.attentionLayer;
        if (attentionLayer2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        return attentionLayer2.getOutputArray().getValues();
    }

    @NotNull
    public final List<ParamsArray.Errors<?>> backward(@NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "outputErrors");
        ArrayList arrayList = new ArrayList();
        arrayList.add(backwardAttentionLayer(denseNDArray, this.propagateToInput));
        if (this.internalAttentionArraysUsed) {
            arrayList.add(backwardTransformLayers());
            if (this.propagateToInput) {
                addTransformErrorsToInput();
            }
        }
        return CollectionsKt.flatten(arrayList);
    }

    @NotNull
    public final DenseNDArray getOutput(boolean z) {
        if (z) {
            AttentionLayer<InputNDArrayType> attentionLayer = this.attentionLayer;
            if (attentionLayer == null) {
                Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
            }
            return attentionLayer.getOutputArray().getValues().copy();
        }
        AttentionLayer<InputNDArrayType> attentionLayer2 = this.attentionLayer;
        if (attentionLayer2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        return attentionLayer2.getOutputArray().getValues();
    }

    @NotNull
    public static /* synthetic */ DenseNDArray getOutput$default(AttentionNetwork attentionNetwork, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return attentionNetwork.getOutput(z);
    }

    @NotNull
    public final List<DenseNDArray> getInputErrors() {
        AttentionLayer<InputNDArrayType> attentionLayer = this.attentionLayer;
        if (attentionLayer == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        List<AugmentedArray<InputNDArrayType>> inputArrays = attentionLayer.getInputArrays();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(inputArrays, 10));
        Iterator<T> it = inputArrays.iterator();
        while (it.hasNext()) {
            arrayList.add(((AugmentedArray) it.next()).getErrors());
        }
        return arrayList;
    }

    @NotNull
    public final List<DenseNDArray> getAttentionErrors() {
        AttentionLayer<InputNDArrayType> attentionLayer = this.attentionLayer;
        if (attentionLayer == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        List<AugmentedArray<DenseNDArray>> attentionArrays = attentionLayer.getAttentionArrays();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(attentionArrays, 10));
        Iterator<T> it = attentionArrays.iterator();
        while (it.hasNext()) {
            arrayList.add(((AugmentedArray) it.next()).getErrors());
        }
        return arrayList;
    }

    @NotNull
    public final DenseNDArray getImportanceScore(boolean z) {
        if (z) {
            AttentionLayer<InputNDArrayType> attentionLayer = this.attentionLayer;
            if (attentionLayer == null) {
                Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
            }
            return attentionLayer.getAttentionScores().getValues().copy();
        }
        AttentionLayer<InputNDArrayType> attentionLayer2 = this.attentionLayer;
        if (attentionLayer2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        return attentionLayer2.getAttentionScores().getValues();
    }

    @NotNull
    public static /* synthetic */ DenseNDArray getImportanceScore$default(AttentionNetwork attentionNetwork, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return attentionNetwork.getImportanceScore(z);
    }

    private final void setInputSequence(List<? extends InputNDArrayType> list) {
        List<DenseNDArray> forward = this.transformProcessor.forward((List) list);
        List<? extends InputNDArrayType> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(AugmentedArray.Companion.invoke((NDArray) it.next()));
        }
        ArrayList arrayList2 = arrayList;
        LayerType.Input input = this.inputType;
        List<DenseNDArray> list3 = forward;
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list3, 10));
        Iterator<T> it2 = list3.iterator();
        while (it2.hasNext()) {
            arrayList3.add(AugmentedArray.Companion.invoke((DenseNDArray) it2.next()));
        }
        this.attentionLayer = new AttentionLayer<>(arrayList2, input, arrayList3, this.model.getAttention(), null, 16, null);
    }

    private final List<ParamsArray.Errors<?>> backwardAttentionLayer(DenseNDArray denseNDArray, boolean z) {
        AttentionLayer<InputNDArrayType> attentionLayer = this.attentionLayer;
        if (attentionLayer == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        attentionLayer.getOutputArray().assignErrors(denseNDArray);
        AttentionLayer<InputNDArrayType> attentionLayer2 = this.attentionLayer;
        if (attentionLayer2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        return attentionLayer2.backward(z);
    }

    static /* synthetic */ List backwardAttentionLayer$default(AttentionNetwork attentionNetwork, DenseNDArray denseNDArray, boolean z, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        return attentionNetwork.backwardAttentionLayer(denseNDArray, z);
    }

    private final List<ParamsArray.Errors<?>> backwardTransformLayers() {
        this.transformProcessor.backward2(getAttentionErrors());
        ParamsErrorsAccumulator.accumulate$default(this.transformParamsErrorsAccumulator, (List) this.transformProcessor.getParamsErrors(false), false, 2, (Object) null);
        List<ParamsArray.Errors<?>> paramsErrors = this.transformParamsErrorsAccumulator.getParamsErrors(true);
        this.transformParamsErrorsAccumulator.clear();
        return paramsErrors;
    }

    private final void addTransformErrorsToInput() {
        List<? extends DenseNDArray> inputErrors2 = this.transformProcessor.getInputErrors2(false);
        AttentionLayer<InputNDArrayType> attentionLayer = this.attentionLayer;
        if (attentionLayer == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        for (Pair pair : CollectionsKt.zip(attentionLayer.getInputArrays(), inputErrors2)) {
            AugmentedArray augmentedArray = (AugmentedArray) pair.component1();
            augmentedArray.getErrors().assignSum((NDArray<?>) pair.component2());
        }
    }

    @NotNull
    public final AttentionNetworkParameters getModel() {
        return this.model;
    }

    @NotNull
    public final LayerType.Input getInputType() {
        return this.inputType;
    }

    public int getId() {
        return this.id;
    }

    public AttentionNetwork(@NotNull AttentionNetworkParameters attentionNetworkParameters, @NotNull LayerType.Input input, double d, boolean z, int i) {
        Intrinsics.checkParameterIsNotNull(attentionNetworkParameters, "model");
        Intrinsics.checkParameterIsNotNull(input, "inputType");
        this.model = attentionNetworkParameters;
        this.inputType = input;
        this.propagateToInput = z;
        this.id = i;
        this.transformParamsErrorsAccumulator = new ParamsErrorsAccumulator();
        this.transformProcessor = new BatchFeedforwardProcessor<>(this.model.getTransform(), d, this.propagateToInput, 0, 8, (DefaultConstructorMarker) null);
    }

    public /* synthetic */ AttentionNetwork(AttentionNetworkParameters attentionNetworkParameters, LayerType.Input input, double d, boolean z, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(attentionNetworkParameters, input, (i2 & 4) != 0 ? 0.0d : d, z, (i2 & 16) != 0 ? 0 : i);
    }
}
