package com.kotlinnlp.simplednn.core.layers.models.attention;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.functionalities.activations.SoftmaxBase;
import com.kotlinnlp.simplednn.core.layers.Layer;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.helpers.RelevanceHelper;
import com.kotlinnlp.simplednn.core.layers.models.attention.attentionmechanism.AttentionMechanismLayer;
import com.kotlinnlp.simplednn.core.layers.models.attention.attentionmechanism.AttentionMechanismLayerParameters;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: AttentionLayer.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��T\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0003\b��\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003BI\u0012\u0012\u0010\u0004\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��0\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0012\u0010\t\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\n0\u00060\u0005\u0012\u0006\u0010\u000b\u001a\u00020\f\u0012\n\b\u0002\u0010\r\u001a\u0004\u0018\u00010\u000e¢\u0006\u0002\u0010\u000fR\u001d\u0010\t\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\n0\u00060\u0005¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011R\u0017\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\n0\u00068F¢\u0006\u0006\u001a\u0004\b\u0013\u0010\u0014R\u0014\u0010\u0015\u001a\u00020\u0016X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018R\u0017\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\n0\u00068F¢\u0006\u0006\u001a\u0004\b\u001a\u0010\u0014R\u001a\u0010\u001b\u001a\b\u0012\u0004\u0012\u00028��0\u001cX\u0090\u0004¢\u0006\b\n��\u001a\u0004\b\u001d\u0010\u001eR\u001a\u0010\u001f\u001a\b\u0012\u0004\u0012\u00028��0 X\u0090\u0004¢\u0006\b\n��\u001a\u0004\b!\u0010\"R\u001d\u0010\u0004\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��0\u00060\u0005¢\u0006\b\n��\u001a\u0004\b#\u0010\u0011R\u0014\u0010\u000b\u001a\u00020\fX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b$\u0010%R\u0016\u0010&\u001a\u0004\u0018\u00010'X\u0090\u0004¢\u0006\b\n��\u001a\u0004\b(\u0010)¨\u0006*"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/attention/AttentionLayer;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/Layer;", "inputArrays", "", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "inputType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "attentionArrays", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "params", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismLayerParameters;", "activation", "Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "(Ljava/util/List;Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;Ljava/util/List;Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismLayerParameters;Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;)V", "getAttentionArrays", "()Ljava/util/List;", "attentionMatrix", "getAttentionMatrix", "()Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "attentionMechanism", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismLayer;", "getAttentionMechanism$simplednn", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismLayer;", "attentionScores", "getAttentionScores", "backwardHelper", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/AttentionBackwardHelper;", "getBackwardHelper$simplednn", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/AttentionBackwardHelper;", "forwardHelper", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/AttentionForwardHelper;", "getForwardHelper$simplednn", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/AttentionForwardHelper;", "getInputArrays", "getParams", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismLayerParameters;", "relevanceHelper", "Lcom/kotlinnlp/simplednn/core/layers/helpers/RelevanceHelper;", "getRelevanceHelper$simplednn", "()Lcom/kotlinnlp/simplednn/core/layers/helpers/RelevanceHelper;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/attention/AttentionLayer.class */
public final class AttentionLayer<InputNDArrayType extends NDArray<InputNDArrayType>> extends Layer<InputNDArrayType> {

    @NotNull
    private final AttentionForwardHelper<InputNDArrayType> forwardHelper;

    @NotNull
    private final AttentionBackwardHelper<InputNDArrayType> backwardHelper;

    @Nullable
    private final RelevanceHelper relevanceHelper;

    @NotNull
    private final AttentionMechanismLayer attentionMechanism;

    @NotNull
    private final List<AugmentedArray<InputNDArrayType>> inputArrays;

    @NotNull
    private final List<AugmentedArray<DenseNDArray>> attentionArrays;

    @NotNull
    private final AttentionMechanismLayerParameters params;

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @NotNull
    /* renamed from: getForwardHelper$simplednn */
    public AttentionForwardHelper<InputNDArrayType> getForwardHelper$simplednn2() {
        return this.forwardHelper;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @NotNull
    /* renamed from: getBackwardHelper$simplednn */
    public AttentionBackwardHelper<InputNDArrayType> getBackwardHelper$simplednn2() {
        return this.backwardHelper;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @Nullable
    public RelevanceHelper getRelevanceHelper$simplednn() {
        return this.relevanceHelper;
    }

    @NotNull
    public final AttentionMechanismLayer getAttentionMechanism$simplednn() {
        return this.attentionMechanism;
    }

    @NotNull
    public final AugmentedArray<DenseNDArray> getAttentionScores() {
        return this.attentionMechanism.getOutputArray();
    }

    @NotNull
    public final AugmentedArray<DenseNDArray> getAttentionMatrix() {
        return this.attentionMechanism.getAttentionMatrix$simplednn();
    }

    @NotNull
    public final List<AugmentedArray<InputNDArrayType>> getInputArrays() {
        return this.inputArrays;
    }

    @NotNull
    public final List<AugmentedArray<DenseNDArray>> getAttentionArrays() {
        return this.attentionArrays;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @NotNull
    public AttentionMechanismLayerParameters getParams() {
        return this.params;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    /* JADX WARN: Multi-variable type inference failed */
    public AttentionLayer(@NotNull List<? extends AugmentedArray<InputNDArrayType>> list, @NotNull LayerType.Input input, @NotNull List<? extends AugmentedArray<DenseNDArray>> list2, @NotNull AttentionMechanismLayerParameters attentionMechanismLayerParameters, @Nullable ActivationFunction activationFunction) {
        super(new AugmentedArray(attentionMechanismLayerParameters.getInputSize()), input, AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.zeros(new Shape(((AugmentedArray) CollectionsKt.first(list)).getSize(), 0, 2, null))), attentionMechanismLayerParameters, activationFunction, 0.0d);
        Intrinsics.checkParameterIsNotNull(list, "inputArrays");
        Intrinsics.checkParameterIsNotNull(input, "inputType");
        Intrinsics.checkParameterIsNotNull(list2, "attentionArrays");
        Intrinsics.checkParameterIsNotNull(attentionMechanismLayerParameters, "params");
        this.inputArrays = list;
        this.attentionArrays = list2;
        this.params = attentionMechanismLayerParameters;
        this.forwardHelper = new AttentionForwardHelper<>(this);
        this.backwardHelper = new AttentionBackwardHelper<>(this);
        AttentionMechanismLayer attentionMechanismLayer = new AttentionMechanismLayer(this.attentionArrays, input, getParams(), new SoftmaxBase());
        attentionMechanismLayer.setParamsErrorsCollector(getParamsErrorsCollector());
        this.attentionMechanism = attentionMechanismLayer;
        if (!(!this.inputArrays.isEmpty())) {
            throw new IllegalArgumentException("The input array cannot be empty.".toString());
        }
        if (!(!this.attentionArrays.isEmpty())) {
            throw new IllegalArgumentException("The attention array cannot be empty.".toString());
        }
        if (!(this.inputArrays.size() == this.attentionArrays.size())) {
            throw new IllegalArgumentException("The input array must have the same length of the attention array.".toString());
        }
        int size = ((AugmentedArray) CollectionsKt.first(this.inputArrays)).getSize();
        Iterator<T> it = this.inputArrays.iterator();
        while (it.hasNext()) {
            if (!(((AugmentedArray) it.next()).getSize() == size)) {
                throw new IllegalArgumentException("Failed requirement.".toString());
            }
        }
    }

    public /* synthetic */ AttentionLayer(List list, LayerType.Input input, List list2, AttentionMechanismLayerParameters attentionMechanismLayerParameters, ActivationFunction activationFunction, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(list, input, list2, attentionMechanismLayerParameters, (i & 16) != 0 ? new SoftmaxBase() : activationFunction);
    }
}
