package com.kotlinnlp.simplednn.deeplearning.transformers;

import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.embeddings.EmbeddingsMap;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.deeplearning.transformers.BERTModel;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function3;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: BERT.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��d\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010 \n\u0002\u0010\u000e\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0004\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\u0002\n\u0002\b\u000b\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018��2,\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u0002\u0012\u0004\u0012\u00020\u00050\u0001B?\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\b\b\u0002\u0010\b\u001a\u00020\t\u0012\b\b\u0002\u0010\n\u001a\u00020\t\u0012\b\b\u0002\u0010\u000b\u001a\u00020\t\u0012\b\b\u0002\u0010\f\u001a\u00020\t\u0012\b\b\u0002\u0010\r\u001a\u00020\u000e¢\u0006\u0002\u0010\u000fJ\u0016\u0010 \u001a\u00020!2\f\u0010\"\u001a\b\u0012\u0004\u0012\u00020\u00040\u0002H\u0016J\u0016\u0010#\u001a\u00020!2\f\u0010$\u001a\b\u0012\u0004\u0012\u00020\u00040\u0002H\u0002J\u0010\u0010%\u001a\u00020\u00042\u0006\u0010&\u001a\u00020\u000eH\u0002J\u001c\u0010'\u001a\b\u0012\u0004\u0012\u00020\u00040\u00022\f\u0010(\u001a\b\u0012\u0004\u0012\u00020\u00030\u0002H\u0002J\u001c\u0010)\u001a\b\u0012\u0004\u0012\u00020\u00040\u00022\f\u0010(\u001a\b\u0012\u0004\u0012\u00020\u00030\u0002H\u0016J\u0010\u0010*\u001a\u00020\u00052\u0006\u0010+\u001a\u00020\tH\u0016J\"\u0010,\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030-R\u00020.0\u0002j\u0002`/2\u0006\u0010+\u001a\u00020\tH\u0016J\u0010\u00100\u001a\u00020\u00042\u0006\u0010&\u001a\u00020\u000eH\u0002R\u000e\u0010\u000b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00040\u0011X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\r\u001a\u00020\u000eX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\u0015R \u0010\u0016\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u00040\u00170\u0002X\u0082.¢\u0006\u0002\n��R\u0014\u0010\u0018\u001a\b\u0012\u0004\u0012\u00020\u00190\u0002X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\n\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0006\u001a\u00020\u0007¢\u0006\b\n��\u001a\u0004\b\u001a\u0010\u001bR\u0014\u0010\f\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001c\u0010\u001dR\u0014\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00190\u0002X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001f\u001a\u00020\u0004X\u0082\u0004¢\u0006\u0002\n��¨\u00061"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERT;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor$NoInputErrors;", "model", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel;", "fineTuning", "", "masksEnabled", "autoPadding", "propagateToInput", "id", "", "(Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel;ZZZZI)V", "embNorm", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "errorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "getId", "()I", "inputSequence", "Lkotlin/Pair;", "layers", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTLayer;", "getModel", "()Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel;", "getPropagateToInput", "()Z", "trainableLayers", "zeroErrors", "backward", "", "outputErrors", "backwardInput", "errors", "buildPositionalEncoding", "pos", "encodeSequence", "input", "forward", "getInputErrors", "copy", "getParamsErrors", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray$Errors;", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsList;", "getPositionalEncoding", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERT.class */
public final class BERT implements NeuralProcessor<List<? extends String>, List<? extends DenseNDArray>, List<? extends DenseNDArray>, NeuralProcessor.NoInputErrors> {
    private final ParamsErrorsAccumulator errorsAccumulator;
    private final List<BERTLayer> layers;
    private final List<BERTLayer> trainableLayers;
    private final BatchFeedforwardProcessor<DenseNDArray> embNorm;
    private List<Pair<String, DenseNDArray>> inputSequence;
    private final DenseNDArray zeroErrors;

    @NotNull
    private final BERTModel model;
    private final boolean masksEnabled;
    private final boolean autoPadding;
    private final boolean propagateToInput;
    private final int id;

    @NotNull
    /* renamed from: forward, reason: avoid collision after fix types in other method */
    public List<DenseNDArray> forward2(@NotNull List<String> list) {
        Intrinsics.checkParameterIsNotNull(list, "input");
        List<DenseNDArray> forward = this.embNorm.forward((List<? extends DenseNDArray>) encodeSequence(list));
        Iterator<T> it = this.layers.iterator();
        while (it.hasNext()) {
            forward = ((BERTLayer) it.next()).forward2(forward);
        }
        return this.autoPadding ? forward.subList(1, CollectionsKt.getLastIndex(forward)) : forward;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ List<? extends DenseNDArray> forward(List<? extends String> list) {
        return forward2((List<String>) list);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* renamed from: backward, reason: avoid collision after fix types in other method */
    public void backward2(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        this.errorsAccumulator.clear();
        List plus = this.autoPadding ? CollectionsKt.plus(CollectionsKt.plus(CollectionsKt.listOf(this.zeroErrors), list), CollectionsKt.listOf(this.zeroErrors)) : list;
        for (BERTLayer bERTLayer : CollectionsKt.reversed(this.trainableLayers)) {
            bERTLayer.backward2((List<DenseNDArray>) plus);
            this.errorsAccumulator.accumulate((List<? extends ParamsArray.Errors<?>>) bERTLayer.getParamsErrors(false), false);
            plus = bERTLayer.getInputErrors2(false);
        }
        backwardInput(plus);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ void backward(List<? extends DenseNDArray> list) {
        backward2((List<DenseNDArray>) list);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public List<ParamsArray.Errors<?>> getParamsErrors(boolean z) {
        return this.errorsAccumulator.getParamsErrors(z);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getInputErrors */
    public NeuralProcessor.NoInputErrors getInputErrors2(boolean z) {
        return NeuralProcessor.NoInputErrors.INSTANCE;
    }

    /* JADX WARN: Type inference failed for: r0v0, types: [com.kotlinnlp.simplednn.deeplearning.transformers.BERT$encodeSequence$1] */
    private final List<DenseNDArray> encodeSequence(List<String> list) {
        ?? r0 = new Function3<String, Integer, Boolean, Pair<? extends String, ? extends DenseNDArray>>() { // from class: com.kotlinnlp.simplednn.deeplearning.transformers.BERT$encodeSequence$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj, Object obj2, Object obj3) {
                return invoke((String) obj, ((Number) obj2).intValue(), ((Boolean) obj3).booleanValue());
            }

            @NotNull
            public final Pair<String, DenseNDArray> invoke(@NotNull String str, int i, boolean z) {
                ParamsArray paramsArray;
                DenseNDArray positionalEncoding;
                boolean z2;
                Intrinsics.checkParameterIsNotNull(str, "token");
                if (!z) {
                    z2 = BERT.this.masksEnabled;
                    if (!z2 || !Intrinsics.areEqual(str, BERTModel.FuncToken.MASK.getForm())) {
                        EmbeddingsMap<String> wordEmb = BERT.this.getModel().getWordEmb();
                        if (wordEmb == null) {
                            Intrinsics.throwNpe();
                        }
                        paramsArray = wordEmb.get(str);
                        DenseNDArray values = paramsArray.getValues();
                        positionalEncoding = BERT.this.getPositionalEncoding(i);
                        return TuplesKt.to(str, values.sum(positionalEncoding).assignSum((NDArray<?>) BERT.this.getModel().getTokenTypeEmb().get(0).getValues()));
                    }
                }
                paramsArray = BERT.this.getModel().getFuncEmb().get(BERTModel.FuncToken.Companion.byForm(str));
                DenseNDArray values2 = paramsArray.getValues();
                positionalEncoding = BERT.this.getPositionalEncoding(i);
                return TuplesKt.to(str, values2.sum(positionalEncoding).assignSum((NDArray<?>) BERT.this.getModel().getTokenTypeEmb().get(0).getValues()));
            }

            @NotNull
            public static /* synthetic */ Pair invoke$default(BERT$encodeSequence$1 bERT$encodeSequence$1, String str, int i, boolean z, int i2, Object obj) {
                if ((i2 & 4) != 0) {
                    z = false;
                }
                return bERT$encodeSequence$1.invoke(str, i, z);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(3);
            }
        };
        List listOf = CollectionsKt.listOf(r0.invoke(BERTModel.FuncToken.CLS.getForm(), 0, true));
        List<String> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        int i = 0;
        for (Object obj : list2) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            arrayList.add(BERT$encodeSequence$1.invoke$default(r0, (String) obj, i2 + 1, false, 4, null));
        }
        this.inputSequence = CollectionsKt.plus(CollectionsKt.plus(listOf, arrayList), CollectionsKt.listOf(r0.invoke(BERTModel.FuncToken.SEP.getForm(), CollectionsKt.getLastIndex(list) + 2, true)));
        List<Pair<String, DenseNDArray>> list3 = this.inputSequence;
        if (list3 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("inputSequence");
        }
        return (List) CollectionsKt.unzip(list3).getSecond();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final DenseNDArray getPositionalEncoding(final int i) {
        return EmbeddingsMap.getOrSet$default(this.model.getPositionalEmb(), Integer.valueOf(i), 0.0d, new Function0<ParamsArray>() { // from class: com.kotlinnlp.simplednn.deeplearning.transformers.BERT$getPositionalEncoding$1
            @NotNull
            public final ParamsArray invoke() {
                DenseNDArray buildPositionalEncoding;
                buildPositionalEncoding = BERT.this.buildPositionalEncoding(i);
                return new ParamsArray(buildPositionalEncoding, (ParamsArray.ErrorsType) null, 2, (DefaultConstructorMarker) null);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(0);
            }
        }, 2, (Object) null).getValues();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final DenseNDArray buildPositionalEncoding(int i) {
        DenseNDArrayFactory denseNDArrayFactory = DenseNDArrayFactory.INSTANCE;
        double[] dArr = new double[this.model.getInputSize()];
        int length = dArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = i2;
            dArr[i2] = i3 % 2 == 0 ? Math.sin(i / Math.pow(10000.0d, i3 / this.model.getInputSize())) : Math.cos(i / Math.pow(10000.0d, i3 / this.model.getInputSize()));
        }
        return denseNDArrayFactory.arrayOf(dArr);
    }

    /* JADX WARN: Removed duplicated region for block: B:20:0x00d2  */
    /* JADX WARN: Removed duplicated region for block: B:29:0x00e7  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private final void backwardInput(java.util.List<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray> r7) {
        /*
            Method dump skipped, instructions count: 329
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: com.kotlinnlp.simplednn.deeplearning.transformers.BERT.backwardInput(java.util.List):void");
    }

    @NotNull
    public final BERTModel getModel() {
        return this.model;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    public int getId() {
        return this.id;
    }

    public BERT(@NotNull BERTModel bERTModel, boolean z, boolean z2, boolean z3, boolean z4, int i) {
        Intrinsics.checkParameterIsNotNull(bERTModel, "model");
        this.model = bERTModel;
        this.masksEnabled = z2;
        this.autoPadding = z3;
        this.propagateToInput = z4;
        this.id = i;
        this.errorsAccumulator = new ParamsErrorsAccumulator();
        List<BERTParameters> layers = this.model.getLayers();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(layers, 10));
        int i2 = 0;
        for (Object obj : layers) {
            int i3 = i2;
            i2++;
            if (i3 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            arrayList.add(new BERTLayer((BERTParameters) obj, i3 > 0 || getPropagateToInput(), 0, 4, null));
        }
        this.layers = arrayList;
        this.trainableLayers = z ? CollectionsKt.takeLast(this.layers, 1) : this.layers;
        this.embNorm = new BatchFeedforwardProcessor<>(this.model.getEmbNorm(), 0.0d, true, 0, 10, (DefaultConstructorMarker) null);
        this.zeroErrors = DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getInputSize(), 0, 2, null));
    }

    public /* synthetic */ BERT(BERTModel bERTModel, boolean z, boolean z2, boolean z3, boolean z4, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(bERTModel, (i2 & 2) != 0 ? false : z, (i2 & 4) != 0 ? false : z2, (i2 & 8) != 0 ? false : z3, (i2 & 16) != 0 ? false : z4, (i2 & 32) != 0 ? 0 : i);
    }

    @NotNull
    /* renamed from: propagateErrors, reason: avoid collision after fix types in other method */
    public NeuralProcessor.NoInputErrors propagateErrors2(@NotNull List<DenseNDArray> list, @NotNull ParamsOptimizer paramsOptimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(list, "errors");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer, "optimizer");
        return (NeuralProcessor.NoInputErrors) NeuralProcessor.DefaultImpls.propagateErrors(this, list, paramsOptimizer, z);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ NeuralProcessor.NoInputErrors propagateErrors(List<? extends DenseNDArray> list, ParamsOptimizer paramsOptimizer, boolean z) {
        return propagateErrors2((List<DenseNDArray>) list, paramsOptimizer, z);
    }
}
