package com.kotlinnlp.simplednn.core.layers.models.recurrent.lstm;

import com.kotlinnlp.simplednn.core.layers.ArrayExtensionsKt;
import com.kotlinnlp.simplednn.core.layers.Layer;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import kotlin.Metadata;
import kotlin.NotImplementedError;
import kotlin.TypeCastException;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: LSTMForwardHelper.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��0\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003B\u0013\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005¢\u0006\u0002\u0010\u0006J\b\u0010\t\u001a\u00020\nH\u0002J\u0014\u0010\u000b\u001a\u00020\n2\n\u0010\f\u001a\u0006\u0012\u0002\b\u00030\rH\u0002J\b\u0010\u000e\u001a\u00020\nH\u0016J\u0014\u0010\u000e\u001a\u00020\n2\n\u0010\u000f\u001a\u0006\u0012\u0002\b\u00030\u0010H\u0016J\b\u0010\u0011\u001a\u00020\nH\u0002J\u0016\u0010\u0012\u001a\u00020\n2\f\u0010\f\u001a\b\u0012\u0002\b\u0003\u0018\u00010\rH\u0002R\u001a\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006\u0013"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/lstm/LSTMForwardHelper;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/helpers/ForwardHelper;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/lstm/LSTMLayer;", "(Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/lstm/LSTMLayer;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/lstm/LSTMLayer;", "activateGates", "", "addGatesRecurrentContribution", "prevStateLayer", "Lcom/kotlinnlp/simplednn/core/layers/Layer;", "forward", "layerContributions", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "forwardGates", "setGates", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/recurrent/lstm/LSTMForwardHelper.class */
public final class LSTMForwardHelper<InputNDArrayType extends NDArray<InputNDArrayType>> extends ForwardHelper<InputNDArrayType> {

    @NotNull
    private final LSTMLayer<InputNDArrayType> layer;

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    public void forward() {
        setGates(getLayer().getLayerContextWindow().getPrevState());
        getLayer().getOutputArray().getValues().assignProd((DenseNDArray) getLayer().getOutputGate().getValues(), getLayer().getCell().getValues());
    }

    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    public void forward(@NotNull LayerParameters<?> layerParameters) {
        Intrinsics.checkParameterIsNotNull(layerParameters, "layerContributions");
        throw new NotImplementedError("An operation is not implemented: not implemented");
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final void setGates(Layer<?> layer) {
        forwardGates();
        if (layer != null) {
            addGatesRecurrentContribution(layer);
        }
        activateGates();
        DenseNDArray values = getLayer().getCell().getValues();
        values.assignProd((DenseNDArray) getLayer().getInputGate().getValues(), (DenseNDArray) getLayer().getCandidate().getValues());
        if (layer != null) {
            values.assignSum((NDArray<?>) ((DenseNDArray) getLayer().getForgetGate().getValues()).prod((NDArray<?>) ((LSTMLayer) layer).getCell().getValuesNotActivated()));
        }
        getLayer().getCell().activate();
    }

    /* JADX WARN: Type inference failed for: r1v14, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r1v21, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r1v28, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r1v7, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r2v13, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r2v20, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r2v27, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r2v6, types: [com.kotlinnlp.simplednn.simplemath.ndarray.NDArray] */
    private final void forwardGates() {
        LayerParameters<?> params = getLayer().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.recurrent.lstm.LSTMLayerParameters");
        }
        InputNDArrayType values = getLayer().getInputArray().getValues();
        ArrayExtensionsKt.forward(getLayer().getInputGate(), ((LSTMLayerParameters) getLayer().getParams()).getInputGate().getWeights().getValues(), ((LSTMLayerParameters) getLayer().getParams()).getInputGate().getBiases().getValues(), values);
        ArrayExtensionsKt.forward(getLayer().getOutputGate(), ((LSTMLayerParameters) getLayer().getParams()).getOutputGate().getWeights().getValues(), ((LSTMLayerParameters) getLayer().getParams()).getOutputGate().getBiases().getValues(), values);
        ArrayExtensionsKt.forward(getLayer().getForgetGate(), ((LSTMLayerParameters) getLayer().getParams()).getForgetGate().getWeights().getValues(), ((LSTMLayerParameters) getLayer().getParams()).getForgetGate().getBiases().getValues(), values);
        ArrayExtensionsKt.forward(getLayer().getCandidate(), ((LSTMLayerParameters) getLayer().getParams()).getCandidate().getWeights().getValues(), ((LSTMLayerParameters) getLayer().getParams()).getCandidate().getBiases().getValues(), values);
    }

    private final void addGatesRecurrentContribution(Layer<?> layer) {
        LayerParameters<?> params = getLayer().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.recurrent.lstm.LSTMLayerParameters");
        }
        DenseNDArray values = layer.getOutputArray().getValues();
        getLayer().getInputGate().addRecurrentContribution(((LSTMLayerParameters) getLayer().getParams()).getInputGate(), values);
        getLayer().getOutputGate().addRecurrentContribution(((LSTMLayerParameters) getLayer().getParams()).getOutputGate(), values);
        getLayer().getForgetGate().addRecurrentContribution(((LSTMLayerParameters) getLayer().getParams()).getForgetGate(), values);
        getLayer().getCandidate().addRecurrentContribution(((LSTMLayerParameters) getLayer().getParams()).getCandidate(), values);
    }

    private final void activateGates() {
        getLayer().getInputGate().activate();
        getLayer().getOutputGate().activate();
        getLayer().getForgetGate().activate();
        getLayer().getCandidate().activate();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    @NotNull
    public LSTMLayer<InputNDArrayType> getLayer() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public LSTMForwardHelper(@NotNull LSTMLayer<InputNDArrayType> lSTMLayer) {
        super(lSTMLayer);
        Intrinsics.checkParameterIsNotNull(lSTMLayer, "layer");
        this.layer = lSTMLayer;
    }
}
