package com.kotlinnlp.simplednn.deeplearning.attentionnetwork;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.arrays.UpdatableArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.Tanh;
import com.kotlinnlp.simplednn.core.layers.LayerStructure;
import com.kotlinnlp.simplednn.core.layers.LayerStructureFactory;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.feedforward.FeedforwardLayerParameters;
import com.kotlinnlp.simplednn.core.layers.feedforward.FeedforwardLayerStructure;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.attentionlayer.AttentionLayerParameters;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.attentionlayer.AttentionLayerStructure;
import com.kotlinnlp.simplednn.deeplearning.mergelayers.MergeLayer;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.utils.ItemsPool;
import java.util.ArrayList;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: AttentionNetwork.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��z\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\t\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B)\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\b\b\u0002\u0010\b\u001a\u00020\t\u0012\b\b\u0002\u0010\n\u001a\u00020\u000b¢\u0006\u0002\u0010\fJ\b\u0010\u001e\u001a\u00020\u001fH\u0002J1\u0010 \u001a\u00020\u001f2\u0006\u0010!\u001a\u00020\"2\u0006\u0010#\u001a\u00020\u00052\b\b\u0002\u0010$\u001a\u00020%2\n\b\u0002\u0010&\u001a\u0004\u0018\u00010\t¢\u0006\u0002\u0010'J\"\u0010(\u001a\u00020\u001f2\u0006\u0010!\u001a\u00020\"2\u0006\u0010#\u001a\u00020)2\b\b\u0002\u0010$\u001a\u00020%H\u0002J)\u0010*\u001a\u00020\u001f2\u0006\u0010#\u001a\u00020\u001d2\b\b\u0002\u0010$\u001a\u00020%2\b\u0010&\u001a\u0004\u0018\u00010\tH\u0002¢\u0006\u0002\u0010+JD\u0010,\u001a\u0012\u0012\u0004\u0012\u00020\"0-j\b\u0012\u0004\u0012\u00020\"`.2\"\u0010/\u001a\u001e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��000-j\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��00`.2\u0006\u00101\u001a\u00020%H\u0002J4\u00102\u001a\u00020\"2\"\u0010/\u001a\u001e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��000-j\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��00`.2\b\b\u0002\u00101\u001a\u00020%J\u0010\u00103\u001a\u00020\"2\b\b\u0002\u00104\u001a\u00020%J\u0011\u00105\u001a\b\u0012\u0004\u0012\u00020\"0\u0018¢\u0006\u0002\u00106J6\u00107\u001a\u00020\u001f2\"\u0010/\u001a\u001e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��000-j\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��00`.2\b\b\u0002\u00101\u001a\u00020%H\u0002J\u000e\u00108\u001a\b\u0012\u0004\u0012\u00028��0\u0019H\u0002R\u0014\u0010\r\u001a\b\u0012\u0004\u0012\u00028��0\u000eX\u0082.¢\u0006\u0002\n��R\u0011\u0010\b\u001a\u00020\t¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\n\u001a\u00020\u000bX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u0011\u0010\u0006\u001a\u00020\u0007¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0016R\u001c\u0010\u0017\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��0\u00190\u0018X\u0082.¢\u0006\u0004\n\u0002\u0010\u001aR\u0014\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u001d0\u001cX\u0082\u0004¢\u0006\u0002\n��¨\u00069"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetwork;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/utils/ItemsPool$IDItem;", "model", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetworkParameters;", "inputType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "dropout", "", "id", "", "(Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetworkParameters;Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;DI)V", "attentionLayer", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/attentionlayer/AttentionLayerStructure;", "getDropout", "()D", "getId", "()I", "getInputType", "()Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "getModel", "()Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetworkParameters;", "transformLayers", "", "Lcom/kotlinnlp/simplednn/core/layers/feedforward/FeedforwardLayerStructure;", "[Lcom/kotlinnlp/simplednn/core/layers/feedforward/FeedforwardLayerStructure;", "transformParamsErrorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "Lcom/kotlinnlp/simplednn/core/layers/feedforward/FeedforwardLayerParameters;", "addTransformErrorsToInput", "", "backward", "outputErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "paramsErrors", "propagateToInput", "", "mePropK", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetworkParameters;ZLjava/lang/Double;)V", "backwardAttentionLayer", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/attentionlayer/AttentionLayerParameters;", "backwardTransformLayers", "(Lcom/kotlinnlp/simplednn/core/layers/feedforward/FeedforwardLayerParameters;ZLjava/lang/Double;)V", "buildAttentionSequence", "Ljava/util/ArrayList;", "Lkotlin/collections/ArrayList;", "inputSequence", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "useDropout", "forward", "getImportanceScore", "copy", "getInputErrors", "()[Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "setInputSequence", "transformLayerFactory", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetwork.class */
public final class AttentionNetwork<InputNDArrayType extends NDArray<InputNDArrayType>> implements ItemsPool.IDItem {
    private final ParamsErrorsAccumulator<FeedforwardLayerParameters> transformParamsErrorsAccumulator;
    private FeedforwardLayerStructure<InputNDArrayType>[] transformLayers;
    private AttentionLayerStructure<InputNDArrayType> attentionLayer;

    @NotNull
    private final AttentionNetworkParameters model;

    @NotNull
    private final LayerType.Input inputType;
    private final double dropout;
    private final int id;

    @NotNull
    public final DenseNDArray forward(@NotNull ArrayList<AugmentedArray<InputNDArrayType>> arrayList, boolean z) {
        Intrinsics.checkParameterIsNotNull(arrayList, "inputSequence");
        setInputSequence(arrayList, z);
        AttentionLayerStructure<InputNDArrayType> attentionLayerStructure = this.attentionLayer;
        if (attentionLayerStructure == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        attentionLayerStructure.forward();
        AttentionLayerStructure<InputNDArrayType> attentionLayerStructure2 = this.attentionLayer;
        if (attentionLayerStructure2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        return attentionLayerStructure2.getOutputArray().getValues();
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray forward$default(AttentionNetwork attentionNetwork, ArrayList arrayList, boolean z, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        return attentionNetwork.forward(arrayList, z);
    }

    public final void backward(@NotNull DenseNDArray denseNDArray, @NotNull AttentionNetworkParameters attentionNetworkParameters, boolean z, @Nullable Double d) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "outputErrors");
        Intrinsics.checkParameterIsNotNull(attentionNetworkParameters, "paramsErrors");
        backwardAttentionLayer(denseNDArray, attentionNetworkParameters.getAttentionParams(), z);
        backwardTransformLayers(attentionNetworkParameters.getTransformParams(), z, d);
        if (z) {
            addTransformErrorsToInput();
        }
    }

    public static /* bridge */ /* synthetic */ void backward$default(AttentionNetwork attentionNetwork, DenseNDArray denseNDArray, AttentionNetworkParameters attentionNetworkParameters, boolean z, Double d, int i, Object obj) {
        if ((i & 4) != 0) {
            z = false;
        }
        if ((i & 8) != 0) {
            d = (Double) null;
        }
        attentionNetwork.backward(denseNDArray, attentionNetworkParameters, z, d);
    }

    @NotNull
    public final DenseNDArray[] getInputErrors() {
        AttentionLayerStructure<InputNDArrayType> attentionLayerStructure = this.attentionLayer;
        if (attentionLayerStructure == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        DenseNDArray[] denseNDArrayArr = new DenseNDArray[attentionLayerStructure.getInputSequence().size()];
        int length = denseNDArrayArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            int i3 = i;
            AttentionLayerStructure<InputNDArrayType> attentionLayerStructure2 = this.attentionLayer;
            if (attentionLayerStructure2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
            }
            denseNDArrayArr[i2] = attentionLayerStructure2.getInputSequence().get(i3).getErrors();
        }
        return denseNDArrayArr;
    }

    @NotNull
    public final DenseNDArray getImportanceScore(boolean z) {
        if (z) {
            AttentionLayerStructure<InputNDArrayType> attentionLayerStructure = this.attentionLayer;
            if (attentionLayerStructure == null) {
                Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
            }
            return attentionLayerStructure.getImportanceScore().copy();
        }
        AttentionLayerStructure<InputNDArrayType> attentionLayerStructure2 = this.attentionLayer;
        if (attentionLayerStructure2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        return attentionLayerStructure2.getImportanceScore();
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray getImportanceScore$default(AttentionNetwork attentionNetwork, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return attentionNetwork.getImportanceScore(z);
    }

    private final FeedforwardLayerStructure<InputNDArrayType> transformLayerFactory() {
        AugmentedArray augmentedArray;
        switch (this.inputType) {
            case Dense:
                augmentedArray = new AugmentedArray(this.model.getInputSize());
                break;
            case Sparse:
                augmentedArray = new AugmentedArray(this.model.getInputSize());
                break;
            case SparseBinary:
                augmentedArray = new AugmentedArray(this.model.getInputSize());
                break;
            default:
                throw new NoWhenBranchMatchedException();
        }
        LayerStructure invoke$default = LayerStructureFactory.invoke$default(LayerStructureFactory.INSTANCE, augmentedArray, this.model.getAttentionSize(), this.model.getTransformParams(), new Tanh(), LayerType.Connection.Feedforward, this.dropout, null, 64, null);
        if (invoke$default == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.feedforward.FeedforwardLayerStructure<InputNDArrayType>");
        }
        return (FeedforwardLayerStructure) invoke$default;
    }

    private final void setInputSequence(ArrayList<AugmentedArray<InputNDArrayType>> arrayList, boolean z) {
        this.attentionLayer = new AttentionLayerStructure<>(arrayList, buildAttentionSequence(arrayList, z), this.model.getAttentionParams());
    }

    static /* bridge */ /* synthetic */ void setInputSequence$default(AttentionNetwork attentionNetwork, ArrayList arrayList, boolean z, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        attentionNetwork.setInputSequence(arrayList, z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final ArrayList<DenseNDArray> buildAttentionSequence(ArrayList<AugmentedArray<InputNDArrayType>> arrayList, boolean z) {
        ArrayList<DenseNDArray> arrayList2 = new ArrayList<>();
        FeedforwardLayerStructure<InputNDArrayType>[] feedforwardLayerStructureArr = new FeedforwardLayerStructure[arrayList.size()];
        int length = feedforwardLayerStructureArr.length;
        for (int i = 0; i < length; i++) {
            feedforwardLayerStructureArr[i] = transformLayerFactory();
        }
        this.transformLayers = feedforwardLayerStructureArr;
        int i2 = 0;
        for (Object obj : arrayList) {
            int i3 = i2;
            i2++;
            AugmentedArray augmentedArray = (AugmentedArray) obj;
            FeedforwardLayerStructure<InputNDArrayType>[] feedforwardLayerStructureArr2 = this.transformLayers;
            if (feedforwardLayerStructureArr2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("transformLayers");
            }
            MergeLayer mergeLayer = feedforwardLayerStructureArr2[i3];
            mergeLayer.setInput(augmentedArray.getValues());
            mergeLayer.forward(z);
            arrayList2.add(mergeLayer.getOutputArray().getValues());
        }
        return arrayList2;
    }

    private final void backwardAttentionLayer(DenseNDArray denseNDArray, AttentionLayerParameters attentionLayerParameters, boolean z) {
        AttentionLayerStructure<InputNDArrayType> attentionLayerStructure = this.attentionLayer;
        if (attentionLayerStructure == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        attentionLayerStructure.setErrors(denseNDArray);
        AttentionLayerStructure<InputNDArrayType> attentionLayerStructure2 = this.attentionLayer;
        if (attentionLayerStructure2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        attentionLayerStructure2.backward(attentionLayerParameters, z);
    }

    static /* bridge */ /* synthetic */ void backwardAttentionLayer$default(AttentionNetwork attentionNetwork, DenseNDArray denseNDArray, AttentionLayerParameters attentionLayerParameters, boolean z, int i, Object obj) {
        if ((i & 4) != 0) {
            z = false;
        }
        attentionNetwork.backwardAttentionLayer(denseNDArray, attentionLayerParameters, z);
    }

    private final void backwardTransformLayers(FeedforwardLayerParameters feedforwardLayerParameters, boolean z, Double d) {
        AttentionLayerStructure<InputNDArrayType> attentionLayerStructure = this.attentionLayer;
        if (attentionLayerStructure == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        DenseNDArray[] attentionErrors = attentionLayerStructure.getAttentionErrors();
        FeedforwardLayerStructure<InputNDArrayType>[] feedforwardLayerStructureArr = this.transformLayers;
        if (feedforwardLayerStructureArr == null) {
            Intrinsics.throwUninitializedPropertyAccessException("transformLayers");
        }
        int i = 0;
        for (FeedforwardLayerStructure<InputNDArrayType> feedforwardLayerStructure : feedforwardLayerStructureArr) {
            int i2 = i;
            i++;
            FeedforwardLayerStructure<InputNDArrayType> feedforwardLayerStructure2 = feedforwardLayerStructure;
            feedforwardLayerStructure2.setErrors(attentionErrors[i2]);
            feedforwardLayerStructure2.backward(feedforwardLayerParameters, z, d);
            ParamsErrorsAccumulator.accumulate$default(this.transformParamsErrorsAccumulator, feedforwardLayerParameters, false, 2, null);
        }
        this.transformParamsErrorsAccumulator.averageErrors();
        for (Pair pair : CollectionsKt.zip(feedforwardLayerParameters, (FeedforwardLayerParameters) ParamsErrorsAccumulator.getParamsErrors$default(this.transformParamsErrorsAccumulator, false, 1, null))) {
            ((UpdatableArray) pair.component1()).getValues().assignValues(((UpdatableArray) pair.component2()).getValues());
        }
        this.transformParamsErrorsAccumulator.reset();
    }

    static /* bridge */ /* synthetic */ void backwardTransformLayers$default(AttentionNetwork attentionNetwork, FeedforwardLayerParameters feedforwardLayerParameters, boolean z, Double d, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        attentionNetwork.backwardTransformLayers(feedforwardLayerParameters, z, d);
    }

    private final void addTransformErrorsToInput() {
        AttentionLayerStructure<InputNDArrayType> attentionLayerStructure = this.attentionLayer;
        if (attentionLayerStructure == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionLayer");
        }
        int i = 0;
        for (Object obj : attentionLayerStructure.getInputSequence()) {
            int i2 = i;
            i++;
            DenseNDArray errors = ((AugmentedArray) obj).getErrors();
            FeedforwardLayerStructure<InputNDArrayType>[] feedforwardLayerStructureArr = this.transformLayers;
            if (feedforwardLayerStructureArr == null) {
                Intrinsics.throwUninitializedPropertyAccessException("transformLayers");
            }
            errors.assignSum((NDArray<?>) feedforwardLayerStructureArr[i2].getInputArray().getErrors());
        }
    }

    @NotNull
    public final AttentionNetworkParameters getModel() {
        return this.model;
    }

    @NotNull
    public final LayerType.Input getInputType() {
        return this.inputType;
    }

    public final double getDropout() {
        return this.dropout;
    }

    @Override // com.kotlinnlp.simplednn.utils.ItemsPool.IDItem
    public int getId() {
        return this.id;
    }

    public AttentionNetwork(@NotNull AttentionNetworkParameters attentionNetworkParameters, @NotNull LayerType.Input input, double d, int i) {
        Intrinsics.checkParameterIsNotNull(attentionNetworkParameters, "model");
        Intrinsics.checkParameterIsNotNull(input, "inputType");
        this.model = attentionNetworkParameters;
        this.inputType = input;
        this.dropout = d;
        this.id = i;
        this.transformParamsErrorsAccumulator = new ParamsErrorsAccumulator<>();
    }

    public /* synthetic */ AttentionNetwork(AttentionNetworkParameters attentionNetworkParameters, LayerType.Input input, double d, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(attentionNetworkParameters, input, (i2 & 4) != 0 ? 0.0d : d, (i2 & 8) != 0 ? 0 : i);
    }
}
