package com.kotlinnlp.tokensencoder.ensemble;

import com.kotlinnlp.linguisticdescription.sentence.Sentence;
import com.kotlinnlp.linguisticdescription.sentence.token.Token;
import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.layers.StackedLayersParameters;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.tokensencoder.TokensEncoder;
import com.kotlinnlp.tokensencoder.ensemble.EnsembleTokensEncoderModel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: EnsembleTokensEncoder.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��^\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018��*\b\b��\u0010\u0001*\u00020\u0002*\u000e\b\u0001\u0010\u0003*\b\u0012\u0004\u0012\u0002H\u00010\u00042\u000e\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u00030\u0005B+\u0012\u0012\u0010\u0006\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0007\u0012\u0006\u0010\b\u001a\u00020\t\u0012\b\b\u0002\u0010\n\u001a\u00020\u000b¢\u0006\u0002\u0010\fJ\u0016\u0010\u0018\u001a\u00020\u00192\f\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u00150\u000eH\u0016J\u001b\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u00150\u000e2\u0006\u0010\u001c\u001a\u00028\u0001H\u0016¢\u0006\u0002\u0010\u001dJ\"\u0010\u001e\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030\u001fR\u00020 0\u000ej\u0002`!2\u0006\u0010\"\u001a\u00020\tH\u0002J\u0010\u0010#\u001a\u00020$2\u0006\u0010\"\u001a\u00020\tH\u0016J\"\u0010%\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030\u001fR\u00020 0\u000ej\u0002`!2\u0006\u0010\"\u001a\u00020\tH\u0016R \u0010\r\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u00050\u000eX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\n\u001a\u00020\u000bX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R \u0010\u0006\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0007X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u0014\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u00150\u0014X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\b\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u0017¨\u0006&"}, d2 = {"Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoder;", "TokenType", "Lcom/kotlinnlp/linguisticdescription/sentence/token/Token;", "SentenceType", "Lcom/kotlinnlp/linguisticdescription/sentence/Sentence;", "Lcom/kotlinnlp/tokensencoder/TokensEncoder;", "model", "Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel;", "useDropout", "", "id", "", "(Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel;ZI)V", "encoders", "", "getId", "()I", "getModel", "()Lcom/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoderModel;", "outputMergeProcessors", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getUseDropout", "()Z", "backward", "", "outputErrors", "forward", "input", "(Lcom/kotlinnlp/linguisticdescription/sentence/Sentence;)Ljava/util/List;", "getEncodersParamsErrors", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray$Errors;", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsList;", "copy", "getInputErrors", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor$NoInputErrors;", "getParamsErrors", "tokensencoder"})
/* loaded from: input_file:com/kotlinnlp/tokensencoder/ensemble/EnsembleTokensEncoder.class */
public final class EnsembleTokensEncoder<TokenType extends Token, SentenceType extends Sentence<TokenType>> extends TokensEncoder<TokenType, SentenceType> {
    private final List<TokensEncoder<TokenType, SentenceType>> encoders;
    private final BatchFeedforwardProcessor<DenseNDArray> outputMergeProcessors;

    @NotNull
    private final EnsembleTokensEncoderModel<TokenType, SentenceType> model;
    private final boolean useDropout;
    private final int id;

    @NotNull
    public List<DenseNDArray> forward(@NotNull SentenceType sentencetype) {
        Intrinsics.checkParameterIsNotNull(sentencetype, "input");
        int size = sentencetype.getTokens().size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            arrayList.add(new ArrayList());
        }
        ArrayList arrayList2 = arrayList;
        Iterator<T> it = this.encoders.iterator();
        while (it.hasNext()) {
            int i2 = 0;
            for (Object obj : (Iterable) ((TokensEncoder) it.next()).forward(sentencetype)) {
                int i3 = i2;
                i2++;
                if (i3 < 0) {
                    CollectionsKt.throwIndexOverflow();
                }
                ((List) arrayList2.get(i3)).add((DenseNDArray) obj);
            }
        }
        return BatchFeedforwardProcessor.forward$default(this.outputMergeProcessors, new ArrayList(arrayList2), false, 2, (Object) null);
    }

    public void backward(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        this.outputMergeProcessors.backward(list);
        List inputsErrors = this.outputMergeProcessors.getInputsErrors(false);
        int i = 0;
        for (Object obj : this.encoders) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            TokensEncoder tokensEncoder = (TokensEncoder) obj;
            if (getModel2().getComponents().get(i2).getTrainable()) {
                List list2 = inputsErrors;
                ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
                Iterator it = list2.iterator();
                while (it.hasNext()) {
                    arrayList.add((DenseNDArray) ((List) it.next()).get(i2));
                }
                tokensEncoder.backward(arrayList);
            }
        }
    }

    @NotNull
    public List<ParamsArray.Errors<?>> getParamsErrors(boolean z) {
        return CollectionsKt.plus(this.outputMergeProcessors.getParamsErrors(z), getEncodersParamsErrors(z));
    }

    @NotNull
    /* renamed from: getInputErrors, reason: merged with bridge method [inline-methods] */
    public NeuralProcessor.NoInputErrors m15getInputErrors(boolean z) {
        return NeuralProcessor.NoInputErrors.INSTANCE;
    }

    private final List<ParamsArray.Errors<?>> getEncodersParamsErrors(boolean z) {
        List<TokensEncoder<TokenType, SentenceType>> list = this.encoders;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        int i = 0;
        for (Object obj : list) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            arrayList.add(getModel2().getComponents().get(i2).getTrainable() ? ((TokensEncoder) obj).getParamsErrors(z) : CollectionsKt.emptyList());
        }
        return CollectionsKt.flatten(arrayList);
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoder
    @NotNull
    /* renamed from: getModel */
    public EnsembleTokensEncoderModel<TokenType, SentenceType> getModel2() {
        return this.model;
    }

    public boolean getUseDropout() {
        return this.useDropout;
    }

    public int getId() {
        return this.id;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public EnsembleTokensEncoder(@NotNull EnsembleTokensEncoderModel<TokenType, SentenceType> ensembleTokensEncoderModel, boolean z, int i) {
        super(ensembleTokensEncoderModel);
        boolean z2;
        Intrinsics.checkParameterIsNotNull(ensembleTokensEncoderModel, "model");
        this.model = ensembleTokensEncoderModel;
        this.useDropout = z;
        this.id = i;
        List<EnsembleTokensEncoderModel.ComponentModel<TokenType, SentenceType>> components = getModel2().getComponents();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(components, 10));
        Iterator<T> it = components.iterator();
        while (it.hasNext()) {
            arrayList.add(((EnsembleTokensEncoderModel.ComponentModel) it.next()).buildEncoder2(getUseDropout(), 0));
        }
        this.encoders = arrayList;
        StackedLayersParameters outputMergeNetwork = getModel2().getOutputMergeNetwork();
        boolean useDropout = getUseDropout();
        List<EnsembleTokensEncoderModel.ComponentModel<TokenType, SentenceType>> components2 = getModel2().getComponents();
        if (!(components2 instanceof Collection) || !components2.isEmpty()) {
            Iterator<T> it2 = components2.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    z2 = false;
                    break;
                } else if (((EnsembleTokensEncoderModel.ComponentModel) it2.next()).getTrainable()) {
                    z2 = true;
                    break;
                }
            }
        } else {
            z2 = false;
        }
        this.outputMergeProcessors = new BatchFeedforwardProcessor<>(outputMergeNetwork, useDropout, z2, 0, 8, (DefaultConstructorMarker) null);
    }

    public /* synthetic */ EnsembleTokensEncoder(EnsembleTokensEncoderModel ensembleTokensEncoderModel, boolean z, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(ensembleTokensEncoderModel, z, (i2 & 4) != 0 ? 0 : i);
    }
}
