package com.kotlinnlp.simplednn.core.layers.recurrent.simple;

import com.kotlinnlp.simplednn.core.layers.ForwardHelper;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.LayerStructure;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: SimpleRecurrentForwardHelper.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��$\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n��\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003B\u0013\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005¢\u0006\u0002\u0010\u0006J\b\u0010\t\u001a\u00020\nH\u0016J\u0014\u0010\t\u001a\u00020\n2\n\u0010\u000b\u001a\u0006\u0012\u0002\b\u00030\fH\u0016R\u001a\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006\r"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/recurrent/simple/SimpleRecurrentForwardHelper;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/ForwardHelper;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/recurrent/simple/SimpleRecurrentLayerStructure;", "(Lcom/kotlinnlp/simplednn/core/layers/recurrent/simple/SimpleRecurrentLayerStructure;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/recurrent/simple/SimpleRecurrentLayerStructure;", "forward", "", "layerContributions", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/recurrent/simple/SimpleRecurrentForwardHelper.class */
public final class SimpleRecurrentForwardHelper<InputNDArrayType extends NDArray<InputNDArrayType>> extends ForwardHelper<InputNDArrayType> {

    @NotNull
    private final SimpleRecurrentLayerStructure<InputNDArrayType> layer;

    @Override // com.kotlinnlp.simplednn.core.layers.ForwardHelper
    public void forward() {
        LayerParameters<?> params = getLayer().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.recurrent.simple.SimpleRecurrentLayerParameters");
        }
        getLayer().getOutputArray().forward(((SimpleRecurrentLayerParameters) getLayer().getParams()).getUnit(), getLayer().getInputArray().getValues());
        LayerStructure<?> prevStateLayer = getLayer().getLayerContextWindow().getPrevStateLayer();
        if (prevStateLayer != null) {
            getLayer().getOutputArray().addRecurrentContribution(((SimpleRecurrentLayerParameters) getLayer().getParams()).getUnit(), prevStateLayer.getOutputArray().getValues());
        }
        getLayer().getOutputArray().activate();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.simplednn.core.layers.ForwardHelper
    public void forward(@NotNull LayerParameters<?> layerParameters) {
        Intrinsics.checkParameterIsNotNull(layerParameters, "layerContributions");
        LayerParameters<?> params = getLayer().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.recurrent.simple.SimpleRecurrentLayerParameters");
        }
        LayerStructure<?> prevStateLayer = getLayer().getLayerContextWindow().getPrevStateLayer();
        Object values = ((SimpleRecurrentLayerParameters) getLayer().getParams()).getUnit().getBiases().getValues();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) values;
        DenseNDArray div = prevStateLayer != null ? denseNDArray.div(2.0d) : denseNDArray;
        NDArray<?> values2 = ((SimpleRecurrentLayerParameters) layerParameters).getUnit().getWeights().getValues();
        NDArray<?> values3 = getLayer().getInputArray().getValues();
        DenseNDArray denseNDArray2 = (DenseNDArray) getLayer().getOutputArray().getValues();
        Object values4 = ((SimpleRecurrentLayerParameters) getLayer().getParams()).getUnit().getWeights().getValues();
        if (values4 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        forwardArray(values2, values3, denseNDArray2, (DenseNDArray) values4, div);
        if (prevStateLayer != null) {
            DenseNDArray values5 = prevStateLayer.getOutputArray().getValues();
            Object values6 = ((SimpleRecurrentLayerParameters) layerParameters).getUnit().getBiases().getValues();
            if (values6 == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
            }
            DenseNDArray denseNDArray3 = (DenseNDArray) values6;
            DenseNDArray denseNDArray4 = (DenseNDArray) getLayer().getOutputArray().getValues();
            Object values7 = ((SimpleRecurrentLayerParameters) getLayer().getParams()).getUnit().getRecurrentWeights().getValues();
            if (values7 == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
            }
            DenseNDArray denseNDArray5 = (DenseNDArray) values7;
            Object values8 = ((SimpleRecurrentLayerParameters) layerParameters).getUnit().getRecurrentWeights().getValues();
            if (values8 == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
            }
            addRecurrentContribution((DenseNDArray) values8, values5, denseNDArray3, denseNDArray4, denseNDArray5, div);
        }
        getLayer().getOutputArray().activate();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.ForwardHelper
    @NotNull
    public SimpleRecurrentLayerStructure<InputNDArrayType> getLayer() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public SimpleRecurrentForwardHelper(@NotNull SimpleRecurrentLayerStructure<InputNDArrayType> simpleRecurrentLayerStructure) {
        super(simpleRecurrentLayerStructure);
        Intrinsics.checkParameterIsNotNull(simpleRecurrentLayerStructure, "layer");
        this.layer = simpleRecurrentLayerStructure;
    }
}
