package com.kotlinnlp.tokensencoder.ensemble;

import com.kotlinnlp.linguisticdescription.sentence.Sentence;
import com.kotlinnlp.linguisticdescription.sentence.token.Token;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.functionalities.initializers.GlorotInitializer;
import com.kotlinnlp.simplednn.core.functionalities.initializers.Initializer;
import com.kotlinnlp.simplednn.core.layers.LayerInterface;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.StackedLayersParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.AffineMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.AvgMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.BiaffineMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ConcatFeedforwardMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ConcatMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.MergeConfiguration;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.ProductMerge;
import com.kotlinnlp.simplednn.core.layers.models.merge.mergeconfig.SumMerge;
import com.kotlinnlp.tokensencoder.TokensEncoder;
import com.kotlinnlp.tokensencoder.TokensEncoderModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: EnsembleTokensEncoderModel.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��R\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\u000e\n\u0002\b\u0003\u0018�� !*\b\b��\u0010\u0001*\u00020\u0002*\u000e\b\u0001\u0010\u0003*\b\u0012\u0004\u0012\u0002H\u00010\u00042\u000e\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u00030\u0005:\u0002!\"BA\u0012\u0018\u0010\u0006\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\b0\u0007\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u0012\n\b\u0002\u0010\u000b\u001a\u0004\u0018\u00010\f\u0012\n\b\u0002\u0010\r\u001a\u0004\u0018\u00010\f¢\u0006\u0002\u0010\u000eJ$\u0010\u001a\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u001b2\u0006\u0010\u001c\u001a\u00020\u001d2\u0006\u0010\u001e\u001a\u00020\u0012H\u0016J\b\u0010\u001f\u001a\u00020 H\u0016R#\u0010\u0006\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\b0\u0007¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u000e\u0010\u0011\u001a\u00020\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0013\u001a\u00020\u0014¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0016R\u0014\u0010\u0017\u001a\u00020\u0012X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0018\u0010\u0019¨\u0006#"}, d2 = {"Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel;", "TokenType", "Lcom/kotlinnlp/linguisticdescription/sentence/token/Token;", "SentenceType", "Lcom/kotlinnlp/linguisticdescription/sentence/Sentence;", "Lcom/kotlinnlp/tokensencoder/TokensEncoderModel;", "components", "", "Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel$ComponentModel;", "outputMergeConfiguration", "Lcom/kotlinnlp/simplednn/core/layers/models/merge/mergeconfig/MergeConfiguration;", "weightsInitializer", "Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;", "biasesInitializer", "(Ljava/util/List;Lcom/kotlinnlp/simplednn/core/layers/models/merge/mergeconfig/MergeConfiguration;Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;Lcom/kotlinnlp/simplednn/core/functionalities/initializers/Initializer;)V", "getComponents", "()Ljava/util/List;", "mergeOutputSize", "", "outputMergeNetwork", "Lcom/kotlinnlp/simplednn/core/layers/StackedLayersParameters;", "getOutputMergeNetwork", "()Lcom/kotlinnlp/simplednn/core/layers/StackedLayersParameters;", "tokenEncodingSize", "getTokenEncodingSize", "()I", "buildEncoder", "Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoder;", "useDropout", "", "id", "toString", "", "Companion", "ComponentModel", "tokensencoder"})
/* loaded from: input_file:com/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel.class */
public final class EnsembleTokensEncoderModel<TokenType extends Token, SentenceType extends Sentence<TokenType>> implements TokensEncoderModel<TokenType, SentenceType> {
    private final int mergeOutputSize;
    private final int tokenEncodingSize;

    @NotNull
    private final StackedLayersParameters outputMergeNetwork;

    @NotNull
    private final List<ComponentModel<TokenType, SentenceType>> components;
    private static final long serialVersionUID = 1;
    public static final Companion Companion = new Companion(null);

    /* compiled from: EnsembleTokensEncoderModel.kt */
    @Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��\u0014\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\t\n\u0002\b\u0002\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002R\u0014\u0010\u0003\u001a\u00020\u0004X\u0082T¢\u0006\b\n��\u0012\u0004\b\u0005\u0010\u0002¨\u0006\u0006"}, d2 = {"Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel$Companion;", "", "()V", "serialVersionUID", "", "serialVersionUID$annotations", "tokensencoder"})
    /* loaded from: input_file:com/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel$Companion.class */
    public static final class Companion {
        private static /* synthetic */ void serialVersionUID$annotations() {
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    /* compiled from: EnsembleTokensEncoderModel.kt */
    @Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��0\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0004\n\u0002\u0010\b\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��*\b\b\u0002\u0010\u0001*\u00020\u0002*\u000e\b\u0003\u0010\u0003*\b\u0012\u0004\u0012\u0002H\u00010\u00042\u000e\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u00030\u0005B!\u0012\u0012\u0010\u0006\u001a\u000e\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u00030\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJ'\u0010\u0012\u001a\u000e\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u00030\u00132\u0006\u0010\u0014\u001a\u00020\b2\b\b\u0002\u0010\u0015\u001a\u00020\rH\u0096\u0001R\u001d\u0010\u0006\u001a\u000e\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u00030\u0005¢\u0006\b\n��\u001a\u0004\b\n\u0010\u000bR\u0012\u0010\f\u001a\u00020\rX\u0096\u0005¢\u0006\u0006\u001a\u0004\b\u000e\u0010\u000fR\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011¨\u0006\u0016"}, d2 = {"Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel$ComponentModel;", "TokenType", "Lcom/kotlinnlp/linguisticdescription/sentence/token/Token;", "SentenceType", "Lcom/kotlinnlp/linguisticdescription/sentence/Sentence;", "Lcom/kotlinnlp/tokensencoder/TokensEncoderModel;", "model", "trainable", "", "(Lcom/kotlinnlp/tokensencoder/TokensEncoderModel;Z)V", "getModel", "()Lcom/kotlinnlp/tokensencoder/TokensEncoderModel;", "tokenEncodingSize", "", "getTokenEncodingSize", "()I", "getTrainable", "()Z", "buildEncoder", "Lcom/kotlinnlp/tokensencoder/TokensEncoder;", "useDropout", "id", "tokensencoder"})
    /* loaded from: input_file:com/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel$ComponentModel.class */
    public static final class ComponentModel<TokenType extends Token, SentenceType extends Sentence<TokenType>> implements TokensEncoderModel<TokenType, SentenceType> {

        @NotNull
        private final TokensEncoderModel<TokenType, SentenceType> model;
        private final boolean trainable;

        @NotNull
        public final TokensEncoderModel<TokenType, SentenceType> getModel() {
            return this.model;
        }

        public final boolean getTrainable() {
            return this.trainable;
        }

        public ComponentModel(@NotNull TokensEncoderModel<TokenType, SentenceType> tokensEncoderModel, boolean z) {
            Intrinsics.checkParameterIsNotNull(tokensEncoderModel, "model");
            this.model = tokensEncoderModel;
            this.trainable = z;
        }

        @Override // com.kotlinnlp.tokensencoder.TokensEncoderModel
        public int getTokenEncodingSize() {
            return this.model.getTokenEncodingSize();
        }

        @Override // com.kotlinnlp.tokensencoder.TokensEncoderModel
        @NotNull
        /* renamed from: buildEncoder */
        public TokensEncoder<TokenType, SentenceType> buildEncoder2(boolean z, int i) {
            return this.model.buildEncoder2(z, i);
        }
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoderModel
    public int getTokenEncodingSize() {
        return this.tokenEncodingSize;
    }

    @NotNull
    public final StackedLayersParameters getOutputMergeNetwork() {
        return this.outputMergeNetwork;
    }

    @NotNull
    public String toString() {
        Object[] objArr = {Integer.valueOf(getTokenEncodingSize())};
        String format = String.format("encoding size %d", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        return format;
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoderModel
    @NotNull
    /* renamed from: buildEncoder */
    public EnsembleTokensEncoder<TokenType, SentenceType> buildEncoder2(boolean z, int i) {
        return new EnsembleTokensEncoder<>(this, z, i);
    }

    @NotNull
    public final List<ComponentModel<TokenType, SentenceType>> getComponents() {
        return this.components;
    }

    public EnsembleTokensEncoderModel(@NotNull List<ComponentModel<TokenType, SentenceType>> list, @NotNull MergeConfiguration mergeConfiguration, @Nullable Initializer initializer, @Nullable Initializer initializer2) {
        boolean z;
        int tokenEncodingSize;
        EnsembleTokensEncoderModel<TokenType, SentenceType> ensembleTokensEncoderModel;
        List listOf;
        Intrinsics.checkParameterIsNotNull(list, "components");
        Intrinsics.checkParameterIsNotNull(mergeConfiguration, "outputMergeConfiguration");
        this.components = list;
        EnsembleTokensEncoderModel<TokenType, SentenceType> ensembleTokensEncoderModel2 = this;
        if (mergeConfiguration instanceof AffineMerge) {
            tokenEncodingSize = ((AffineMerge) mergeConfiguration).getOutputSize();
        } else if (mergeConfiguration instanceof BiaffineMerge) {
            tokenEncodingSize = ((BiaffineMerge) mergeConfiguration).getOutputSize();
        } else if (mergeConfiguration instanceof ConcatFeedforwardMerge) {
            tokenEncodingSize = ((ConcatFeedforwardMerge) mergeConfiguration).getOutputSize();
        } else if (mergeConfiguration instanceof ConcatMerge) {
            int i = 0;
            Iterator<T> it = this.components.iterator();
            while (it.hasNext()) {
                i += ((ComponentModel) it.next()).getTokenEncodingSize();
            }
            ensembleTokensEncoderModel2 = ensembleTokensEncoderModel2;
            tokenEncodingSize = i;
        } else {
            if (!(mergeConfiguration instanceof SumMerge) && !(mergeConfiguration instanceof ProductMerge) && !(mergeConfiguration instanceof AvgMerge)) {
                throw new RuntimeException("Invalid output merge configuration.");
            }
            List<ComponentModel<TokenType, SentenceType>> list2 = this.components;
            if (!(list2 instanceof Collection) || !list2.isEmpty()) {
                Iterator<T> it2 = list2.iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        z = true;
                        break;
                    } else {
                        if (!(((ComponentModel) it2.next()).getTokenEncodingSize() == this.components.get(0).getTokenEncodingSize())) {
                            z = false;
                            break;
                        }
                    }
                }
            } else {
                z = true;
            }
            if (!z) {
                throw new IllegalArgumentException("Failed requirement.".toString());
            }
            ensembleTokensEncoderModel2 = ensembleTokensEncoderModel2;
            tokenEncodingSize = this.components.get(0).getTokenEncodingSize();
        }
        ensembleTokensEncoderModel2.mergeOutputSize = tokenEncodingSize;
        this.tokenEncodingSize = this.mergeOutputSize;
        if (mergeConfiguration instanceof ConcatFeedforwardMerge) {
            LayerInterface[] layerInterfaceArr = new LayerInterface[3];
            List<ComponentModel<TokenType, SentenceType>> list3 = this.components;
            ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list3, 10));
            Iterator<T> it3 = list3.iterator();
            while (it3.hasNext()) {
                arrayList.add(Integer.valueOf(((ComponentModel) it3.next()).getTokenEncodingSize()));
            }
            layerInterfaceArr[0] = new LayerInterface(arrayList, (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, mergeConfiguration.getDropout(), 14, (DefaultConstructorMarker) null);
            int i2 = 0;
            Iterator<T> it4 = this.components.iterator();
            while (it4.hasNext()) {
                i2 += ((ComponentModel) it4.next()).getTokenEncodingSize();
            }
            ensembleTokensEncoderModel = this;
            layerInterfaceArr[1] = new LayerInterface(i2, (LayerType.Input) null, LayerType.Connection.Concat, (ActivationFunction) null, 0.0d, 26, (DefaultConstructorMarker) null);
            layerInterfaceArr[2] = new LayerInterface(((ConcatFeedforwardMerge) mergeConfiguration).getOutputSize(), (LayerType.Input) null, LayerType.Connection.Feedforward, (ActivationFunction) null, 0.0d, 26, (DefaultConstructorMarker) null);
            listOf = CollectionsKt.listOf(layerInterfaceArr);
        } else {
            LayerInterface[] layerInterfaceArr2 = new LayerInterface[2];
            List<ComponentModel<TokenType, SentenceType>> list4 = this.components;
            ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list4, 10));
            Iterator<T> it5 = list4.iterator();
            while (it5.hasNext()) {
                arrayList2.add(Integer.valueOf(((ComponentModel) it5.next()).getTokenEncodingSize()));
            }
            ArrayList arrayList3 = arrayList2;
            ensembleTokensEncoderModel = this;
            layerInterfaceArr2[0] = new LayerInterface(arrayList3, (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, mergeConfiguration.getDropout(), 14, (DefaultConstructorMarker) null);
            layerInterfaceArr2[1] = new LayerInterface(this.mergeOutputSize, (LayerType.Input) null, mergeConfiguration.getType(), (ActivationFunction) null, 0.0d, 26, (DefaultConstructorMarker) null);
            listOf = CollectionsKt.listOf(layerInterfaceArr2);
        }
        ensembleTokensEncoderModel.outputMergeNetwork = new StackedLayersParameters(listOf, initializer, initializer2);
    }

    public /* synthetic */ EnsembleTokensEncoderModel(List list, MergeConfiguration mergeConfiguration, Initializer initializer, Initializer initializer2, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(list, (i & 2) != 0 ? (MergeConfiguration) new ConcatMerge(0.0d, 1, (DefaultConstructorMarker) null) : mergeConfiguration, (i & 4) != 0 ? (Initializer) new GlorotInitializer(0.0d, false, 0L, 7, (DefaultConstructorMarker) null) : initializer, (i & 8) != 0 ? (Initializer) null : initializer2);
    }
}
