package com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.layers.Layer;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper;
import com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper;
import com.kotlinnlp.simplednn.core.layers.helpers.RelevanceHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArrayMask;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.sequences.SequencesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: ScaledDotAttentionLayer.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��F\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n\u0002\b\f\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\b\u000f\n\u0002\u0018\u0002\n\u0002\b\u0006\b��\u0018��2\b\u0012\u0004\u0012\u00020\u00020\u0001B+\u0012\u0012\u0010\u0003\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00020\u00050\u0004\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\b\b\u0002\u0010\b\u001a\u00020\t¢\u0006\u0002\u0010\nJ\u000e\u00108\u001a\b\u0012\u0004\u0012\u00020\u001a0\u0004H\u0002R\u001a\u0010\u000b\u001a\u00020\u0002X\u0080.¢\u0006\u000e\n��\u001a\u0004\b\f\u0010\r\"\u0004\b\u000e\u0010\u000fR \u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00020\u0004X\u0080.¢\u0006\u000e\n��\u001a\u0004\b\u0011\u0010\u0012\"\u0004\b\u0013\u0010\u0014R\u000e\u0010\b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0015\u001a\u00020\u0016X\u0090\u0004¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018R\u001c\u0010\u0019\u001a\u0004\u0018\u00010\u001aX\u0080\u000e¢\u0006\u000e\n��\u001a\u0004\b\u001b\u0010\u001c\"\u0004\b\u001d\u0010\u001eR\"\u0010\u001f\u001a\n\u0012\u0004\u0012\u00020\u001a\u0018\u00010\u0004X\u0080\u000e¢\u0006\u000e\n��\u001a\u0004\b \u0010\u0012\"\u0004\b!\u0010\u0014R\u0014\u0010\"\u001a\u00020#X\u0090\u0004¢\u0006\b\n��\u001a\u0004\b$\u0010%R\u001d\u0010\u0003\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00020\u00050\u0004¢\u0006\b\n��\u001a\u0004\b&\u0010\u0012R\u0017\u0010'\u001a\b\u0012\u0004\u0012\u00020\u00020\u0005¢\u0006\b\n��\u001a\u0004\b(\u0010)R\u001a\u0010*\u001a\b\u0012\u0004\u0012\u00020\u00020\u0005X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b+\u0010)R\u001d\u0010,\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00020\u00050\u0004¢\u0006\b\n��\u001a\u0004\b-\u0010\u0012R\u0014\u0010\u0006\u001a\u00020\u0007X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b.\u0010/R\u001a\u00100\u001a\b\u0012\u0004\u0012\u00020\u00020\u0005X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b1\u0010)R\u0016\u00102\u001a\u0004\u0018\u000103X\u0090\u0004¢\u0006\b\n��\u001a\u0004\b4\u00105R\u001a\u00106\u001a\b\u0012\u0004\u0012\u00020\u00020\u0005X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b7\u0010)¨\u00069"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayer;", "Lcom/kotlinnlp/simplednn/core/layers/Layer;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "inputArrays", "", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "params", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayerParameters;", "attentionDropout", "", "(Ljava/util/List;Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayerParameters;D)V", "attention", "getAttention$simplednn", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "setAttention$simplednn", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "attentionAct", "getAttentionAct$simplednn", "()Ljava/util/List;", "setAttentionAct$simplednn", "(Ljava/util/List;)V", "backwardHelper", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionBackwardHelper;", "getBackwardHelper$simplednn", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionBackwardHelper;", "dropoutMaskFull", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArrayMask;", "getDropoutMaskFull$simplednn", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArrayMask;", "setDropoutMaskFull$simplednn", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArrayMask;)V", "dropoutMasks", "getDropoutMasks$simplednn", "setDropoutMasks$simplednn", "forwardHelper", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionForwardHelper;", "getForwardHelper$simplednn", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionForwardHelper;", "getInputArrays", "inputMatrix", "getInputMatrix", "()Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "keys", "getKeys$simplednn", "outputArrays", "getOutputArrays", "getParams", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayerParameters;", "queries", "getQueries$simplednn", "relevanceHelper", "Lcom/kotlinnlp/simplednn/core/layers/helpers/RelevanceHelper;", "getRelevanceHelper$simplednn", "()Lcom/kotlinnlp/simplednn/core/layers/helpers/RelevanceHelper;", "values", "getValues$simplednn", "buildDropoutMasks", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/attention/scaleddot/ScaledDotAttentionLayer.class */
public final class ScaledDotAttentionLayer extends Layer<DenseNDArray> {

    @NotNull
    private final AugmentedArray<DenseNDArray> inputMatrix;

    @NotNull
    private final List<AugmentedArray<DenseNDArray>> outputArrays;

    @NotNull
    private final AugmentedArray<DenseNDArray> queries;

    @NotNull
    private final AugmentedArray<DenseNDArray> keys;

    @NotNull
    private final AugmentedArray<DenseNDArray> values;

    @Nullable
    private List<NDArrayMask> dropoutMasks;

    @Nullable
    private NDArrayMask dropoutMaskFull;

    @NotNull
    public DenseNDArray attention;

    @NotNull
    public List<DenseNDArray> attentionAct;

    @NotNull
    private final ScaledDotAttentionForwardHelper forwardHelper;

    @NotNull
    private final ScaledDotAttentionBackwardHelper backwardHelper;

    @Nullable
    private final RelevanceHelper relevanceHelper;

    @NotNull
    private final List<AugmentedArray<DenseNDArray>> inputArrays;

    @NotNull
    private final ScaledDotAttentionLayerParameters params;
    private final double attentionDropout;

    @NotNull
    public final AugmentedArray<DenseNDArray> getInputMatrix() {
        return this.inputMatrix;
    }

    @NotNull
    public final List<AugmentedArray<DenseNDArray>> getOutputArrays() {
        return this.outputArrays;
    }

    @NotNull
    public final AugmentedArray<DenseNDArray> getQueries$simplednn() {
        return this.queries;
    }

    @NotNull
    public final AugmentedArray<DenseNDArray> getKeys$simplednn() {
        return this.keys;
    }

    @NotNull
    public final AugmentedArray<DenseNDArray> getValues$simplednn() {
        return this.values;
    }

    @Nullable
    public final List<NDArrayMask> getDropoutMasks$simplednn() {
        return this.dropoutMasks;
    }

    public final void setDropoutMasks$simplednn(@Nullable List<NDArrayMask> list) {
        this.dropoutMasks = list;
    }

    @Nullable
    public final NDArrayMask getDropoutMaskFull$simplednn() {
        return this.dropoutMaskFull;
    }

    public final void setDropoutMaskFull$simplednn(@Nullable NDArrayMask nDArrayMask) {
        this.dropoutMaskFull = nDArrayMask;
    }

    @NotNull
    public final DenseNDArray getAttention$simplednn() {
        DenseNDArray denseNDArray = this.attention;
        if (denseNDArray == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attention");
        }
        return denseNDArray;
    }

    public final void setAttention$simplednn(@NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "<set-?>");
        this.attention = denseNDArray;
    }

    @NotNull
    public final List<DenseNDArray> getAttentionAct$simplednn() {
        List<DenseNDArray> list = this.attentionAct;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("attentionAct");
        }
        return list;
    }

    public final void setAttentionAct$simplednn(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "<set-?>");
        this.attentionAct = list;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @NotNull
    /* renamed from: getForwardHelper$simplednn */
    public ForwardHelper<DenseNDArray> getForwardHelper$simplednn2() {
        return this.forwardHelper;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @NotNull
    /* renamed from: getBackwardHelper$simplednn */
    public BackwardHelper<DenseNDArray> getBackwardHelper$simplednn2() {
        return this.backwardHelper;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @Nullable
    public RelevanceHelper getRelevanceHelper$simplednn() {
        return this.relevanceHelper;
    }

    private final List<NDArrayMask> buildDropoutMasks() {
        int size = this.inputArrays.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            List list = SequencesKt.toList(SequencesKt.map(SequencesKt.filter(SequencesKt.map(CollectionsKt.asSequence(CollectionsKt.getIndices(this.inputArrays)), new Function1<Integer, Pair<? extends Integer, ? extends Double>>() { // from class: com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer$buildDropoutMasks$1$activeIndices$1
                public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                    return invoke(((Number) obj).intValue());
                }

                @NotNull
                public final Pair<Integer, Double> invoke(int i2) {
                    return TuplesKt.to(Integer.valueOf(i2), Double.valueOf(Math.random()));
                }
            }), new Function1<Pair<? extends Integer, ? extends Double>, Boolean>() { // from class: com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer$buildDropoutMasks$$inlined$List$lambda$1
                /* JADX INFO: Access modifiers changed from: package-private */
                {
                    super(1);
                }

                public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                    return Boolean.valueOf(invoke((Pair<Integer, Double>) obj));
                }

                public final boolean invoke(@NotNull Pair<Integer, Double> pair) {
                    double d;
                    Intrinsics.checkParameterIsNotNull(pair, "it");
                    double doubleValue = ((Number) pair.getSecond()).doubleValue();
                    d = ScaledDotAttentionLayer.this.attentionDropout;
                    return doubleValue >= d;
                }
            }), new Function1<Pair<? extends Integer, ? extends Double>, Integer>() { // from class: com.kotlinnlp.simplednn.core.layers.models.attention.scaleddot.ScaledDotAttentionLayer$buildDropoutMasks$1$activeIndices$3
                public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                    return Integer.valueOf(invoke((Pair<Integer, Double>) obj));
                }

                public final int invoke(@NotNull Pair<Integer, Double> pair) {
                    Intrinsics.checkParameterIsNotNull(pair, "it");
                    return ((Number) pair.getFirst()).intValue();
                }
            }));
            int[] iArr = new int[list.size()];
            int length = iArr.length;
            for (int i2 = 0; i2 < length; i2++) {
                iArr[i2] = 0;
            }
            arrayList.add(new NDArrayMask(CollectionsKt.toIntArray(list), iArr));
        }
        return arrayList;
    }

    @NotNull
    public final List<AugmentedArray<DenseNDArray>> getInputArrays() {
        return this.inputArrays;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @NotNull
    public ScaledDotAttentionLayerParameters getParams() {
        return this.params;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    /* JADX WARN: Multi-variable type inference failed */
    public ScaledDotAttentionLayer(@NotNull List<? extends AugmentedArray<DenseNDArray>> list, @NotNull ScaledDotAttentionLayerParameters scaledDotAttentionLayerParameters, double d) {
        super((AugmentedArray) list.get(0), LayerType.Input.Dense, new AugmentedArray(list.size()), scaledDotAttentionLayerParameters, null, 0.0d);
        NDArrayMask nDArrayMask;
        boolean z;
        Intrinsics.checkParameterIsNotNull(list, "inputArrays");
        Intrinsics.checkParameterIsNotNull(scaledDotAttentionLayerParameters, "params");
        this.inputArrays = list;
        this.params = scaledDotAttentionLayerParameters;
        this.attentionDropout = d;
        AugmentedArray.Companion companion = AugmentedArray.Companion;
        DenseNDArrayFactory denseNDArrayFactory = DenseNDArrayFactory.INSTANCE;
        List<AugmentedArray<DenseNDArray>> list2 = this.inputArrays;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add((DenseNDArray) ((AugmentedArray) it.next()).getValues());
        }
        this.inputMatrix = companion.invoke(denseNDArrayFactory.fromRows(arrayList));
        List<AugmentedArray<DenseNDArray>> list3 = this.inputArrays;
        ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list3, 10));
        Iterator<T> it2 = list3.iterator();
        while (it2.hasNext()) {
            arrayList2.add(new AugmentedArray(getParams().getOutputSize()));
        }
        this.outputArrays = arrayList2;
        this.queries = new AugmentedArray<>(this.inputArrays.size() * getParams().getAttentionSize());
        this.keys = new AugmentedArray<>(this.inputArrays.size() * getParams().getAttentionSize());
        this.values = new AugmentedArray<>(this.inputArrays.size() * getParams().getOutputSize());
        this.dropoutMasks = this.attentionDropout > 0.0d ? buildDropoutMasks() : null;
        ScaledDotAttentionLayer scaledDotAttentionLayer = this;
        List<NDArrayMask> list4 = this.dropoutMasks;
        if (list4 != null) {
            int[] iArr = new int[0];
            Iterator<T> it3 = list4.iterator();
            while (it3.hasNext()) {
                iArr = ArraysKt.plus(iArr, ((NDArrayMask) it3.next()).getDim1());
            }
            int[] iArr2 = iArr;
            int i = 0;
            int[] iArr3 = new int[0];
            for (Object obj : list4) {
                int i2 = i;
                i++;
                if (i2 < 0) {
                    CollectionsKt.throwIndexOverflow();
                }
                int[] iArr4 = iArr3;
                int[] iArr5 = new int[((NDArrayMask) obj).getSize()];
                int length = iArr5.length;
                for (int i3 = 0; i3 < length; i3++) {
                    iArr5[i3] = i2;
                }
                iArr3 = ArraysKt.plus(iArr4, iArr5);
            }
            NDArrayMask nDArrayMask2 = new NDArrayMask(iArr2, iArr3);
            scaledDotAttentionLayer = scaledDotAttentionLayer;
            nDArrayMask = nDArrayMask2;
        } else {
            nDArrayMask = null;
        }
        scaledDotAttentionLayer.dropoutMaskFull = nDArrayMask;
        this.forwardHelper = new ScaledDotAttentionForwardHelper(this);
        this.backwardHelper = new ScaledDotAttentionBackwardHelper(this);
        if (!(!this.inputArrays.isEmpty())) {
            throw new IllegalArgumentException("The attention sequence cannot be empty.".toString());
        }
        List<AugmentedArray<DenseNDArray>> list5 = this.inputArrays;
        if (!(list5 instanceof Collection) || !list5.isEmpty()) {
            Iterator<T> it4 = list5.iterator();
            while (true) {
                if (it4.hasNext()) {
                    if (!(((DenseNDArray) ((AugmentedArray) it4.next()).getValues()).getLength() == getParams().getInputSize())) {
                        z = false;
                        break;
                    }
                } else {
                    z = true;
                    break;
                }
            }
        } else {
            z = true;
        }
        if (z) {
            return;
        }
        Object[] objArr = {Integer.valueOf(getParams().getInputSize())};
        String format = String.format("All the input arrays must have the same size (%d).", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        throw new IllegalArgumentException(format.toString());
    }

    public /* synthetic */ ScaledDotAttentionLayer(List list, ScaledDotAttentionLayerParameters scaledDotAttentionLayerParameters, double d, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(list, scaledDotAttentionLayerParameters, (i & 4) != 0 ? 0.0d : d);
    }
}
