package com.kotlinnlp.simplednn.deeplearning.birnn;

import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.functionalities.initializers.GlorotInitializer;
import com.kotlinnlp.simplednn.core.functionalities.initializers.Initializer;
import com.kotlinnlp.simplednn.core.layers.LayerInterface;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.AffineMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.AvgMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.BiaffineMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ConcatFeedforwardMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ConcatMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.MergeConfiguration;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ProductMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.SumMerge;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.utils.Serializer;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: BiRNN.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��V\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\f\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u000b\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018�� 02\u00020\u0001:\u00010B[\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0005\u0012\b\u0010\u0007\u001a\u0004\u0018\u00010\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u000b\u001a\u00020\f\u0012\b\b\u0002\u0010\r\u001a\u00020\u000e\u0012\n\b\u0002\u0010\u000f\u001a\u0004\u0018\u00010\u0010\u0012\n\b\u0002\u0010\u0011\u001a\u0004\u0018\u00010\u0010¢\u0006\u0002\u0010\u0012J\u000e\u0010,\u001a\u00020-2\u0006\u0010.\u001a\u00020/R\u0011\u0010\t\u001a\u00020\n¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u0013\u0010\u0007\u001a\u0004\u0018\u00010\b¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0016R\u0011\u0010\u0006\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0019\u0010\u0018R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u001a\u0010\u001bR\u0011\u0010\u001c\u001a\u00020\u001d¢\u0006\b\n��\u001a\u0004\b\u001e\u0010\u001fR\u0011\u0010 \u001a\u00020!¢\u0006\b\n��\u001a\u0004\b\"\u0010#R\u0011\u0010$\u001a\u00020\u001d¢\u0006\b\n��\u001a\u0004\b%\u0010\u001fR\u0011\u0010&\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b'\u0010\u0018R\u0011\u0010\u000b\u001a\u00020\f¢\u0006\b\n��\u001a\u0004\b(\u0010)R\u0011\u0010*\u001a\u00020\u001d¢\u0006\b\n��\u001a\u0004\b+\u0010\u001f¨\u00061"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNN;", "Ljava/io/Serializable;", "inputType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "inputSize", "", "hiddenSize", "hiddenActivation", "Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "dropout", "", "recurrentConnectionType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Connection;", "outputMergeConfiguration", "Lcom/kotlinnlp/simplednn/core/layers/models/merge/mergeconfig/MergeConfiguration;", "weightsInitializer", "Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;", "biasesInitializer", "(Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;IILcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;DLcom/kotlinnlp/simplednn/core/layers/LayerType$Connection;Lcom/kotlinnlp/simplednn/core/layers/models/merge/mergeconfig/MergeConfiguration;Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;)V", "getDropout", "()D", "getHiddenActivation", "()Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "getHiddenSize", "()I", "getInputSize", "getInputType", "()Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "leftToRightNetwork", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "getLeftToRightNetwork", "()Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "model", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNNParameters;", "getModel", "()Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNNParameters;", "outputMergeNetwork", "getOutputMergeNetwork", "outputSize", "getOutputSize", "getRecurrentConnectionType", "()Lcom/kotlinnlp/simplednn/core/layers/LayerType$Connection;", "rightToLeftNetwork", "getRightToLeftNetwork", "dump", "", "outputStream", "Ljava/io/OutputStream;", "Companion", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/birnn/BiRNN.class */
public final class BiRNN implements Serializable {
    private final int outputSize;

    @NotNull
    private final NeuralNetwork leftToRightNetwork;

    @NotNull
    private final NeuralNetwork rightToLeftNetwork;

    @NotNull
    private final NeuralNetwork outputMergeNetwork;

    @NotNull
    private final BiRNNParameters model;

    @NotNull
    private final LayerType.Input inputType;
    private final int inputSize;
    private final int hiddenSize;

    @Nullable
    private final ActivationFunction hiddenActivation;
    private final double dropout;

    @NotNull
    private final LayerType.Connection recurrentConnectionType;
    private static final long serialVersionUID = 1;
    public static final Companion Companion = new Companion(null);

    /* compiled from: BiRNN.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"�� \n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\t\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u000e\u0010\u0006\u001a\u00020\u00072\u0006\u0010\b\u001a\u00020\tR\u0016\u0010\u0003\u001a\u00020\u00048\u0002X\u0083T¢\u0006\b\n��\u0012\u0004\b\u0005\u0010\u0002¨\u0006\n"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNN$Companion;", "", "()V", "serialVersionUID", "", "serialVersionUID$annotations", "load", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNN;", "inputStream", "Ljava/io/InputStream;", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/birnn/BiRNN$Companion.class */
    public static final class Companion {
        private static /* synthetic */ void serialVersionUID$annotations() {
        }

        @NotNull
        public final BiRNN load(@NotNull InputStream inputStream) {
            Intrinsics.checkParameterIsNotNull(inputStream, "inputStream");
            return (BiRNN) Serializer.INSTANCE.deserialize(inputStream);
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    public final int getOutputSize() {
        return this.outputSize;
    }

    @NotNull
    public final NeuralNetwork getLeftToRightNetwork() {
        return this.leftToRightNetwork;
    }

    @NotNull
    public final NeuralNetwork getRightToLeftNetwork() {
        return this.rightToLeftNetwork;
    }

    @NotNull
    public final NeuralNetwork getOutputMergeNetwork() {
        return this.outputMergeNetwork;
    }

    @NotNull
    public final BiRNNParameters getModel() {
        return this.model;
    }

    public final void dump(@NotNull OutputStream outputStream) {
        Intrinsics.checkParameterIsNotNull(outputStream, "outputStream");
        Serializer.INSTANCE.serialize(this, outputStream);
    }

    @NotNull
    public final LayerType.Input getInputType() {
        return this.inputType;
    }

    public final int getInputSize() {
        return this.inputSize;
    }

    public final int getHiddenSize() {
        return this.hiddenSize;
    }

    @Nullable
    public final ActivationFunction getHiddenActivation() {
        return this.hiddenActivation;
    }

    public final double getDropout() {
        return this.dropout;
    }

    @NotNull
    public final LayerType.Connection getRecurrentConnectionType() {
        return this.recurrentConnectionType;
    }

    public BiRNN(@NotNull LayerType.Input input, int i, int i2, @Nullable ActivationFunction activationFunction, double d, @NotNull LayerType.Connection connection, @NotNull MergeConfiguration mergeConfiguration, @Nullable Initializer initializer, @Nullable Initializer initializer2) {
        int i3;
        Intrinsics.checkParameterIsNotNull(input, "inputType");
        Intrinsics.checkParameterIsNotNull(connection, "recurrentConnectionType");
        Intrinsics.checkParameterIsNotNull(mergeConfiguration, "outputMergeConfiguration");
        this.inputType = input;
        this.inputSize = i;
        this.hiddenSize = i2;
        this.hiddenActivation = activationFunction;
        this.dropout = d;
        this.recurrentConnectionType = connection;
        if (mergeConfiguration instanceof AffineMerge) {
            i3 = ((AffineMerge) mergeConfiguration).getOutputSize();
        } else if (mergeConfiguration instanceof BiaffineMerge) {
            i3 = ((BiaffineMerge) mergeConfiguration).getOutputSize();
        } else if (mergeConfiguration instanceof ConcatFeedforwardMerge) {
            i3 = ((ConcatFeedforwardMerge) mergeConfiguration).getOutputSize();
        } else if (mergeConfiguration instanceof ConcatMerge) {
            i3 = 2 * this.hiddenSize;
        } else if (mergeConfiguration instanceof SumMerge) {
            i3 = this.hiddenSize;
        } else if (mergeConfiguration instanceof ProductMerge) {
            i3 = this.hiddenSize;
        } else {
            if (!(mergeConfiguration instanceof AvgMerge)) {
                throw new RuntimeException("Invalid output merge configuration.");
            }
            i3 = this.hiddenSize;
        }
        this.outputSize = i3;
        this.leftToRightNetwork = new NeuralNetwork(new LayerInterface[]{new LayerInterface(this.inputSize, this.inputType, (LayerType.Connection) null, (ActivationFunction) null, false, this.dropout, 28, (DefaultConstructorMarker) null), new LayerInterface(this.hiddenSize, (LayerType.Input) null, this.recurrentConnectionType, this.hiddenActivation, false, 0.0d, 50, (DefaultConstructorMarker) null)}, initializer, initializer2);
        this.rightToLeftNetwork = new NeuralNetwork(new LayerInterface[]{new LayerInterface(this.inputSize, this.inputType, (LayerType.Connection) null, (ActivationFunction) null, false, this.dropout, 28, (DefaultConstructorMarker) null), new LayerInterface(this.hiddenSize, (LayerType.Input) null, this.recurrentConnectionType, this.hiddenActivation, false, 0.0d, 50, (DefaultConstructorMarker) null)}, initializer, initializer2);
        this.outputMergeNetwork = new NeuralNetwork((List<LayerInterface>) (mergeConfiguration instanceof ConcatFeedforwardMerge ? CollectionsKt.listOf(new LayerInterface[]{new LayerInterface(CollectionsKt.listOf(new Integer[]{Integer.valueOf(this.hiddenSize), Integer.valueOf(this.hiddenSize)}), (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, false, mergeConfiguration.getDropout(), 30, (DefaultConstructorMarker) null), new LayerInterface(2 * this.hiddenSize, (LayerType.Input) null, LayerType.Connection.Concat, (ActivationFunction) null, false, 0.0d, 58, (DefaultConstructorMarker) null), new LayerInterface(((ConcatFeedforwardMerge) mergeConfiguration).getOutputSize(), (LayerType.Input) null, LayerType.Connection.Feedforward, (ActivationFunction) null, false, 0.0d, 58, (DefaultConstructorMarker) null)}) : CollectionsKt.listOf(new LayerInterface[]{new LayerInterface(CollectionsKt.listOf(new Integer[]{Integer.valueOf(this.hiddenSize), Integer.valueOf(this.hiddenSize)}), (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, false, mergeConfiguration.getDropout(), 30, (DefaultConstructorMarker) null), new LayerInterface(this.outputSize, (LayerType.Input) null, mergeConfiguration.getType(), (ActivationFunction) null, false, 0.0d, 58, (DefaultConstructorMarker) null)})), initializer, initializer2);
        this.model = new BiRNNParameters(this.leftToRightNetwork.getModel(), this.rightToLeftNetwork.getModel(), this.outputMergeNetwork.getModel());
        if (!(this.recurrentConnectionType.getProperty() == LayerType.Property.Recurrent)) {
            throw new IllegalArgumentException("required recurrentConnectionType with Recurrent property".toString());
        }
    }

    public /* synthetic */ BiRNN(LayerType.Input input, int i, int i2, ActivationFunction activationFunction, double d, LayerType.Connection connection, MergeConfiguration mergeConfiguration, Initializer initializer, Initializer initializer2, int i3, DefaultConstructorMarker defaultConstructorMarker) {
        this(input, i, i2, activationFunction, (i3 & 16) != 0 ? 0.0d : d, connection, (i3 & 64) != 0 ? new ConcatMerge(0.0d, 1, null) : mergeConfiguration, (i3 & 128) != 0 ? new GlorotInitializer(0.0d, false, 0L, 7, null) : initializer, (i3 & 256) != 0 ? new GlorotInitializer(0.0d, false, 0L, 7, null) : initializer2);
    }
}
