package com.kotlinnlp.simplednn.core.neuralnetwork.structure.feedforward;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.arrays.DistributionArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.layers.LayerConfiguration;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.LayerStructure;
import com.kotlinnlp.simplednn.core.layers.LayerStructureFactory;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralnetwork.structure.NetworkStructure;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import java.util.Arrays;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IndexedValue;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: FeedforwardNetworkStructure.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��H\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n��\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003B\u001b\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJH\u0010\n\u001a\b\u0012\u0004\u0012\u0002H\u00010\u000b\"\u000e\b\u0001\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\f\u0010\f\u001a\b\u0012\u0004\u0012\u0002H\u00010\r2\u0006\u0010\u000e\u001a\u00020\u00062\n\u0010\u0007\u001a\u0006\u0012\u0002\b\u00030\u000f2\u0006\u0010\u0010\u001a\u00020\u0011H\u0014J\u0016\u0010\u0012\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\b2\u0006\u0010\u0015\u001a\u00020\u0016¨\u0006\u0017"}, d2 = {"Lcom/kotlinnlp/simplednn/core/neuralnetwork/structure/feedforward/FeedforwardNetworkStructure;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/structure/NetworkStructure;", "layersConfiguration", "", "Lcom/kotlinnlp/simplednn/core/layers/LayerConfiguration;", "params", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "(Ljava/util/List;Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;)V", "layerFactory", "Lcom/kotlinnlp/simplednn/core/layers/LayerStructure;", "inputArray", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "outputConfiguration", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "dropout", "", "propagateRelevance", "", "networkContributions", "relevantOutcomesDistribution", "Lcom/kotlinnlp/simplednn/core/arrays/DistributionArray;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/neuralnetwork/structure/feedforward/FeedforwardNetworkStructure.class */
public final class FeedforwardNetworkStructure<InputNDArrayType extends NDArray<InputNDArrayType>> extends NetworkStructure<InputNDArrayType> {
    @Override // com.kotlinnlp.simplednn.core.neuralnetwork.structure.NetworkStructure
    @NotNull
    protected <InputNDArrayType extends NDArray<InputNDArrayType>> LayerStructure<InputNDArrayType> layerFactory(@NotNull AugmentedArray<InputNDArrayType> augmentedArray, @NotNull LayerConfiguration layerConfiguration, @NotNull LayerParameters<?> layerParameters, double d) {
        Intrinsics.checkParameterIsNotNull(augmentedArray, "inputArray");
        Intrinsics.checkParameterIsNotNull(layerConfiguration, "outputConfiguration");
        Intrinsics.checkParameterIsNotNull(layerParameters, "params");
        if (!Intrinsics.areEqual(layerConfiguration.getConnectionType(), LayerType.Connection.Feedforward)) {
            Object[] objArr = {layerConfiguration.getConnectionType(), LayerType.Connection.Feedforward};
            String format = String.format("Layer connection of type %s not allowed [only %s]", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            throw new IllegalArgumentException(format.toString());
        }
        LayerStructureFactory layerStructureFactory = LayerStructureFactory.INSTANCE;
        int size = layerConfiguration.getSize();
        ActivationFunction activationFunction = layerConfiguration.getActivationFunction();
        LayerType.Connection connectionType = layerConfiguration.getConnectionType();
        if (connectionType == null) {
            Intrinsics.throwNpe();
        }
        return LayerStructureFactory.invoke$default(layerStructureFactory, augmentedArray, size, layerParameters, activationFunction, connectionType, d, null, 64, null);
    }

    public final void propagateRelevance(@NotNull NetworkParameters networkParameters, @NotNull DistributionArray distributionArray) {
        Intrinsics.checkParameterIsNotNull(networkParameters, "networkContributions");
        Intrinsics.checkParameterIsNotNull(distributionArray, "relevantOutcomesDistribution");
        ((LayerStructure) ArraysKt.last(getLayers())).setOutputRelevance(distributionArray);
        for (IndexedValue indexedValue : CollectionsKt.reversed(ArraysKt.withIndex(getLayers()))) {
            int component1 = indexedValue.component1();
            LayerStructure layerStructure = (LayerStructure) indexedValue.component2();
            if (layerStructure == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.feedforward.FeedforwardLayerStructure<out com.kotlinnlp.simplednn.simplemath.ndarray.NDArray<*>>");
            }
            setCurLayerIndex(component1);
            layerStructure.setInputRelevance(networkParameters.getParamsPerLayer()[component1]);
        }
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public FeedforwardNetworkStructure(@NotNull List<LayerConfiguration> list, @NotNull NetworkParameters networkParameters) {
        super(list, networkParameters);
        Intrinsics.checkParameterIsNotNull(list, "layersConfiguration");
        Intrinsics.checkParameterIsNotNull(networkParameters, "params");
    }
}
