package com.kotlinnlp.simplednn.core.neuralnetwork.structure;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.layers.LayerConfiguration;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.LayerStructure;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IndexedValue;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: NetworkStructure.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��^\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n��\n\u0002\u0010\u0006\n\u0002\b\t\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\b&\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B\u001b\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJ2\u0010 \u001a\u00020!2\u0006\u0010\"\u001a\u00020\u001b2\u0006\u0010#\u001a\u00020\b2\b\b\u0002\u0010$\u001a\u00020%2\u0010\u0010&\u001a\f\u0012\u0006\u0012\u0004\u0018\u00010'\u0018\u00010\u0005J\u0012\u0010(\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u00110\u0005H\u0004J%\u0010)\u001a\u00020\u001b2\u0006\u0010*\u001a\u00028��2\u0006\u0010+\u001a\u00020\b2\b\b\u0002\u0010,\u001a\u00020%¢\u0006\u0002\u0010-J\u001d\u0010)\u001a\u00020\u001b2\u0006\u0010*\u001a\u00028��2\b\b\u0002\u0010,\u001a\u00020%¢\u0006\u0002\u0010.JH\u0010/\u001a\b\u0012\u0004\u0012\u0002H\u00010\u0011\"\u000e\b\u0001\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\f\u00100\u001a\b\u0012\u0004\u0012\u0002H\u0001012\u0006\u00102\u001a\u00020\u00062\n\u0010\u0007\u001a\u0006\u0012\u0002\b\u0003032\u0006\u00104\u001a\u00020'H$J(\u0010/\u001a\u0006\u0012\u0002\b\u00030\u00112\u0006\u00105\u001a\u00020\u00062\u0006\u00102\u001a\u00020\u00062\n\u0010\u0007\u001a\u0006\u0012\u0002\b\u000303H\u0002R\u001a\u0010\n\u001a\u00020\u000bX\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\f\u0010\r\"\u0004\b\u000e\u0010\u000fR \u0010\u0010\u001a\b\u0012\u0004\u0012\u00028��0\u00118FX\u0087\u0004¢\u0006\f\u0012\u0004\b\u0012\u0010\u0013\u001a\u0004\b\u0014\u0010\u0015R\u001b\u0010\u0016\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u00110\u0005¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018R\u0017\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005¢\u0006\b\n��\u001a\u0004\b\u0019\u0010\u0018R \u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u001b0\u00118FX\u0087\u0004¢\u0006\f\u0012\u0004\b\u001c\u0010\u0013\u001a\u0004\b\u001d\u0010\u0015R\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u001e\u0010\u001f¨\u00066"}, d2 = {"Lcom/kotlinnlp/simplednn/core/neuralnetwork/structure/NetworkStructure;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "", "layersConfiguration", "", "Lcom/kotlinnlp/simplednn/core/layers/LayerConfiguration;", "params", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "(Ljava/util/List;Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;)V", "curLayerIndex", "", "getCurLayerIndex", "()I", "setCurLayerIndex", "(I)V", "inputLayer", "Lcom/kotlinnlp/simplednn/core/layers/LayerStructure;", "inputLayer$annotations", "()V", "getInputLayer", "()Lcom/kotlinnlp/simplednn/core/layers/LayerStructure;", "layers", "getLayers", "()Ljava/util/List;", "getLayersConfiguration", "outputLayer", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "outputLayer$annotations", "getOutputLayer", "getParams", "()Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "backward", "", "outputErrors", "paramsErrors", "propagateToInput", "", "mePropK", "", "buildLayers", "forward", "features", "networkContributions", "useDropout", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;Z)Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;Z)Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "layerFactory", "inputArray", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "outputConfiguration", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "dropout", "inputConfiguration", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/neuralnetwork/structure/NetworkStructure.class */
public abstract class NetworkStructure<InputNDArrayType extends NDArray<InputNDArrayType>> {

    @NotNull
    private final List<LayerStructure<?>> layers;
    private int curLayerIndex;

    @NotNull
    private final List<LayerConfiguration> layersConfiguration;

    @NotNull
    private final NetworkParameters params;

    @NotNull
    public final List<LayerStructure<?>> getLayers() {
        return this.layers;
    }

    public final int getCurLayerIndex() {
        return this.curLayerIndex;
    }

    public final void setCurLayerIndex(int i) {
        this.curLayerIndex = i;
    }

    public static /* synthetic */ void inputLayer$annotations() {
    }

    @NotNull
    public final LayerStructure<InputNDArrayType> getInputLayer() {
        Object first = CollectionsKt.first((List<? extends Object>) this.layers);
        if (first == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.LayerStructure<InputNDArrayType>");
        }
        return (LayerStructure) first;
    }

    public static /* synthetic */ void outputLayer$annotations() {
    }

    @NotNull
    public final LayerStructure<DenseNDArray> getOutputLayer() {
        Object last = CollectionsKt.last((List<? extends Object>) this.layers);
        if (last == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.LayerStructure<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray>");
        }
        return (LayerStructure) last;
    }

    @NotNull
    public final DenseNDArray forward(@NotNull InputNDArrayType features, boolean z) {
        Intrinsics.checkParameterIsNotNull(features, "features");
        getInputLayer().setInput(features);
        for (IndexedValue indexedValue : CollectionsKt.withIndex(this.layers)) {
            int component1 = indexedValue.component1();
            LayerStructure layerStructure = (LayerStructure) indexedValue.component2();
            this.curLayerIndex = component1;
            layerStructure.forward(z);
        }
        return getOutputLayer().getOutputArray().getValues();
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray forward$default(NetworkStructure networkStructure, NDArray nDArray, boolean z, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forward");
        }
        if ((i & 2) != 0) {
            z = false;
        }
        return networkStructure.forward(nDArray, z);
    }

    @NotNull
    public final DenseNDArray forward(@NotNull InputNDArrayType features, @NotNull NetworkParameters networkContributions, boolean z) {
        Intrinsics.checkParameterIsNotNull(features, "features");
        Intrinsics.checkParameterIsNotNull(networkContributions, "networkContributions");
        getInputLayer().setInput(features);
        for (IndexedValue indexedValue : CollectionsKt.withIndex(this.layers)) {
            int component1 = indexedValue.component1();
            LayerStructure layerStructure = (LayerStructure) indexedValue.component2();
            this.curLayerIndex = component1;
            layerStructure.forward(networkContributions.getParamsPerLayer()[component1], z);
        }
        return getOutputLayer().getOutputArray().getValues();
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray forward$default(NetworkStructure networkStructure, NDArray nDArray, NetworkParameters networkParameters, boolean z, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forward");
        }
        if ((i & 4) != 0) {
            z = false;
        }
        return networkStructure.forward(nDArray, networkParameters, z);
    }

    public final void backward(@NotNull DenseNDArray outputErrors, @NotNull NetworkParameters paramsErrors, boolean z, @Nullable List<Double> list) {
        Intrinsics.checkParameterIsNotNull(outputErrors, "outputErrors");
        Intrinsics.checkParameterIsNotNull(paramsErrors, "paramsErrors");
        if (!(list == null || list.size() == this.layers.size())) {
            throw new IllegalArgumentException("Invalid size of the list of mePropK factors: needed one factor per layer.".toString());
        }
        getOutputLayer().setErrors(outputErrors);
        for (IndexedValue indexedValue : CollectionsKt.reversed(CollectionsKt.withIndex(this.layers))) {
            int component1 = indexedValue.component1();
            LayerStructure layerStructure = (LayerStructure) indexedValue.component2();
            this.curLayerIndex = component1;
            layerStructure.backward(paramsErrors.getParamsPerLayer()[component1], component1 > 0 || z, list != null ? list.get(component1) : null);
        }
    }

    public static /* bridge */ /* synthetic */ void backward$default(NetworkStructure networkStructure, DenseNDArray denseNDArray, NetworkParameters networkParameters, boolean z, List list, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: backward");
        }
        if ((i & 4) != 0) {
            z = false;
        }
        networkStructure.backward(denseNDArray, networkParameters, z, list);
    }

    private final LayerStructure<?> layerFactory(LayerConfiguration layerConfiguration, LayerConfiguration layerConfiguration2, LayerParameters<?> layerParameters) {
        if (!(layerConfiguration2.getConnectionType() != null)) {
            throw new IllegalArgumentException("Output layer configurations must have a not null connectionType".toString());
        }
        switch (layerConfiguration.getInputType()) {
            case Dense:
                return layerFactory(new AugmentedArray<>(layerConfiguration.getSize()), layerConfiguration2, layerParameters, layerConfiguration.getDropout());
            case Sparse:
                return layerFactory(new AugmentedArray<>(layerConfiguration.getSize()), layerConfiguration2, layerParameters, layerConfiguration.getDropout());
            case SparseBinary:
                return layerFactory(new AugmentedArray<>(layerConfiguration.getSize()), layerConfiguration2, layerParameters, layerConfiguration.getDropout());
            default:
                throw new NoWhenBranchMatchedException();
        }
    }

    @NotNull
    protected abstract <InputNDArrayType extends NDArray<InputNDArrayType>> LayerStructure<InputNDArrayType> layerFactory(@NotNull AugmentedArray<InputNDArrayType> augmentedArray, @NotNull LayerConfiguration layerConfiguration, @NotNull LayerParameters<?> layerParameters, double d);

    /* JADX INFO: Access modifiers changed from: protected */
    @NotNull
    public final List<LayerStructure<?>> buildLayers() {
        LayerStructure<?> layerFactory;
        LayerStructure<?> layerStructure = (LayerStructure) null;
        int size = this.layersConfiguration.size() - 1;
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            int i2 = i;
            if (i2 == 0) {
                layerFactory = layerFactory(this.layersConfiguration.get(0), this.layersConfiguration.get(1), this.params.getParamsPerLayer()[0]);
            } else {
                LayerStructure<?> layerStructure2 = layerStructure;
                if (layerStructure2 == null) {
                    Intrinsics.throwNpe();
                }
                layerFactory = layerFactory(layerStructure2.getOutputArray(), this.layersConfiguration.get(i2 + 1), this.params.getParamsPerLayer()[i2], this.layersConfiguration.get(i2).getDropout());
            }
            LayerStructure<?> layerStructure3 = layerFactory;
            layerStructure = layerStructure3;
            arrayList.add(layerStructure3);
        }
        return arrayList;
    }

    @NotNull
    public final List<LayerConfiguration> getLayersConfiguration() {
        return this.layersConfiguration;
    }

    @NotNull
    public final NetworkParameters getParams() {
        return this.params;
    }

    public NetworkStructure(@NotNull List<LayerConfiguration> layersConfiguration, @NotNull NetworkParameters params) {
        boolean z;
        Intrinsics.checkParameterIsNotNull(layersConfiguration, "layersConfiguration");
        Intrinsics.checkParameterIsNotNull(params, "params");
        this.layersConfiguration = layersConfiguration;
        this.params = params;
        this.layers = buildLayers();
        List<LayerConfiguration> subList = this.layersConfiguration.subList(1, CollectionsKt.getLastIndex(this.layersConfiguration));
        if (!(subList instanceof Collection) || !subList.isEmpty()) {
            Iterator<T> it = subList.iterator();
            while (true) {
                if (!it.hasNext()) {
                    z = true;
                    break;
                } else if (!Intrinsics.areEqual(((LayerConfiguration) it.next()).getInputType(), LayerType.Input.Dense)) {
                    z = false;
                    break;
                }
            }
        } else {
            z = true;
        }
        if (!z) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
    }
}
