package com.kotlinnlp.simplednn.core.layers.recurrent.ran;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.LayerStructure;
import com.kotlinnlp.simplednn.core.layers.RelevanceUtils;
import com.kotlinnlp.simplednn.core.layers.recurrent.GatedRecurrentRelevanceHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: RANRelevanceHelper.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��6\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0004\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003B\u0013\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005¢\u0006\u0002\u0010\u0006J\u0018\u0010\t\u001a\u0006\u0012\u0002\b\u00030\u00022\n\u0010\n\u001a\u0006\u0012\u0002\b\u00030\u000bH\u0014J\u001e\u0010\f\u001a\u0010\u0012\u0004\u0012\u00020\u000e\u0012\u0006\u0012\u0004\u0018\u00010\u000e0\r2\u0006\u0010\u000f\u001a\u00020\u0010H\u0002J\u0014\u0010\u0011\u001a\u00020\u00122\n\u0010\n\u001a\u0006\u0012\u0002\b\u00030\u000bH\u0016J\u0014\u0010\u0013\u001a\u00020\u00122\n\u0010\n\u001a\u0006\u0012\u0002\b\u00030\u000bH\u0016J&\u0010\u0014\u001a\u0010\u0012\u0004\u0012\u00020\u000e\u0012\u0006\u0012\u0004\u0018\u00010\u000e0\r2\u0006\u0010\u0015\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u0010H\u0002R\u001a\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006\u0016"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/recurrent/ran/RANRelevanceHelper;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/recurrent/GatedRecurrentRelevanceHelper;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/recurrent/ran/RANLayerStructure;", "(Lcom/kotlinnlp/simplednn/core/layers/recurrent/ran/RANLayerStructure;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/recurrent/ran/RANLayerStructure;", "getInputRelevance", "layerContributions", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "getRelevancePartitions", "Lkotlin/Pair;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "contributions", "Lcom/kotlinnlp/simplednn/core/layers/recurrent/ran/RANLayerParameters;", "propagateRelevanceToGates", "", "setRecurrentRelevance", "splitRelevancePartitions", "yRelevance", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/recurrent/ran/RANRelevanceHelper.class */
public final class RANRelevanceHelper<InputNDArrayType extends NDArray<InputNDArrayType>> extends GatedRecurrentRelevanceHelper<InputNDArrayType> {

    @NotNull
    private final RANLayerStructure<InputNDArrayType> layer;

    @Override // com.kotlinnlp.simplednn.core.layers.recurrent.GatedRecurrentRelevanceHelper
    public void propagateRelevanceToGates(@NotNull LayerParameters<?> layerContributions) {
        Intrinsics.checkParameterIsNotNull(layerContributions, "layerContributions");
        Pair<DenseNDArray, DenseNDArray> relevancePartitions = getRelevancePartitions((RANLayerParameters) layerContributions);
        DenseNDArray component1 = relevancePartitions.component1();
        DenseNDArray component2 = relevancePartitions.component2();
        DenseNDArray assignDiv = component1.assignDiv(2.0d);
        getLayer().getCandidate().assignRelevance(assignDiv);
        getLayer().getInputGate().assignRelevance(assignDiv);
        if (component2 != null) {
            getLayer().getForgetGate().assignRelevance(component2.assignDiv(2.0d));
        }
    }

    /* JADX WARN: Type inference failed for: r0v22, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray<?>, com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    @Override // com.kotlinnlp.simplednn.core.layers.RelevanceHelper
    @NotNull
    protected NDArray<?> getInputRelevance(@NotNull LayerParameters<?> layerContributions) {
        Intrinsics.checkParameterIsNotNull(layerContributions, "layerContributions");
        InputNDArrayType values = getLayer().getInputArray().getValues();
        boolean z = getLayer().getLayerContextWindow().getPrevStateLayer() != null;
        ?? assignSum = getLayer().getInputGate().getInputRelevance(values, ((RANLayerParameters) layerContributions).getInputGate(), z).assignSum(getLayer().getCandidate().getInputRelevance(values, ((RANLayerParameters) layerContributions).getCandidate()));
        if (z) {
            assignSum.assignSum(getLayer().getForgetGate().getInputRelevance(values, ((RANLayerParameters) layerContributions).getForgetGate(), true));
        }
        return assignSum;
    }

    /* JADX WARN: Type inference failed for: r0v28, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    @Override // com.kotlinnlp.simplednn.core.layers.recurrent.RecurrentRelevanceHelper
    public void setRecurrentRelevance(@NotNull LayerParameters<?> layerContributions) {
        Intrinsics.checkParameterIsNotNull(layerContributions, "layerContributions");
        LayerStructure<?> prevStateLayer = getLayer().getLayerContextWindow().getPrevStateLayer();
        if (prevStateLayer == null) {
            Intrinsics.throwNpe();
        }
        AugmentedArray<DenseNDArray> outputArray = prevStateLayer.getOutputArray();
        DenseNDArray component2 = getRelevancePartitions((RANLayerParameters) layerContributions).component2();
        if (component2 == null) {
            Intrinsics.throwNpe();
        }
        DenseNDArray assignDiv = component2.assignDiv(2.0d);
        DenseNDArray recurrentRelevance = getLayer().getInputGate().getRecurrentRelevance(((RANLayerParameters) layerContributions).getInputGate(), outputArray.getValues());
        DenseNDArray recurrentRelevance2 = getLayer().getForgetGate().getRecurrentRelevance(((RANLayerParameters) layerContributions).getForgetGate(), outputArray.getValues());
        outputArray.assignRelevance(assignDiv);
        outputArray.getRelevance().assignSum(recurrentRelevance).assignSum(recurrentRelevance2);
    }

    private final Pair<DenseNDArray, DenseNDArray> getRelevancePartitions(RANLayerParameters rANLayerParameters) {
        NDArray<?> relevance = getLayer().getOutputArray().getRelevance();
        if (relevance == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) relevance;
        return getLayer().getLayerContextWindow().getPrevStateLayer() != null ? splitRelevancePartitions(denseNDArray, rANLayerParameters) : new Pair<>(denseNDArray, null);
    }

    private final Pair<DenseNDArray, DenseNDArray> splitRelevancePartitions(DenseNDArray denseNDArray, RANLayerParameters rANLayerParameters) {
        DenseNDArray valuesNotActivated = getLayer().getOutputArray().getValuesNotActivated();
        Object values = rANLayerParameters.getCandidate().getBiases().getValues();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray2 = (DenseNDArray) values;
        return new Pair<>(RelevanceUtils.getRelevancePartition1$default(RelevanceUtils.INSTANCE, denseNDArray, valuesNotActivated, valuesNotActivated.sub(denseNDArray2), denseNDArray2, 0, 16, null), RelevanceUtils.getRelevancePartition2$default(RelevanceUtils.INSTANCE, denseNDArray, valuesNotActivated, denseNDArray2, 0, 8, null));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.recurrent.RecurrentRelevanceHelper, com.kotlinnlp.simplednn.core.layers.RelevanceHelper
    @NotNull
    public RANLayerStructure<InputNDArrayType> getLayer() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public RANRelevanceHelper(@NotNull RANLayerStructure<InputNDArrayType> layer) {
        super(layer);
        Intrinsics.checkParameterIsNotNull(layer, "layer");
        this.layer = layer;
    }
}
