package com.kotlinnlp.simplednn.core.layers.recurrent.deltarnn;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.LayerStructure;
import com.kotlinnlp.simplednn.core.layers.RelevanceUtils;
import com.kotlinnlp.simplednn.core.layers.recurrent.GatedRecurrentRelevanceHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.sparse.SparseNDArray;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.ArraysKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: DeltaRNNRelevanceHelper.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��L\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0004\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003B\u0013\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005¢\u0006\u0002\u0010\u0006J \u0010\t\u001a\u0006\u0012\u0002\b\u00030\u00022\n\u0010\n\u001a\u0006\u0012\u0002\b\u00030\u00022\u0006\u0010\u000b\u001a\u00020\fH\u0002J\u0010\u0010\r\u001a\u00020\f2\u0006\u0010\u000e\u001a\u00020\u000fH\u0002J\u0018\u0010\u0010\u001a\u0006\u0012\u0002\b\u00030\u00022\n\u0010\u0011\u001a\u0006\u0012\u0002\b\u00030\u0012H\u0014J\u0010\u0010\u0013\u001a\u00020\f2\u0006\u0010\u000e\u001a\u00020\u000fH\u0002J\u0014\u0010\u0014\u001a\u00020\u00152\n\u0010\u0011\u001a\u0006\u0012\u0002\b\u00030\u0012H\u0016J\u0010\u0010\u0016\u001a\u00020\u00152\u0006\u0010\u0017\u001a\u00020\u0018H\u0002J\u0014\u0010\u0019\u001a\u00020\u00152\n\u0010\u0011\u001a\u0006\u0012\u0002\b\u00030\u0012H\u0016J\u0018\u0010\u001a\u001a\u00020\u00152\u0006\u0010\u001b\u001a\u00020\f2\u0006\u0010\u001c\u001a\u00020\u001dH\u0002J\u0014\u0010\u001e\u001a\u00020\f*\u00020\f2\u0006\u0010\n\u001a\u00020\fH\u0002J\u0014\u0010\u001e\u001a\u00020\u001f*\u00020\u001f2\u0006\u0010\n\u001a\u00020\fH\u0002R\u001a\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006 "}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/recurrent/deltarnn/DeltaRNNRelevanceHelper;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/recurrent/GatedRecurrentRelevanceHelper;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/recurrent/deltarnn/DeltaRNNLayerStructure;", "(Lcom/kotlinnlp/simplednn/core/layers/recurrent/deltarnn/DeltaRNNLayerStructure;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/recurrent/deltarnn/DeltaRNNLayerStructure;", "assignSum", "a", "b", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getInputPartition", "contributions", "Lcom/kotlinnlp/simplednn/core/layers/recurrent/deltarnn/DeltaRNNLayerParameters;", "getInputRelevance", "layerContributions", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "getRecurrentPartition", "propagateRelevanceToGates", "", "setCandidateRelevancePartitions", "previousStateExists", "", "setRecurrentRelevance", "splitCandidateRelevancePartitions", "cRelevance", "relevanceSupport", "Lcom/kotlinnlp/simplednn/core/layers/recurrent/deltarnn/DeltaRNNRelevanceSupport;", "partialAssignSum", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/sparse/SparseNDArray;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/recurrent/deltarnn/DeltaRNNRelevanceHelper.class */
public final class DeltaRNNRelevanceHelper<InputNDArrayType extends NDArray<InputNDArrayType>> extends GatedRecurrentRelevanceHelper<InputNDArrayType> {

    @NotNull
    private final DeltaRNNLayerStructure<InputNDArrayType> layer;

    @Override // com.kotlinnlp.simplednn.core.layers.recurrent.GatedRecurrentRelevanceHelper
    public void propagateRelevanceToGates(@NotNull LayerParameters<?> layerParameters) {
        Intrinsics.checkParameterIsNotNull(layerParameters, "layerContributions");
        boolean z = getLayer().getLayerContextWindow().getPrevStateLayer() != null;
        NDArray<?> relevance = getLayer().getOutputArray().getRelevance();
        if (relevance == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray div = ((DenseNDArray) relevance).div(2.0d);
        DenseNDArray div2 = z ? getInputPartition((DeltaRNNLayerParameters) layerParameters).div(2.0d) : div;
        getLayer().getPartition().assignRelevance(div);
        getLayer().getCandidate().assignRelevance(div2);
        setCandidateRelevancePartitions(z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v15, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v55, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray<?>, com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    @Override // com.kotlinnlp.simplednn.core.layers.RelevanceHelper
    @NotNull
    protected NDArray<?> getInputRelevance(@NotNull LayerParameters<?> layerParameters) {
        Intrinsics.checkParameterIsNotNull(layerParameters, "layerContributions");
        LayerParameters<?> params = getLayer().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.recurrent.deltarnn.DeltaRNNLayerParameters");
        }
        InputNDArrayType values = getLayer().getInputArray().getValues();
        ?? values2 = ((DeltaRNNLayerParameters) layerParameters).getFeedforwardUnit().getWeights().getValues();
        DeltaRNNRelevanceSupport relevanceSupport = getLayer().getRelevanceSupport();
        boolean z = getLayer().getLayerContextWindow().getPrevStateLayer() != null;
        Object values3 = ((DeltaRNNLayerParameters) getLayer().getParams()).getRecurrentUnit().getBiases().getValues();
        if (values3 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) values3;
        Object values4 = ((DeltaRNNLayerParameters) getLayer().getParams()).getFeedforwardUnit().getBiases().getValues();
        if (values4 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray2 = (DenseNDArray) values4;
        DenseNDArray values5 = ((DeltaRNNLayerParameters) getLayer().getParams()).getBeta1().getValues();
        DenseNDArray div = z ? denseNDArray2.div(2.0d) : denseNDArray2;
        RelevanceUtils relevanceUtils = RelevanceUtils.INSTANCE;
        DenseNDArray valuesNotActivated = getLayer().getPartition().getValuesNotActivated();
        NDArray<?> relevance = getLayer().getPartition().getRelevance();
        if (relevance == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        NDArray<?> calculateRelevanceOfArray = relevanceUtils.calculateRelevanceOfArray(values, valuesNotActivated, (DenseNDArray) relevance, assignSum(values2.copy(), denseNDArray));
        RelevanceUtils relevanceUtils2 = RelevanceUtils.INSTANCE;
        DenseNDArray values6 = relevanceSupport.getD1Input().getValues();
        NDArray<?> relevance2 = relevanceSupport.getD1Input().getRelevance();
        if (relevance2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        ?? assignSum = calculateRelevanceOfArray.assignSum(relevanceUtils2.calculateRelevanceOfArray(values, values6, (DenseNDArray) relevance2, assignSum(values2.prod(values5), div)));
        if (z) {
            RelevanceUtils relevanceUtils3 = RelevanceUtils.INSTANCE;
            DenseNDArray values7 = getLayer().getWx().getValues();
            NDArray<?> relevance3 = relevanceSupport.getD2().getRelevance();
            if (relevance3 == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
            }
            assignSum.assignSum(relevanceUtils3.calculateRelevanceOfArray(values, values7, ((DenseNDArray) relevance3).div(2.0d), values2));
        }
        return assignSum;
    }

    /* JADX WARN: Type inference failed for: r0v27, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v46, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    @Override // com.kotlinnlp.simplednn.core.layers.recurrent.RecurrentRelevanceHelper
    public void setRecurrentRelevance(@NotNull LayerParameters<?> layerParameters) {
        Intrinsics.checkParameterIsNotNull(layerParameters, "layerContributions");
        LayerParameters<?> params = getLayer().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.recurrent.deltarnn.DeltaRNNLayerParameters");
        }
        LayerStructure<?> prevStateLayer = getLayer().getLayerContextWindow().getPrevStateLayer();
        if (prevStateLayer == null) {
            Intrinsics.throwNpe();
        }
        AugmentedArray<DenseNDArray> outputArray = prevStateLayer.getOutputArray();
        DenseNDArray values = outputArray.getValues();
        Object values2 = ((DeltaRNNLayerParameters) layerParameters).getRecurrentUnit().getWeights().getValues();
        if (values2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) values2;
        NDArray div = ((DeltaRNNLayerParameters) getLayer().getParams()).getFeedforwardUnit().getBiases().getValues().div(2.0d);
        if (div == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray2 = (DenseNDArray) div;
        DenseNDArray values3 = ((DeltaRNNLayerParameters) getLayer().getParams()).getBeta2().getValues();
        DeltaRNNRelevanceSupport relevanceSupport = getLayer().getRelevanceSupport();
        RelevanceUtils relevanceUtils = RelevanceUtils.INSTANCE;
        DenseNDArray denseNDArray3 = values;
        DenseNDArray values4 = relevanceSupport.getD1Rec().getValues();
        NDArray<?> relevance = relevanceSupport.getD1Rec().getRelevance();
        if (relevance == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        NDArray<?> calculateRelevanceOfArray = relevanceUtils.calculateRelevanceOfArray(denseNDArray3, values4, (DenseNDArray) relevance, assignSum(denseNDArray.prod((NDArray<?>) values3), denseNDArray2));
        RelevanceUtils relevanceUtils2 = RelevanceUtils.INSTANCE;
        DenseNDArray denseNDArray4 = values;
        DenseNDArray values5 = getLayer().getWyRec().getValues();
        NDArray<?> relevance2 = relevanceSupport.getD2().getRelevance();
        if (relevance2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        NDArray<?> calculateRelevanceOfArray2 = relevanceUtils2.calculateRelevanceOfArray(denseNDArray4, values5, ((DenseNDArray) relevance2).div(2.0d), denseNDArray);
        outputArray.assignRelevance(getRecurrentPartition((DeltaRNNLayerParameters) layerParameters).div(2.0d));
        outputArray.getRelevance().assignSum(calculateRelevanceOfArray).assignSum(calculateRelevanceOfArray2);
    }

    private final DenseNDArray getInputPartition(DeltaRNNLayerParameters deltaRNNLayerParameters) {
        DenseNDArray valuesNotActivated = getLayer().getOutputArray().getValuesNotActivated();
        Object values = deltaRNNLayerParameters.getRecurrentUnit().getBiases().getValues();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) values;
        DenseNDArray sub = valuesNotActivated.sub(denseNDArray);
        RelevanceUtils relevanceUtils = RelevanceUtils.INSTANCE;
        NDArray<?> relevance = getLayer().getOutputArray().getRelevance();
        if (relevance == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        return RelevanceUtils.getRelevancePartition1$default(relevanceUtils, (DenseNDArray) relevance, valuesNotActivated, sub, denseNDArray, 0, 16, null);
    }

    private final DenseNDArray getRecurrentPartition(DeltaRNNLayerParameters deltaRNNLayerParameters) {
        RelevanceUtils relevanceUtils = RelevanceUtils.INSTANCE;
        NDArray<?> relevance = getLayer().getOutputArray().getRelevance();
        if (relevance == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) relevance;
        DenseNDArray valuesNotActivated = getLayer().getOutputArray().getValuesNotActivated();
        Object values = deltaRNNLayerParameters.getRecurrentUnit().getBiases().getValues();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        return RelevanceUtils.getRelevancePartition2$default(relevanceUtils, denseNDArray, valuesNotActivated, (DenseNDArray) values, 0, 8, null);
    }

    private final void setCandidateRelevancePartitions(boolean z) {
        NDArray<?> relevance = getLayer().getCandidate().getRelevance();
        if (relevance == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) relevance;
        DeltaRNNRelevanceSupport relevanceSupport = getLayer().getRelevanceSupport();
        if (z) {
            splitCandidateRelevancePartitions(denseNDArray, relevanceSupport);
        } else {
            relevanceSupport.getD1Input().assignRelevance(denseNDArray);
        }
    }

    private final void splitCandidateRelevancePartitions(DenseNDArray denseNDArray, DeltaRNNRelevanceSupport deltaRNNRelevanceSupport) {
        DenseNDArray valuesNotActivated = getLayer().getCandidate().getValuesNotActivated();
        DenseNDArray values = deltaRNNRelevanceSupport.getD2().getValues();
        deltaRNNRelevanceSupport.getD1Input().assignRelevance(RelevanceUtils.INSTANCE.getRelevancePartition1(denseNDArray, valuesNotActivated, deltaRNNRelevanceSupport.getD1Input().getValues(), values, 3));
        deltaRNNRelevanceSupport.getD1Rec().assignRelevance(RelevanceUtils.INSTANCE.getRelevancePartition1(denseNDArray, valuesNotActivated, deltaRNNRelevanceSupport.getD1Rec().getValues(), values, 3));
        deltaRNNRelevanceSupport.getD2().assignRelevance(RelevanceUtils.INSTANCE.getRelevancePartition2(denseNDArray, valuesNotActivated, values, 3));
    }

    private final NDArray<?> assignSum(NDArray<?> nDArray, DenseNDArray denseNDArray) {
        if (!(nDArray.getRows() == denseNDArray.getRows())) {
            throw new IllegalArgumentException("b must be a column vector with the same number of rows of a".toString());
        }
        if (!(denseNDArray.getColumns() == 1)) {
            throw new IllegalArgumentException("b must be a column vector".toString());
        }
        if (nDArray instanceof DenseNDArray) {
            return partialAssignSum((DenseNDArray) nDArray, denseNDArray);
        }
        if (nDArray instanceof SparseNDArray) {
            return partialAssignSum((SparseNDArray) nDArray, denseNDArray);
        }
        throw new RuntimeException("Invalid NDArray type");
    }

    private final DenseNDArray partialAssignSum(@NotNull DenseNDArray denseNDArray, DenseNDArray denseNDArray2) {
        DenseNDArray div = denseNDArray2.div(denseNDArray.getColumns());
        int rows = denseNDArray.getRows();
        for (int i = 0; i < rows; i++) {
            double doubleValue = div.get(i).doubleValue();
            int columns = denseNDArray.getColumns();
            for (int i2 = 0; i2 < columns; i2++) {
                int i3 = i;
                int i4 = i2;
                denseNDArray.set(i3, i4, Double.valueOf(denseNDArray.get(i3, i4).doubleValue() + doubleValue));
            }
        }
        return denseNDArray;
    }

    private final SparseNDArray partialAssignSum(@NotNull SparseNDArray sparseNDArray, DenseNDArray denseNDArray) {
        DenseNDArray div = denseNDArray.div(ArraysKt.toSet(sparseNDArray.getColIndices()).size());
        int length = sparseNDArray.getValues().length;
        for (int i = 0; i < length; i++) {
            int intValue = sparseNDArray.getRowIndices()[i].intValue();
            Double[] values = sparseNDArray.getValues();
            int i2 = i;
            values[i2] = Double.valueOf(values[i2].doubleValue() + div.get(intValue).doubleValue());
        }
        return sparseNDArray;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.recurrent.RecurrentRelevanceHelper, com.kotlinnlp.simplednn.core.layers.RelevanceHelper
    @NotNull
    public DeltaRNNLayerStructure<InputNDArrayType> getLayer() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public DeltaRNNRelevanceHelper(@NotNull DeltaRNNLayerStructure<InputNDArrayType> deltaRNNLayerStructure) {
        super(deltaRNNLayerStructure);
        Intrinsics.checkParameterIsNotNull(deltaRNNLayerStructure, "layer");
        this.layer = deltaRNNLayerStructure;
    }
}
