package com.kotlinnlp.simplednn.deeplearning.transformers;

import com.kotlinnlp.simplednn.core.embeddings.EmbeddingsMap;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.functionalities.activations.GeLU;
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax;
import com.kotlinnlp.simplednn.core.functionalities.initializers.GlorotInitializer;
import com.kotlinnlp.simplednn.core.functionalities.initializers.Initializer;
import com.kotlinnlp.simplednn.core.layers.LayerInterface;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.StackedLayersParameters;
import com.kotlinnlp.utils.BaseExtensionsKt;
import com.kotlinnlp.utils.DictionarySet;
import com.kotlinnlp.utils.Serializer;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.collections.MapsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: BERTModel.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��^\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010!\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\b\u0011\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0005\u0018�� ?2\u00020\u0001:\u0002?@Bm\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0003\u0012\u0006\u0010\u0006\u001a\u00020\u0003\u0012\u0006\u0010\u0007\u001a\u00020\u0003\u0012\f\u0010\b\u001a\b\u0012\u0004\u0012\u00020\n0\t\u0012\u0010\b\u0002\u0010\u000b\u001a\n\u0012\u0004\u0012\u00020\n\u0018\u00010\f\u0012\u0006\u0010\r\u001a\u00020\u0003\u0012\n\b\u0002\u0010\u000e\u001a\u0004\u0018\u00010\u000f\u0012\n\b\u0002\u0010\u0010\u001a\u0004\u0018\u00010\u000f¢\u0006\u0002\u0010\u0011J\u000e\u00109\u001a\u00020:2\u0006\u0010;\u001a\u00020<J\u000e\u0010=\u001a\u00020:2\u0006\u0010>\u001a\u00020\u0003R\u0011\u0010\u0005\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0012\u0010\u0013R\u0011\u0010\u0004\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\u0013R\u001a\u0010\u0015\u001a\u00020\u0016X\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0017\u0010\u0018\"\u0004\b\u0019\u0010\u001aR\u0011\u0010\u001b\u001a\u00020\u0016¢\u0006\b\n��\u001a\u0004\b\u001c\u0010\u0018R \u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u001e0\fX\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u001f\u0010 \"\u0004\b!\u0010\"R\u0014\u0010#\u001a\b\u0012\u0004\u0012\u00020%0$X\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b&\u0010\u0013R*\u0010)\u001a\b\u0012\u0004\u0012\u00020%0(2\f\u0010'\u001a\b\u0012\u0004\u0012\u00020%0(@BX\u0086\u000e¢\u0006\b\n��\u001a\u0004\b*\u0010+R\u0011\u0010\u0007\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b,\u0010\u0013R\u0011\u0010\u0006\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b-\u0010\u0013R\u0011\u0010.\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b/\u0010\u0013R\u0017\u00100\u001a\b\u0012\u0004\u0012\u00020\u00030\f¢\u0006\b\n��\u001a\u0004\b1\u0010 R\u0017\u00102\u001a\b\u0012\u0004\u0012\u00020\u00030\f¢\u0006\b\n��\u001a\u0004\b3\u0010 R\u0017\u0010\b\u001a\b\u0012\u0004\u0012\u00020\n0\t¢\u0006\b\n��\u001a\u0004\b4\u00105R\"\u00106\u001a\n\u0012\u0004\u0012\u00020\n\u0018\u00010\fX\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b7\u0010 \"\u0004\b8\u0010\"¨\u0006A"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel;", "Ljava/io/Serializable;", "inputSize", "", "attentionSize", "attentionOutputSize", "outputHiddenSize", "numOfHeads", "vocabulary", "Lcom/kotlinnlp/utils/DictionarySet;", "", "wordEmbeddings", "Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsMap;", "numOfLayers", "weightsInitializer", "Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;", "biasesInitializer", "(IIIIILcom/kotlinnlp/utils/DictionarySet;Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsMap;ILcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;)V", "getAttentionOutputSize", "()I", "getAttentionSize", "classifier", "Lcom/kotlinnlp/simplednn/core/layers/StackedLayersParameters;", "getClassifier", "()Lcom/kotlinnlp/simplednn/core/layers/StackedLayersParameters;", "setClassifier", "(Lcom/kotlinnlp/simplednn/core/layers/StackedLayersParameters;)V", "embNorm", "getEmbNorm", "funcEmb", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel$FuncToken;", "getFuncEmb", "()Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsMap;", "setFuncEmb", "(Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsMap;)V", "initLayers", "", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTParameters;", "getInputSize", "<set-?>", "", "layers", "getLayers", "()Ljava/util/List;", "getNumOfHeads", "getOutputHiddenSize", "outputSize", "getOutputSize", "positionalEmb", "getPositionalEmb", "tokenTypeEmb", "getTokenTypeEmb", "getVocabulary", "()Lcom/kotlinnlp/utils/DictionarySet;", "wordEmb", "getWordEmb", "setWordEmb", "dump", "", "outputStream", "Ljava/io/OutputStream;", "reduceLayersTo", "size", "Companion", "FuncToken", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERTModel.class */
public final class BERTModel implements Serializable {
    private final int outputSize;
    private final List<BERTParameters> initLayers;

    @NotNull
    private List<BERTParameters> layers;

    @NotNull
    private final StackedLayersParameters embNorm;

    @Nullable
    private EmbeddingsMap<String> wordEmb;

    @NotNull
    private EmbeddingsMap<FuncToken> funcEmb;

    @NotNull
    private final EmbeddingsMap<Integer> positionalEmb;

    @NotNull
    private final EmbeddingsMap<Integer> tokenTypeEmb;

    @NotNull
    private StackedLayersParameters classifier;
    private final int inputSize;
    private final int attentionSize;
    private final int attentionOutputSize;
    private final int outputHiddenSize;
    private final int numOfHeads;

    @NotNull
    private final DictionarySet<String> vocabulary;
    private static final long serialVersionUID = 1;
    public static final Companion Companion = new Companion(null);

    /* compiled from: BERTModel.kt */
    @Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"�� \n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\t\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u000e\u0010\u0006\u001a\u00020\u00072\u0006\u0010\b\u001a\u00020\tR\u0014\u0010\u0003\u001a\u00020\u0004X\u0082T¢\u0006\b\n��\u0012\u0004\b\u0005\u0010\u0002¨\u0006\n"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel$Companion;", "", "()V", "serialVersionUID", "", "serialVersionUID$annotations", "load", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel;", "inputStream", "Ljava/io/InputStream;", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERTModel$Companion.class */
    public static final class Companion {
        private static /* synthetic */ void serialVersionUID$annotations() {
        }

        @NotNull
        public final BERTModel load(@NotNull InputStream inputStream) {
            Intrinsics.checkParameterIsNotNull(inputStream, "inputStream");
            return (BERTModel) Serializer.INSTANCE.deserialize(inputStream);
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    /* compiled from: BERTModel.kt */
    @Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��\u0012\n\u0002\u0018\u0002\n\u0002\u0010\u0010\n��\n\u0002\u0010\u000e\n\u0002\b\n\b\u0086\u0001\u0018�� \f2\b\u0012\u0004\u0012\u00020��0\u0001:\u0001\fB\u000f\b\u0002\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0005\u0010\u0006j\u0002\b\u0007j\u0002\b\bj\u0002\b\tj\u0002\b\nj\u0002\b\u000b¨\u0006\r"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel$FuncToken;", "", "form", "", "(Ljava/lang/String;ILjava/lang/String;)V", "getForm", "()Ljava/lang/String;", "CLS", "SEP", "PAD", "UNK", "MASK", "Companion", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERTModel$FuncToken.class */
    public enum FuncToken {
        CLS("[CLS]"),
        SEP("[SEP]"),
        PAD("[PAD]"),
        UNK("[UNK]"),
        MASK("[MASK]");


        @NotNull
        private final String form;
        private static final Map<String, FuncToken> tokensByForm;
        public static final Companion Companion = new Companion(null);

        /* compiled from: BERTModel.kt */
        @Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��\u001c\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010$\n\u0002\u0010\u000e\n\u0002\u0018\u0002\n\u0002\b\u0003\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u000e\u0010\u0007\u001a\u00020\u00062\u0006\u0010\b\u001a\u00020\u0005R\u001a\u0010\u0003\u001a\u000e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u00060\u0004X\u0082\u0004¢\u0006\u0002\n��¨\u0006\t"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel$FuncToken$Companion;", "", "()V", "tokensByForm", "", "", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel$FuncToken;", "byForm", "form", "simplednn"})
        /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERTModel$FuncToken$Companion.class */
        public static final class Companion {
            @NotNull
            public final FuncToken byForm(@NotNull String str) {
                Intrinsics.checkParameterIsNotNull(str, "form");
                return (FuncToken) MapsKt.getValue(FuncToken.tokensByForm, str);
            }

            private Companion() {
            }

            public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
                this();
            }
        }

        static {
            FuncToken[] values = values();
            LinkedHashMap linkedHashMap = new LinkedHashMap(RangesKt.coerceAtLeast(MapsKt.mapCapacity(values.length), 16));
            for (FuncToken funcToken : values) {
                linkedHashMap.put(funcToken.form, funcToken);
            }
            tokensByForm = linkedHashMap;
        }

        @NotNull
        public final String getForm() {
            return this.form;
        }

        FuncToken(String str) {
            this.form = str;
        }
    }

    public final int getOutputSize() {
        return this.outputSize;
    }

    @NotNull
    public final List<BERTParameters> getLayers() {
        return this.layers;
    }

    @NotNull
    public final StackedLayersParameters getEmbNorm() {
        return this.embNorm;
    }

    @Nullable
    public final EmbeddingsMap<String> getWordEmb() {
        return this.wordEmb;
    }

    public final void setWordEmb(@Nullable EmbeddingsMap<String> embeddingsMap) {
        this.wordEmb = embeddingsMap;
    }

    @NotNull
    public final EmbeddingsMap<FuncToken> getFuncEmb() {
        return this.funcEmb;
    }

    public final void setFuncEmb(@NotNull EmbeddingsMap<FuncToken> embeddingsMap) {
        Intrinsics.checkParameterIsNotNull(embeddingsMap, "<set-?>");
        this.funcEmb = embeddingsMap;
    }

    @NotNull
    public final EmbeddingsMap<Integer> getPositionalEmb() {
        return this.positionalEmb;
    }

    @NotNull
    public final EmbeddingsMap<Integer> getTokenTypeEmb() {
        return this.tokenTypeEmb;
    }

    @NotNull
    public final StackedLayersParameters getClassifier() {
        return this.classifier;
    }

    public final void setClassifier(@NotNull StackedLayersParameters stackedLayersParameters) {
        Intrinsics.checkParameterIsNotNull(stackedLayersParameters, "<set-?>");
        this.classifier = stackedLayersParameters;
    }

    public final void dump(@NotNull OutputStream outputStream) {
        Intrinsics.checkParameterIsNotNull(outputStream, "outputStream");
        Serializer.INSTANCE.serialize(this, outputStream);
    }

    public final void reduceLayersTo(int i) {
        if (!(i < this.layers.size())) {
            throw new IllegalArgumentException(("The reducing size (" + i + ") must be lower than the current layer size (" + this.layers.size() + ')').toString());
        }
        BaseExtensionsKt.removeFrom(this.initLayers, i);
        this.layers = CollectionsKt.toList(this.initLayers);
    }

    public final int getInputSize() {
        return this.inputSize;
    }

    public final int getAttentionSize() {
        return this.attentionSize;
    }

    public final int getAttentionOutputSize() {
        return this.attentionOutputSize;
    }

    public final int getOutputHiddenSize() {
        return this.outputHiddenSize;
    }

    public final int getNumOfHeads() {
        return this.numOfHeads;
    }

    @NotNull
    public final DictionarySet<String> getVocabulary() {
        return this.vocabulary;
    }

    public BERTModel(int i, int i2, int i3, int i4, int i5, @NotNull DictionarySet<String> dictionarySet, @Nullable EmbeddingsMap<String> embeddingsMap, int i6, @Nullable Initializer initializer, @Nullable Initializer initializer2) {
        Intrinsics.checkParameterIsNotNull(dictionarySet, "vocabulary");
        this.inputSize = i;
        this.attentionSize = i2;
        this.attentionOutputSize = i3;
        this.outputHiddenSize = i4;
        this.numOfHeads = i5;
        this.vocabulary = dictionarySet;
        this.outputSize = this.inputSize;
        ArrayList arrayList = new ArrayList(i6);
        for (int i7 = 0; i7 < i6; i7++) {
            arrayList.add(new BERTParameters(this.inputSize, this.attentionSize, this.attentionOutputSize, this.outputHiddenSize, this.numOfHeads, initializer, initializer2));
        }
        this.initLayers = arrayList;
        this.layers = CollectionsKt.toList(this.initLayers);
        this.embNorm = new StackedLayersParameters(new LayerInterface[]{new LayerInterface(this.inputSize, (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, 14, (DefaultConstructorMarker) null), new LayerInterface(this.inputSize, (LayerType.Input) null, LayerType.Connection.Norm, (ActivationFunction) null, 10, (DefaultConstructorMarker) null)}, initializer, initializer2);
        BERTModel bERTModel = this;
        EmbeddingsMap<String> embeddingsMap2 = embeddingsMap;
        if (embeddingsMap2 == null) {
            EmbeddingsMap<String> embeddingsMap3 = new EmbeddingsMap<>(this.inputSize, null, false, 6, null);
            Iterator it = this.vocabulary.getElements().iterator();
            while (it.hasNext()) {
                EmbeddingsMap.set$default(embeddingsMap3, (String) it.next(), null, 2, null);
            }
            bERTModel = bERTModel;
            embeddingsMap2 = embeddingsMap3;
        }
        bERTModel.wordEmb = embeddingsMap2;
        EmbeddingsMap<FuncToken> embeddingsMap4 = new EmbeddingsMap<>(this.inputSize, null, false, 6, null);
        for (FuncToken funcToken : FuncToken.values()) {
            EmbeddingsMap<String> embeddingsMap5 = this.wordEmb;
            if (embeddingsMap5 == null) {
                Intrinsics.throwNpe();
            }
            embeddingsMap4.set(funcToken, embeddingsMap5.getOrNull(funcToken.getForm()));
        }
        this.funcEmb = embeddingsMap4;
        this.positionalEmb = new EmbeddingsMap<>(this.inputSize, null, false, 6, null);
        EmbeddingsMap<Integer> embeddingsMap6 = new EmbeddingsMap<>(this.inputSize, null, false, 6, null);
        EmbeddingsMap.set$default(embeddingsMap6, 0, null, 2, null);
        EmbeddingsMap.set$default(embeddingsMap6, 1, null, 2, null);
        this.tokenTypeEmb = embeddingsMap6;
        this.classifier = new StackedLayersParameters(new LayerInterface[]{new LayerInterface(this.inputSize, (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, 14, (DefaultConstructorMarker) null), new LayerInterface(this.inputSize, (LayerType.Input) null, LayerType.Connection.Feedforward, GeLU.INSTANCE, 2, (DefaultConstructorMarker) null), new LayerInterface(this.inputSize, (LayerType.Input) null, LayerType.Connection.Norm, (ActivationFunction) null, 10, (DefaultConstructorMarker) null), new LayerInterface(this.vocabulary.getSize(), (LayerType.Input) null, LayerType.Connection.Feedforward, new Softmax(), 2, (DefaultConstructorMarker) null)}, (Initializer) null, (Initializer) null, 6, (DefaultConstructorMarker) null);
    }

    public /* synthetic */ BERTModel(int i, int i2, int i3, int i4, int i5, DictionarySet dictionarySet, EmbeddingsMap embeddingsMap, int i6, Initializer initializer, Initializer initializer2, int i7, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, i2, i3, i4, i5, dictionarySet, (i7 & 64) != 0 ? (EmbeddingsMap) null : embeddingsMap, i6, (i7 & 256) != 0 ? new GlorotInitializer(0.0d, false, 0L, 7, null) : initializer, (i7 & 512) != 0 ? new GlorotInitializer(0.0d, false, 0L, 7, null) : initializer2);
    }
}
