package com.kotlinnlp.tokensencoder.charlm;

import com.kotlinnlp.languagemodel.CharLM;
import com.kotlinnlp.linguisticdescription.sentence.Sentence;
import com.kotlinnlp.linguisticdescription.sentence.token.FormToken;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.functionalities.initializers.GlorotInitializer;
import com.kotlinnlp.simplednn.core.functionalities.initializers.Initializer;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.layers.LayerInterface;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.AffineMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.AvgMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.BiaffineMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ConcatFeedforwardMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ConcatMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.MergeConfiguration;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ProductMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.SumMerge;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.tokensencoder.TokensEncoder;
import com.kotlinnlp.tokensencoder.TokensEncoderModel;
import com.kotlinnlp.tokensencoder.TokensEncoderOptimizer;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CharLMEncoderModel.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��T\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018�� !2\u0014\u0012\u0004\u0012\u00020\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00020\u00030\u0001:\u0001!B7\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b\u0012\n\b\u0002\u0010\t\u001a\u0004\u0018\u00010\n\u0012\n\b\u0002\u0010\u000b\u001a\u0004\u0018\u00010\n¢\u0006\u0002\u0010\fJ\u0018\u0010\u0018\u001a\u00020\u00192\u0006\u0010\u001a\u001a\u00020\u001b2\u0006\u0010\u001c\u001a\u00020\u0015H\u0016J\u0014\u0010\u001d\u001a\u00020\u001e2\n\u0010\u001f\u001a\u0006\u0012\u0002\b\u00030 H\u0016R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\r\u0010\u000eR\u0011\u0010\u000f\u001a\u00020\u0010¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u0011\u0010\u0006\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u000eR\u0014\u0010\u0014\u001a\u00020\u0015X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u0017¨\u0006\""}, d2 = {"Lcom/kotlinnlp/tokensencoder/charlm/CharLMEncoderModel;", "Lcom/kotlinnlp/tokensencoder/TokensEncoderModel;", "Lcom/kotlinnlp/linguisticdescription/sentence/token/FormToken;", "Lcom/kotlinnlp/linguisticdescription/sentence/Sentence;", "charLM", "Lcom/kotlinnlp/languagemodel/CharLM;", "revCharLM", "outputMergeConfiguration", "Lcom/kotlinnlp/simplednn/core/layers/models/merge/mergeconfig/MergeConfiguration;", "weightsInitializer", "Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;", "biasesInitializer", "(Lcom/kotlinnlp/languagemodel/CharLM;Lcom/kotlinnlp/languagemodel/CharLM;Lcom/kotlinnlp/simplednn/core/layers/models/merge/mergeconfig/MergeConfiguration;Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;)V", "getCharLM", "()Lcom/kotlinnlp/languagemodel/CharLM;", "outputMergeNetwork", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "getOutputMergeNetwork", "()Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "getRevCharLM", "tokenEncodingSize", "", "getTokenEncodingSize", "()I", "buildEncoder", "Lcom/kotlinnlp/tokensencoder/charlm/CharLMEncoder;", "useDropout", "", "id", "buildOptimizer", "Lcom/kotlinnlp/tokensencoder/charlm/CharLMEncoderOptimizer;", "updateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "Companion", "tokensencoder"})
/* loaded from: input_file:com/kotlinnlp/tokensencoder/charlm/CharLMEncoderModel.class */
public final class CharLMEncoderModel implements TokensEncoderModel<FormToken, Sentence<FormToken>> {
    private final int tokenEncodingSize;

    @NotNull
    private final NeuralNetwork outputMergeNetwork;

    @NotNull
    private final CharLM charLM;

    @NotNull
    private final CharLM revCharLM;
    private static final long serialVersionUID = 1;
    public static final Companion Companion = new Companion(null);

    /* compiled from: CharLMEncoderModel.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0014\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\t\n\u0002\b\u0002\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002R\u0016\u0010\u0003\u001a\u00020\u00048\u0002X\u0083T¢\u0006\b\n��\u0012\u0004\b\u0005\u0010\u0002¨\u0006\u0006"}, d2 = {"Lcom/kotlinnlp/tokensencoder/charlm/CharLMEncoderModel$Companion;", "", "()V", "serialVersionUID", "", "serialVersionUID$annotations", "tokensencoder"})
    /* loaded from: input_file:com/kotlinnlp/tokensencoder/charlm/CharLMEncoderModel$Companion.class */
    public static final class Companion {
        private static /* synthetic */ void serialVersionUID$annotations() {
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoderModel
    public int getTokenEncodingSize() {
        return this.tokenEncodingSize;
    }

    @NotNull
    public final NeuralNetwork getOutputMergeNetwork() {
        return this.outputMergeNetwork;
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoderModel
    @NotNull
    /* renamed from: buildEncoder */
    public TokensEncoder<FormToken, Sentence<FormToken>> buildEncoder2(boolean z, int i) {
        return new CharLMEncoder(this, i);
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoderModel
    @NotNull
    public CharLMEncoderOptimizer buildOptimizer(@NotNull UpdateMethod<?> updateMethod) {
        Intrinsics.checkParameterIsNotNull(updateMethod, "updateMethod");
        return new CharLMEncoderOptimizer(this, updateMethod);
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoderModel
    public /* bridge */ /* synthetic */ TokensEncoderOptimizer buildOptimizer(UpdateMethod updateMethod) {
        return buildOptimizer((UpdateMethod<?>) updateMethod);
    }

    @NotNull
    public final CharLM getCharLM() {
        return this.charLM;
    }

    @NotNull
    public final CharLM getRevCharLM() {
        return this.revCharLM;
    }

    public CharLMEncoderModel(@NotNull CharLM charLM, @NotNull CharLM revCharLM, @NotNull MergeConfiguration outputMergeConfiguration, @Nullable Initializer initializer, @Nullable Initializer initializer2) {
        int outputSize;
        Intrinsics.checkParameterIsNotNull(charLM, "charLM");
        Intrinsics.checkParameterIsNotNull(revCharLM, "revCharLM");
        Intrinsics.checkParameterIsNotNull(outputMergeConfiguration, "outputMergeConfiguration");
        this.charLM = charLM;
        this.revCharLM = revCharLM;
        if (!(!this.charLM.getReverseModel())) {
            throw new IllegalArgumentException("The charLM must be trained to process the sequence from left to right.".toString());
        }
        if (!this.revCharLM.getReverseModel()) {
            throw new IllegalArgumentException("The revCharLM must be trained to process the sequence from right to left.".toString());
        }
        if (!(this.charLM.getRecurrentNetwork().getOutputSize() == this.revCharLM.getRecurrentNetwork().getOutputSize())) {
            throw new IllegalArgumentException("The charLM and the reverse CharLM must have the same recurrent hidden size.".toString());
        }
        if (outputMergeConfiguration instanceof AffineMerge) {
            outputSize = ((AffineMerge) outputMergeConfiguration).getOutputSize();
        } else if (outputMergeConfiguration instanceof BiaffineMerge) {
            outputSize = ((BiaffineMerge) outputMergeConfiguration).getOutputSize();
        } else if (outputMergeConfiguration instanceof ConcatFeedforwardMerge) {
            outputSize = ((ConcatFeedforwardMerge) outputMergeConfiguration).getOutputSize();
        } else if (outputMergeConfiguration instanceof ConcatMerge) {
            outputSize = 2 * this.charLM.getRecurrentNetwork().getOutputSize();
        } else if (outputMergeConfiguration instanceof SumMerge) {
            outputSize = this.charLM.getRecurrentNetwork().getOutputSize();
        } else if (outputMergeConfiguration instanceof ProductMerge) {
            outputSize = this.charLM.getRecurrentNetwork().getOutputSize();
        } else {
            if (!(outputMergeConfiguration instanceof AvgMerge)) {
                throw new RuntimeException("Invalid output merge configuration.");
            }
            outputSize = this.charLM.getRecurrentNetwork().getOutputSize();
        }
        this.tokenEncodingSize = outputSize;
        this.outputMergeNetwork = new NeuralNetwork((List<LayerInterface>) (outputMergeConfiguration instanceof ConcatFeedforwardMerge ? CollectionsKt.listOf((Object[]) new LayerInterface[]{new LayerInterface(CollectionsKt.listOf((Object[]) new Integer[]{Integer.valueOf(this.charLM.getRecurrentNetwork().getOutputSize()), Integer.valueOf(this.revCharLM.getRecurrentNetwork().getOutputSize())}), (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, false, outputMergeConfiguration.getDropout(), 30, (DefaultConstructorMarker) null), new LayerInterface(2 * this.charLM.getRecurrentNetwork().getOutputSize(), (LayerType.Input) null, LayerType.Connection.Concat, (ActivationFunction) null, false, 0.0d, 58, (DefaultConstructorMarker) null), new LayerInterface(((ConcatFeedforwardMerge) outputMergeConfiguration).getOutputSize(), (LayerType.Input) null, LayerType.Connection.Feedforward, ((ConcatFeedforwardMerge) outputMergeConfiguration).getActivationFunction(), false, 0.0d, 50, (DefaultConstructorMarker) null)}) : CollectionsKt.listOf((Object[]) new LayerInterface[]{new LayerInterface(CollectionsKt.listOf((Object[]) new Integer[]{Integer.valueOf(this.charLM.getRecurrentNetwork().getOutputSize()), Integer.valueOf(this.revCharLM.getRecurrentNetwork().getOutputSize())}), (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, false, outputMergeConfiguration.getDropout(), 30, (DefaultConstructorMarker) null), new LayerInterface(getTokenEncodingSize(), (LayerType.Input) null, outputMergeConfiguration.getType(), (ActivationFunction) null, false, 0.0d, 58, (DefaultConstructorMarker) null)})), initializer, initializer2);
    }

    public /* synthetic */ CharLMEncoderModel(CharLM charLM, CharLM charLM2, MergeConfiguration mergeConfiguration, Initializer initializer, Initializer initializer2, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(charLM, charLM2, (i & 4) != 0 ? new ConcatMerge(0.0d, 1, null) : mergeConfiguration, (i & 8) != 0 ? new GlorotInitializer(0.0d, false, 0L, 7, null) : initializer, (i & 16) != 0 ? (Initializer) null : initializer2);
    }
}
