package com.kotlinnlp.simplednn.core.attention;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.utils.ItemsPool;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: AttentionMechanism.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��2\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u000e\n\u0002\u0010\u0002\n\u0002\b\u0005\b\u0016\u0018��2\u00020\u0001B%\u0012\f\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJ\u0016\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u00062\u0006\u0010\u001c\u001a\u00020\u0004J\u0006\u0010\u001d\u001a\u00020\u0004J\f\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003R\u0017\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00040\u000b¢\u0006\b\n��\u001a\u0004\b\f\u0010\rR\u0017\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003¢\u0006\b\n��\u001a\u0004\b\u000e\u0010\u000fR\u0014\u0010\u0007\u001a\u00020\bX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011R\u001a\u0010\u0012\u001a\u00020\u0004X\u0086.¢\u0006\u000e\n��\u001a\u0004\b\u0013\u0010\u0014\"\u0004\b\u0015\u0010\u0016R\u0011\u0010\u0005\u001a\u00020\u0006¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018¨\u0006\u001f"}, d2 = {"Lcom/kotlinnlp/simplednn/core/attention/AttentionMechanism;", "Lcom/kotlinnlp/utils/ItemsPool$IDItem;", "attentionSequence", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "params", "Lcom/kotlinnlp/simplednn/core/attention/AttentionParameters;", "id", "", "(Ljava/util/List;Lcom/kotlinnlp/simplednn/core/attention/AttentionParameters;I)V", "attentionMatrix", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "getAttentionMatrix", "()Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "getAttentionSequence", "()Ljava/util/List;", "getId", "()I", "importanceScore", "getImportanceScore", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "setImportanceScore", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "getParams", "()Lcom/kotlinnlp/simplednn/core/attention/AttentionParameters;", "backwardImportanceScore", "", "paramsErrors", "importanceScoreErrors", "forwardImportanceScore", "getAttentionErrors", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/attention/AttentionMechanism.class */
public class AttentionMechanism implements ItemsPool.IDItem {

    @NotNull
    private final AugmentedArray<DenseNDArray> attentionMatrix;

    @NotNull
    public DenseNDArray importanceScore;

    @NotNull
    private final List<DenseNDArray> attentionSequence;

    @NotNull
    private final AttentionParameters params;
    private final int id;

    @NotNull
    public final AugmentedArray<DenseNDArray> getAttentionMatrix() {
        return this.attentionMatrix;
    }

    @NotNull
    public final DenseNDArray getImportanceScore() {
        DenseNDArray denseNDArray = this.importanceScore;
        if (denseNDArray == null) {
            Intrinsics.throwUninitializedPropertyAccessException("importanceScore");
        }
        return denseNDArray;
    }

    public final void setImportanceScore(@NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "<set-?>");
        this.importanceScore = denseNDArray;
    }

    @NotNull
    public final DenseNDArray forwardImportanceScore() {
        this.importanceScore = new Softmax().f(this.attentionMatrix.getValues().dot((NDArray<?>) this.params.getContextVector().getValues()));
        DenseNDArray denseNDArray = this.importanceScore;
        if (denseNDArray == null) {
            Intrinsics.throwUninitializedPropertyAccessException("importanceScore");
        }
        return denseNDArray;
    }

    public final void backwardImportanceScore(@NotNull AttentionParameters attentionParameters, @NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(attentionParameters, "paramsErrors");
        Intrinsics.checkParameterIsNotNull(denseNDArray, "importanceScoreErrors");
        DenseNDArray values = this.params.getContextVector().getValues();
        Softmax softmax = new Softmax();
        DenseNDArray denseNDArray2 = this.importanceScore;
        if (denseNDArray2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("importanceScore");
        }
        DenseNDArray dot = softmax.df(denseNDArray2).dot((NDArray<?>) denseNDArray);
        attentionParameters.getContextVector().getValues().assignValues((NDArray<?>) dot.getT().dot((NDArray<?>) this.attentionMatrix.getValues()).getT());
        this.attentionMatrix.assignErrorsByDot(dot, values.getT());
    }

    @NotNull
    public final List<DenseNDArray> getAttentionErrors() {
        int dim1 = this.attentionMatrix.getValues().getShape().getDim1();
        ArrayList arrayList = new ArrayList(dim1);
        for (int i = 0; i < dim1; i++) {
            arrayList.add(this.attentionMatrix.getErrors().getRow(i).getT());
        }
        return arrayList;
    }

    @NotNull
    public final List<DenseNDArray> getAttentionSequence() {
        return this.attentionSequence;
    }

    @NotNull
    public final AttentionParameters getParams() {
        return this.params;
    }

    public int getId() {
        return this.id;
    }

    public AttentionMechanism(@NotNull List<DenseNDArray> list, @NotNull AttentionParameters attentionParameters, int i) {
        boolean z;
        Intrinsics.checkParameterIsNotNull(list, "attentionSequence");
        Intrinsics.checkParameterIsNotNull(attentionParameters, "params");
        this.attentionSequence = list;
        this.params = attentionParameters;
        this.id = i;
        AugmentedArray.Companion companion = AugmentedArray.Companion;
        DenseNDArrayFactory denseNDArrayFactory = DenseNDArrayFactory.INSTANCE;
        List<DenseNDArray> list2 = this.attentionSequence;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(((DenseNDArray) it.next()).toDoubleArray());
        }
        this.attentionMatrix = companion.invoke(denseNDArrayFactory.arrayOf(arrayList));
        if (!(!this.attentionSequence.isEmpty())) {
            throw new IllegalArgumentException("The attention sequence cannot be empty.".toString());
        }
        List<DenseNDArray> list3 = this.attentionSequence;
        if (!(list3 instanceof Collection) || !list3.isEmpty()) {
            Iterator<T> it2 = list3.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    z = true;
                    break;
                } else {
                    if (!(((DenseNDArray) it2.next()).getLength() == this.params.getAttentionSize())) {
                        z = false;
                        break;
                    }
                }
            }
        } else {
            z = true;
        }
        if (z) {
            return;
        }
        Object[] objArr = {Integer.valueOf(this.params.getAttentionSize())};
        String format = String.format("The attention arrays must have the expected size (%d).", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        throw new IllegalArgumentException(format.toString());
    }

    public /* synthetic */ AttentionMechanism(List list, AttentionParameters attentionParameters, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(list, attentionParameters, (i2 & 4) != 0 ? 0 : i);
    }
}
