package com.kotlinnlp.simplednn.core.layers.helpers;

import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.simplednn.simplemath.ndarray.sparse.SparseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.sparsebinary.SparseBinaryNDArray;
import java.util.Arrays;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: RelevanceUtils.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��6\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\b\n\u0002\b\u0002\bÆ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J2\u0010\u0005\u001a\u0006\u0012\u0002\b\u00030\u00062\n\u0010\u0007\u001a\u0006\u0012\u0002\b\u00030\u00062\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\t2\n\u0010\u000b\u001a\u0006\u0012\u0002\b\u00030\u0006J&\u0010\f\u001a\u00020\t2\u0006\u0010\u0007\u001a\u00020\t2\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\t2\u0006\u0010\u000b\u001a\u00020\tJ(\u0010\r\u001a\u00020\u000e2\u0006\u0010\u0007\u001a\u00020\u000f2\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\t2\u0006\u0010\u000b\u001a\u00020\u000eH\u0002J0\u0010\u0010\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\t2\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\u0011\u001a\u00020\t2\u0006\u0010\u0012\u001a\u00020\t2\b\b\u0002\u0010\u0013\u001a\u00020\u0014J(\u0010\u0015\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\t2\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\u0012\u001a\u00020\t2\b\b\u0002\u0010\u0013\u001a\u00020\u0014R\u000e\u0010\u0003\u001a\u00020\u0004X\u0082T¢\u0006\u0002\n��¨\u0006\u0016"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/helpers/RelevanceUtils;", "", "()V", "relevanceEps", "", "calculateRelevanceOfArray", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "x", "y", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "yRelevance", "contributions", "calculateRelevanceOfDenseArray", "calculateRelevanceOfSparseArray", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/sparse/SparseNDArray;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/sparsebinary/SparseBinaryNDArray;", "getRelevancePartition1", "yContribute1", "yContribute2", "nPartitions", "", "getRelevancePartition2", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/helpers/RelevanceUtils.class */
public final class RelevanceUtils {
    private static final double relevanceEps = 0.01d;
    public static final RelevanceUtils INSTANCE = new RelevanceUtils();

    @NotNull
    public final NDArray<?> calculateRelevanceOfArray(@NotNull NDArray<?> nDArray, @NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2, @NotNull NDArray<?> nDArray2) {
        Intrinsics.checkParameterIsNotNull(nDArray, "x");
        Intrinsics.checkParameterIsNotNull(denseNDArray, "y");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "yRelevance");
        Intrinsics.checkParameterIsNotNull(nDArray2, "contributions");
        if (nDArray instanceof DenseNDArray) {
            return calculateRelevanceOfDenseArray((DenseNDArray) nDArray, denseNDArray, denseNDArray2, (DenseNDArray) nDArray2);
        }
        if (nDArray instanceof SparseBinaryNDArray) {
            return calculateRelevanceOfSparseArray((SparseBinaryNDArray) nDArray, denseNDArray, denseNDArray2, (SparseNDArray) nDArray2);
        }
        Object[] objArr = {nDArray.getClass().getName()};
        String format = String.format("Invalid input type '%s'", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        throw new RuntimeException(format);
    }

    @NotNull
    public final DenseNDArray calculateRelevanceOfDenseArray(@NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2, @NotNull DenseNDArray denseNDArray3, @NotNull DenseNDArray denseNDArray4) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "x");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "y");
        Intrinsics.checkParameterIsNotNull(denseNDArray3, "yRelevance");
        Intrinsics.checkParameterIsNotNull(denseNDArray4, "contributions");
        DenseNDArray zeros = DenseNDArrayFactory.INSTANCE.zeros(denseNDArray.getShape());
        int length = denseNDArray.getLength();
        int length2 = denseNDArray2.getLength();
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                double d = denseNDArray2.get(i2).doubleValue() >= ((double) 0) ? relevanceEps : -0.01d;
                int i3 = i;
                zeros.set(i3, Double.valueOf(zeros.get(i3).doubleValue() + ((denseNDArray3.get(i2).doubleValue() * (denseNDArray4.get(i2, i).doubleValue() + (d / length))) / (denseNDArray2.get(i2).doubleValue() + d))));
            }
        }
        return zeros;
    }

    private final SparseNDArray calculateRelevanceOfSparseArray(SparseBinaryNDArray sparseBinaryNDArray, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, SparseNDArray sparseNDArray) {
        Object first = CollectionsKt.first(sparseBinaryNDArray.getActiveIndicesByColumn().values());
        if (first == null) {
            Intrinsics.throwNpe();
        }
        List list = (List) first;
        int size = list.size();
        double[] dArr = new double[size];
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            dArr[i] = 0.0d;
        }
        int[] iArr = new int[size];
        int length2 = iArr.length;
        for (int i2 = 0; i2 < length2; i2++) {
            iArr[i2] = 0;
        }
        int[] intArray = CollectionsKt.toIntArray(list);
        int length3 = denseNDArray.getLength();
        int i3 = 0;
        for (int i4 = 0; i4 < size; i4++) {
            for (int i5 = 0; i5 < length3; i5++) {
                double d = denseNDArray.get(i5).doubleValue() >= ((double) 0) ? relevanceEps : -0.01d;
                double d2 = d / size;
                int i6 = i3;
                i3++;
                double d3 = sparseNDArray.getValues()[i6];
                int i7 = i4;
                dArr[i7] = dArr[i7] + ((denseNDArray2.get(i5).doubleValue() * (d3 + d2)) / (denseNDArray.get(i5).doubleValue() + d));
            }
        }
        return SparseNDArray.Companion.invoke(sparseBinaryNDArray.getShape(), dArr, intArray, iArr);
    }

    @NotNull
    public final DenseNDArray getRelevancePartition1(@NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2, @NotNull DenseNDArray denseNDArray3, @NotNull DenseNDArray denseNDArray4, int i) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "yRelevance");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "y");
        Intrinsics.checkParameterIsNotNull(denseNDArray3, "yContribute1");
        Intrinsics.checkParameterIsNotNull(denseNDArray4, "yContribute2");
        DenseNDArray assignProd = denseNDArray4.nonZeroSign().assignProd(relevanceEps);
        return denseNDArray.prod((NDArray<?>) denseNDArray3.sum(assignProd.div(i))).assignDiv(denseNDArray2.sum(assignProd));
    }

    @NotNull
    public static /* synthetic */ DenseNDArray getRelevancePartition1$default(RelevanceUtils relevanceUtils, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3, DenseNDArray denseNDArray4, int i, int i2, Object obj) {
        if ((i2 & 16) != 0) {
            i = 2;
        }
        return relevanceUtils.getRelevancePartition1(denseNDArray, denseNDArray2, denseNDArray3, denseNDArray4, i);
    }

    @NotNull
    public final DenseNDArray getRelevancePartition2(@NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2, @NotNull DenseNDArray denseNDArray3, int i) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "yRelevance");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "y");
        Intrinsics.checkParameterIsNotNull(denseNDArray3, "yContribute2");
        DenseNDArray assignProd = denseNDArray3.nonZeroSign().assignProd(relevanceEps);
        return denseNDArray.prod((NDArray<?>) denseNDArray3.sum(assignProd.div(i))).assignDiv(denseNDArray2.sum(assignProd));
    }

    @NotNull
    public static /* synthetic */ DenseNDArray getRelevancePartition2$default(RelevanceUtils relevanceUtils, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3, int i, int i2, Object obj) {
        if ((i2 & 8) != 0) {
            i = 2;
        }
        return relevanceUtils.getRelevancePartition2(denseNDArray, denseNDArray2, denseNDArray3, i);
    }

    private RelevanceUtils() {
    }
}
