package com.kotlinnlp.simplednn.core.layers;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.models.LayerUnit;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.highway.HighwayLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.merge.affine.AffineLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.affine.AffineLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.merge.avg.AvgLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.avg.AvgLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.merge.biaffine.BiaffineLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.biaffine.BiaffineLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.merge.concat.ConcatLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.concat.ConcatLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.merge.product.ProductLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.product.ProductLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.merge.sum.SumLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.sum.SumLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.LayerContextWindow;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.RecurrentLayerUnit;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.cfn.CFNLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.deltarnn.DeltaRNNLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.gru.GRULayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.indrnn.IndRNNLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.lstm.LSTMLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.ran.RANLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.simple.SimpleRecurrentLayerStructure;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.List;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: LayerStructureFactory.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��F\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n��\bÆ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002Jq\u0010\u0003\u001a\b\u0012\u0004\u0012\u0002H\u00050\u0004\"\u000e\b��\u0010\u0005*\b\u0012\u0004\u0012\u0002H\u00050\u00062\u0012\u0010\u0007\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u0002H\u00050\t0\b2\u0006\u0010\n\u001a\u00020\u000b2\n\u0010\f\u001a\u0006\u0012\u0002\b\u00030\r2\u0006\u0010\u000e\u001a\u00020\u000f2\n\b\u0002\u0010\u0010\u001a\u0004\u0018\u00010\u00112\b\b\u0002\u0010\u0012\u001a\u00020\u00132\n\b\u0002\u0010\u0014\u001a\u0004\u0018\u00010\u0015H\u0086\u0002¨\u0006\u0016"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/LayerStructureFactory;", "", "()V", "invoke", "Lcom/kotlinnlp/simplednn/core/layers/LayerStructure;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "inputArrays", "", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "outputSize", "", "params", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "connectionType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Connection;", "activationFunction", "Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "dropout", "", "contextWindow", "Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/LayerContextWindow;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/LayerStructureFactory.class */
public final class LayerStructureFactory {
    public static final LayerStructureFactory INSTANCE = new LayerStructureFactory();

    @NotNull
    public final <InputNDArrayType extends NDArray<InputNDArrayType>> LayerStructure<InputNDArrayType> invoke(@NotNull List<? extends AugmentedArray<InputNDArrayType>> inputArrays, int i, @NotNull LayerParameters<?> params, @NotNull LayerType.Connection connectionType, @Nullable ActivationFunction activationFunction, double d, @Nullable LayerContextWindow layerContextWindow) {
        Intrinsics.checkParameterIsNotNull(inputArrays, "inputArrays");
        Intrinsics.checkParameterIsNotNull(params, "params");
        Intrinsics.checkParameterIsNotNull(connectionType, "connectionType");
        switch (connectionType) {
            case Feedforward:
                return new FeedforwardLayerStructure((AugmentedArray) CollectionsKt.first((List) inputArrays), new LayerUnit(i), params, activationFunction, d, 0, 32, null);
            case Highway:
                return new HighwayLayerStructure((AugmentedArray) CollectionsKt.first((List) inputArrays), AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null))), params, activationFunction, d, 0, 32, null);
            case Affine:
                return new AffineLayerStructure(inputArrays, AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null))), (AffineLayerParameters) params, activationFunction, d, 0, 32, null);
            case Biaffine:
                return new BiaffineLayerStructure(inputArrays.get(0), inputArrays.get(1), AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null))), (BiaffineLayerParameters) params, activationFunction, d, 0, 64, null);
            case Concat:
                return new ConcatLayerStructure(inputArrays, AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null))), (ConcatLayerParameters) params, 0, 8, null);
            case Sum:
                return new SumLayerStructure(inputArrays, AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null))), (SumLayerParameters) params, 0, 8, null);
            case Avg:
                return new AvgLayerStructure(inputArrays, AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null))), (AvgLayerParameters) params, 0, 8, null);
            case Product:
                return new ProductLayerStructure(inputArrays, AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null))), (ProductLayerParameters) params, 0, 8, null);
            case SimpleRecurrent:
                AugmentedArray augmentedArray = (AugmentedArray) CollectionsKt.first((List) inputArrays);
                RecurrentLayerUnit recurrentLayerUnit = new RecurrentLayerUnit(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new SimpleRecurrentLayerStructure(augmentedArray, recurrentLayerUnit, params, layerContextWindow, activationFunction, d);
            case GRU:
                AugmentedArray augmentedArray2 = (AugmentedArray) CollectionsKt.first((List) inputArrays);
                AugmentedArray invoke = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null)));
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new GRULayerStructure(augmentedArray2, invoke, params, layerContextWindow, activationFunction, d);
            case LSTM:
                AugmentedArray augmentedArray3 = (AugmentedArray) CollectionsKt.first((List) inputArrays);
                AugmentedArray invoke2 = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null)));
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new LSTMLayerStructure(augmentedArray3, invoke2, params, layerContextWindow, activationFunction, d);
            case CFN:
                AugmentedArray augmentedArray4 = (AugmentedArray) CollectionsKt.first((List) inputArrays);
                AugmentedArray invoke3 = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null)));
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new CFNLayerStructure(augmentedArray4, invoke3, params, layerContextWindow, activationFunction, d);
            case RAN:
                AugmentedArray augmentedArray5 = (AugmentedArray) CollectionsKt.first((List) inputArrays);
                AugmentedArray invoke4 = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null)));
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new RANLayerStructure(augmentedArray5, invoke4, params, layerContextWindow, activationFunction, d);
            case DeltaRNN:
                AugmentedArray augmentedArray6 = (AugmentedArray) CollectionsKt.first((List) inputArrays);
                AugmentedArray invoke5 = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.emptyArray(new Shape(i, 0, 2, null)));
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new DeltaRNNLayerStructure(augmentedArray6, invoke5, params, layerContextWindow, activationFunction, d);
            case IndRNN:
                AugmentedArray augmentedArray7 = (AugmentedArray) CollectionsKt.first((List) inputArrays);
                LayerUnit layerUnit = new LayerUnit(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new IndRNNLayerStructure(augmentedArray7, layerUnit, params, layerContextWindow, activationFunction, d);
            default:
                throw new NoWhenBranchMatchedException();
        }
    }

    @NotNull
    public static /* bridge */ /* synthetic */ LayerStructure invoke$default(LayerStructureFactory layerStructureFactory, List list, int i, LayerParameters layerParameters, LayerType.Connection connection, ActivationFunction activationFunction, double d, LayerContextWindow layerContextWindow, int i2, Object obj) {
        if ((i2 & 16) != 0) {
            activationFunction = (ActivationFunction) null;
        }
        if ((i2 & 32) != 0) {
            d = 0.0d;
        }
        if ((i2 & 64) != 0) {
            layerContextWindow = (LayerContextWindow) null;
        }
        return layerStructureFactory.invoke(list, i, layerParameters, connection, activationFunction, d, layerContextWindow);
    }

    private LayerStructureFactory() {
    }
}
