package com.kotlinnlp.simplednn.core.layers;

import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.sparse.SparseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.sparsebinary.SparseBinaryNDArray;
import java.util.ArrayList;
import java.util.Arrays;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: ForwardHelper.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��8\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\b&\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B\u0013\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005¢\u0006\u0002\u0010\u0006J8\u0010\t\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\f2\u0006\u0010\r\u001a\u00020\f2\u0006\u0010\u000e\u001a\u00020\f2\u0006\u0010\u000f\u001a\u00020\f2\u0006\u0010\u0010\u001a\u00020\f2\u0006\u0010\u0011\u001a\u00020\fH\u0004J\b\u0010\u0012\u001a\u00020\nH&J\u0014\u0010\u0012\u001a\u00020\n2\n\u0010\u0013\u001a\u0006\u0012\u0002\b\u00030\u0014H&J<\u0010\u0015\u001a\u00020\n2\n\u0010\u000b\u001a\u0006\u0012\u0002\b\u00030\u00022\n\u0010\u0016\u001a\u0006\u0012\u0002\b\u00030\u00022\u0006\u0010\u000f\u001a\u00020\f2\u0006\u0010\u0017\u001a\u00020\f2\n\b\u0002\u0010\u0011\u001a\u0004\u0018\u00010\fH\u0004J4\u0010\u0018\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\f2\u0006\u0010\u0016\u001a\u00020\f2\u0006\u0010\u000f\u001a\u00020\f2\u0006\u0010\u0017\u001a\u00020\f2\n\b\u0002\u0010\u0011\u001a\u0004\u0018\u00010\fH\u0004J4\u0010\u0019\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\u001a2\u0006\u0010\u0016\u001a\u00020\u001b2\u0006\u0010\u000f\u001a\u00020\f2\u0006\u0010\u0017\u001a\u00020\f2\n\b\u0002\u0010\u0011\u001a\u0004\u0018\u00010\fH\u0002R\u001a\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006\u001c"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/ForwardHelper;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "", "layer", "Lcom/kotlinnlp/simplednn/core/layers/LayerStructure;", "(Lcom/kotlinnlp/simplednn/core/layers/LayerStructure;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/LayerStructure;", "addRecurrentContribution", "", "contributions", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "yPrev", "yRec", "y", "wRec", "b", "forward", "layerContributions", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "forwardArray", "x", "w", "forwardDenseArray", "forwardSparseBinaryArray", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/sparse/SparseNDArray;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/sparsebinary/SparseBinaryNDArray;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/ForwardHelper.class */
public abstract class ForwardHelper<InputNDArrayType extends NDArray<InputNDArrayType>> {

    @NotNull
    private final LayerStructure<InputNDArrayType> layer;

    public abstract void forward();

    public abstract void forward(@NotNull LayerParameters<?> layerParameters);

    /* JADX INFO: Access modifiers changed from: protected */
    public final void forwardArray(@NotNull NDArray<?> contributions, @NotNull NDArray<?> x, @NotNull DenseNDArray y, @NotNull DenseNDArray w, @Nullable DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(contributions, "contributions");
        Intrinsics.checkParameterIsNotNull(x, "x");
        Intrinsics.checkParameterIsNotNull(y, "y");
        Intrinsics.checkParameterIsNotNull(w, "w");
        if (x instanceof DenseNDArray) {
            forwardDenseArray((DenseNDArray) contributions, (DenseNDArray) x, y, w, denseNDArray);
        } else {
            if (x instanceof SparseBinaryNDArray) {
                forwardSparseBinaryArray((SparseNDArray) contributions, (SparseBinaryNDArray) x, y, w, denseNDArray);
                return;
            }
            Object[] objArr = {x.getClass().getName()};
            String format = String.format("Invalid input type '%s'", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            throw new RuntimeException(format);
        }
    }

    public static /* bridge */ /* synthetic */ void forwardArray$default(ForwardHelper forwardHelper, NDArray nDArray, NDArray nDArray2, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forwardArray");
        }
        if ((i & 16) != 0) {
            denseNDArray3 = (DenseNDArray) null;
        }
        forwardHelper.forwardArray(nDArray, nDArray2, denseNDArray, denseNDArray2, denseNDArray3);
    }

    protected final void forwardDenseArray(@NotNull DenseNDArray contributions, @NotNull DenseNDArray x, @NotNull DenseNDArray y, @NotNull DenseNDArray w, @Nullable DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(contributions, "contributions");
        Intrinsics.checkParameterIsNotNull(x, "x");
        Intrinsics.checkParameterIsNotNull(y, "y");
        Intrinsics.checkParameterIsNotNull(w, "w");
        int length = x.getLength();
        int rows = w.getRows();
        for (int i = 0; i < rows; i++) {
            y.set(i, Double.valueOf(0.0d));
            int columns = w.getColumns();
            for (int i2 = 0; i2 < columns; i2++) {
                double doubleValue = w.get(i, i2).doubleValue() * x.get(i2).doubleValue();
                if (denseNDArray != null) {
                    doubleValue += denseNDArray.get(i).doubleValue() / length;
                }
                contributions.set(i, i2, Double.valueOf(doubleValue));
                int i3 = i;
                y.set(i3, Double.valueOf(y.get(i3).doubleValue() + doubleValue));
            }
        }
    }

    public static /* bridge */ /* synthetic */ void forwardDenseArray$default(ForwardHelper forwardHelper, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3, DenseNDArray denseNDArray4, DenseNDArray denseNDArray5, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forwardDenseArray");
        }
        if ((i & 16) != 0) {
            denseNDArray5 = (DenseNDArray) null;
        }
        forwardHelper.forwardDenseArray(denseNDArray, denseNDArray2, denseNDArray3, denseNDArray4, denseNDArray5);
    }

    private final void forwardSparseBinaryArray(SparseNDArray sparseNDArray, SparseBinaryNDArray sparseBinaryNDArray, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3) {
        Object first = CollectionsKt.first(sparseBinaryNDArray.getActiveIndicesByColumn().values());
        if (first == null) {
            Intrinsics.throwNpe();
        }
        ArrayList arrayList = (ArrayList) first;
        int size = arrayList.size();
        int length = denseNDArray.getLength();
        int i = size * length;
        denseNDArray.zeros();
        Double[] dArr = new Double[i];
        int length2 = dArr.length;
        for (int i2 = 0; i2 < length2; i2++) {
            int i3 = i2;
            int i4 = i2;
            int i5 = i4 % length;
            Object obj = arrayList.get(i4 / length);
            Intrinsics.checkExpressionValueIsNotNull(obj, "xActiveIndices[k / yLength]");
            double doubleValue = denseNDArray2.get(i5, ((Number) obj).intValue()).doubleValue();
            if (denseNDArray3 != null) {
                doubleValue += denseNDArray3.get(i5).doubleValue() / size;
            }
            denseNDArray.set(i5, Double.valueOf(denseNDArray.get(i5).doubleValue() + doubleValue));
            dArr[i3] = Double.valueOf(doubleValue);
        }
        Integer[] numArr = new Integer[i];
        int length3 = numArr.length;
        for (int i6 = 0; i6 < length3; i6++) {
            numArr[i6] = Integer.valueOf(i6 % length);
        }
        Integer[] numArr2 = new Integer[i];
        int length4 = numArr2.length;
        for (int i7 = 0; i7 < length4; i7++) {
            numArr2[i7] = (Integer) arrayList.get(i7 / length);
        }
        sparseNDArray.assignValues(dArr, numArr, numArr2);
    }

    static /* bridge */ /* synthetic */ void forwardSparseBinaryArray$default(ForwardHelper forwardHelper, SparseNDArray sparseNDArray, SparseBinaryNDArray sparseBinaryNDArray, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forwardSparseBinaryArray");
        }
        if ((i & 16) != 0) {
            denseNDArray3 = (DenseNDArray) null;
        }
        forwardHelper.forwardSparseBinaryArray(sparseNDArray, sparseBinaryNDArray, denseNDArray, denseNDArray2, denseNDArray3);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void addRecurrentContribution(@NotNull DenseNDArray contributions, @NotNull DenseNDArray yPrev, @NotNull DenseNDArray yRec, @NotNull DenseNDArray y, @NotNull DenseNDArray wRec, @NotNull DenseNDArray b) {
        Intrinsics.checkParameterIsNotNull(contributions, "contributions");
        Intrinsics.checkParameterIsNotNull(yPrev, "yPrev");
        Intrinsics.checkParameterIsNotNull(yRec, "yRec");
        Intrinsics.checkParameterIsNotNull(y, "y");
        Intrinsics.checkParameterIsNotNull(wRec, "wRec");
        Intrinsics.checkParameterIsNotNull(b, "b");
        forwardArray(contributions, yPrev, yRec, wRec, b);
        y.assignSum((NDArray<?>) yRec);
    }

    @NotNull
    protected LayerStructure<InputNDArrayType> getLayer() {
        return this.layer;
    }

    public ForwardHelper(@NotNull LayerStructure<InputNDArrayType> layer) {
        Intrinsics.checkParameterIsNotNull(layer, "layer");
        this.layer = layer;
    }
}
