package com.kotlinnlp.simplednn.deeplearning.transformers;

import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.layers.helpers.ParamsErrorsCollector;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.deeplearning.transformers.BERTModel;
import com.kotlinnlp.simplednn.helpers.Statistics;
import com.kotlinnlp.simplednn.helpers.Trainer;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.utils.Shuffler;
import com.kotlinnlp.utils.Timer;
import com.kotlinnlp.utils.WordPieceTokenizer;
import com.kotlinnlp.utils.stats.MetricCounter;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: BERTTrainer.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��p\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\n\u0002\u0010\u000b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u001c\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010!\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\r\u0018�� /2\b\u0012\u0004\u0012\u00020\u00020\u0001:\u0002/0B_\u0012\u0006\u0010\u0003\u001a\u00020\u0004\u0012\u0006\u0010\u0005\u001a\u00020\u0002\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0007\u0012\u0006\u0010\b\u001a\u00020\t\u0012\n\u0010\n\u001a\u0006\u0012\u0002\b\u00030\u000b\u0012\f\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u00020\r\u0012\u0006\u0010\u000e\u001a\u00020\u000f\u0012\n\b\u0002\u0010\u0010\u001a\u0004\u0018\u00010\u0011\u0012\b\b\u0002\u0010\u0012\u001a\u00020\t¢\u0006\u0002\u0010\u0013J\b\u0010#\u001a\u00020$H\u0014J\u0018\u0010%\u001a\u00020\u00182\u0006\u0010&\u001a\u00020\u00182\u0006\u0010'\u001a\u00020\u000fH\u0002J\b\u0010(\u001a\u00020$H\u0014J\u0010\u0010)\u001a\u00020$2\u0006\u0010*\u001a\u00020\u0002H\u0014J\b\u0010+\u001a\u00020$H\u0002J\u0018\u0010,\u001a\u00020$2\u0006\u0010-\u001a\u00020\u00182\u0006\u0010.\u001a\u00020\u0018H\u0002R\u000e\u0010\u0014\u001a\u00020\u0015X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0016\u001a\b\u0012\u0004\u0012\u00020\u00180\u0017X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0019\u001a\u00020\u000fX\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u00070\u001bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0003\u001a\u00020\u0004X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001c\u001a\u00020\u001dX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001e\u001a\u00020\u001fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010 \u001a\u00020!X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\"\u001a\u00020\u0018X\u0082\u0004¢\u0006\u0002\n��¨\u00061"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTTrainer;", "Lcom/kotlinnlp/simplednn/helpers/Trainer;", "", "model", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel;", "modelFilename", "termsDropout", "", "optimizeEmbeddings", "", "updateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "examples", "", "epochs", "", "shuffler", "Lcom/kotlinnlp/utils/Shuffler;", "verbose", "(Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel;Ljava/lang/String;DZLcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;Ljava/lang/Iterable;ILcom/kotlinnlp/utils/Shuffler;Z)V", "bert", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERT;", "classifier", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessor;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "examplesCount", "lastLosses", "", "stats", "Lcom/kotlinnlp/simplednn/helpers/Statistics$Simple;", "timer", "Lcom/kotlinnlp/utils/Timer;", "tokenizer", "Lcom/kotlinnlp/utils/WordPieceTokenizer;", "zeroErrors", "accumulateErrors", "", "classifyVector", "vector", "goldIndex", "dumpModel", "learnFromExample", "example", "printProgressAndStats", "updateStats", "classification", "goldOutput", "Companion", "EncodedTerm", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERTTrainer.class */
public final class BERTTrainer extends Trainer<String> {
    private final WordPieceTokenizer tokenizer;
    private final BERT bert;
    private final FeedforwardNeuralProcessor<DenseNDArray> classifier;
    private final DenseNDArray zeroErrors;
    private final Statistics.Simple stats;
    private final List<Double> lastLosses;
    private int examplesCount;
    private final Timer timer;
    private final BERTModel model;
    private final double termsDropout;
    private final boolean optimizeEmbeddings;
    private static final Set<String> SPECIAL_TOKENS;
    public static final Companion Companion = new Companion(null);
    private static final Random maskRandom = new Random(739);
    private static final Random replaceRandom = new Random(743);
    private static final Random formRandom = new Random(751);

    /* compiled from: BERTTrainer.kt */
    @Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��\u001e\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\"\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002R\u0014\u0010\u0003\u001a\b\u0012\u0004\u0012\u00020\u00050\u0004X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\b\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\t\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��¨\u0006\n"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTTrainer$Companion;", "", "()V", "SPECIAL_TOKENS", "", "", "formRandom", "Ljava/util/Random;", "maskRandom", "replaceRandom", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERTTrainer$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* compiled from: BERTTrainer.kt */
    @Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��*\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0010\u000b\n\u0002\b\u0004\b\u0082\u0004\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u0006\u0010\u0014\u001a\u00020\u0003J\b\u0010\u0015\u001a\u00020\u0003H\u0002R\u001a\u0010\u0005\u001a\u00020\u0006X\u0086.¢\u0006\u000e\n��\u001a\u0004\b\u0007\u0010\b\"\u0004\b\t\u0010\nR\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u000b\u0010\fR\u0011\u0010\r\u001a\u00020\u000e¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0011\u0010\u0011\u001a\u00020\u0012¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0013¨\u0006\u0016"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTTrainer$EncodedTerm;", "", "form", "", "(Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTTrainer;Ljava/lang/String;)V", "encoding", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getEncoding", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "setEncoding", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "getForm", "()Ljava/lang/String;", "index", "", "getIndex", "()I", "isMasked", "", "()Z", "getMaskedForm", "getRandomForm", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERTTrainer$EncodedTerm.class */
    public final class EncodedTerm {

        @NotNull
        public DenseNDArray encoding;
        private final int index;
        private final boolean isMasked;

        @NotNull
        private final String form;
        final /* synthetic */ BERTTrainer this$0;

        @NotNull
        public final DenseNDArray getEncoding() {
            DenseNDArray denseNDArray = this.encoding;
            if (denseNDArray == null) {
                Intrinsics.throwUninitializedPropertyAccessException("encoding");
            }
            return denseNDArray;
        }

        public final void setEncoding(@NotNull DenseNDArray denseNDArray) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "<set-?>");
            this.encoding = denseNDArray;
        }

        public final int getIndex() {
            return this.index;
        }

        public final boolean isMasked() {
            return this.isMasked;
        }

        @NotNull
        public final String getMaskedForm() {
            if (!this.isMasked) {
                return this.form;
            }
            double nextDouble = BERTTrainer.replaceRandom.nextDouble();
            return nextDouble < 0.8d ? BERTModel.FuncToken.MASK.getForm() : nextDouble < 0.9d ? getRandomForm() : this.form;
        }

        private final String getRandomForm() {
            Object element = this.this$0.model.getVocabulary().getElement((int) Math.floor(BERTTrainer.formRandom.nextDouble() * this.this$0.model.getVocabulary().getSize()));
            if (element == null) {
                Intrinsics.throwNpe();
            }
            return (String) element;
        }

        @NotNull
        public final String getForm() {
            return this.form;
        }

        public EncodedTerm(@NotNull BERTTrainer bERTTrainer, String str) {
            Intrinsics.checkParameterIsNotNull(str, "form");
            this.this$0 = bERTTrainer;
            this.form = str;
            Integer id = bERTTrainer.model.getVocabulary().getId(this.form);
            this.index = id != null ? id.intValue() : -1;
            this.isMasked = bERTTrainer.model.getVocabulary().contains(this.form) && BERTTrainer.maskRandom.nextDouble() < bERTTrainer.termsDropout;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.helpers.Trainer
    public void learnFromExample(@NotNull String str) {
        Intrinsics.checkParameterIsNotNull(str, "example");
        List list = WordPieceTokenizer.tokenize$default(this.tokenizer, str, 0, SPECIAL_TOKENS, 2, (Object) null);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        Iterator it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new EncodedTerm(this, (String) it.next()));
        }
        ArrayList arrayList2 = arrayList;
        BERT bert = this.bert;
        ArrayList arrayList3 = arrayList2;
        ArrayList arrayList4 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList3, 10));
        Iterator it2 = arrayList3.iterator();
        while (it2.hasNext()) {
            arrayList4.add(((EncodedTerm) it2.next()).getMaskedForm());
        }
        for (Pair pair : CollectionsKt.zip(bert.forward2((List<String>) arrayList4), arrayList2)) {
            ((EncodedTerm) pair.component2()).setEncoding((DenseNDArray) pair.component1());
        }
        ArrayList<EncodedTerm> arrayList5 = arrayList2;
        ArrayList arrayList6 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList5, 10));
        for (EncodedTerm encodedTerm : arrayList5) {
            arrayList6.add(encodedTerm.isMasked() ? classifyVector(encodedTerm.getEncoding(), encodedTerm.getIndex()) : this.zeroErrors);
        }
        this.bert.backward2((List<DenseNDArray>) arrayList6);
        if (getVerbose()) {
            printProgressAndStats();
        }
    }

    @Override // com.kotlinnlp.simplednn.helpers.Trainer
    protected void accumulateErrors() {
        ((ParamsOptimizer) CollectionsKt.single(getOptimizers())).accumulate((List<? extends ParamsArray.Errors<?>>) this.bert.getParamsErrors(false), false);
    }

    @Override // com.kotlinnlp.simplednn.helpers.Trainer
    protected void dumpModel() {
        this.model.dump(new FileOutputStream(new File(getModelFilename())));
    }

    private final DenseNDArray classifyVector(DenseNDArray denseNDArray, int i) {
        this.classifier.forward((FeedforwardNeuralProcessor<DenseNDArray>) denseNDArray);
        DenseNDArray output = this.classifier.getOutput(false);
        DenseNDArray oneHotEncoder = DenseNDArrayFactory.INSTANCE.oneHotEncoder(this.classifier.getModel().getOutputSize(), i);
        updateStats(output, oneHotEncoder);
        this.classifier.backward(SoftmaxCrossEntropyCalculator.INSTANCE.calculateErrors(output, oneHotEncoder));
        ParamsErrorsAccumulator.accumulate$default((ParamsErrorsAccumulator) CollectionsKt.single(getOptimizers()), (List) this.classifier.getParamsErrors(false), false, 2, (Object) null);
        return this.classifier.getInputErrors2(true);
    }

    private final void updateStats(DenseNDArray denseNDArray, DenseNDArray denseNDArray2) {
        int argMaxIndex$default = NDArray.DefaultImpls.argMaxIndex$default(denseNDArray, 0, 1, null);
        int argMaxIndex$default2 = NDArray.DefaultImpls.argMaxIndex$default(denseNDArray2, 0, 1, null);
        this.lastLosses.add(Double.valueOf(SoftmaxCrossEntropyCalculator.INSTANCE.calculateLoss(denseNDArray, denseNDArray2).sum()));
        if (argMaxIndex$default == argMaxIndex$default2) {
            MetricCounter metric = this.stats.getMetric();
            metric.setTruePos(metric.getTruePos() + 1);
        } else {
            MetricCounter metric2 = this.stats.getMetric();
            metric2.setFalsePos(metric2.getFalsePos() + 1);
        }
    }

    private final void printProgressAndStats() {
        this.examplesCount++;
        if (this.examplesCount % 100 == 0) {
            System.out.print((Object) ".");
        }
        if (this.examplesCount % 1000 == 0) {
            Object[] objArr = {Double.valueOf(CollectionsKt.averageOfDouble(this.lastLosses))};
            String format = String.format("loss %.2f", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) ("\n[" + this.timer.formatElapsedTime() + "] After " + this.examplesCount + " examples: " + format + " | " + this.stats.getMetric()));
            validateAndSaveModel();
            this.lastLosses.clear();
            this.stats.reset();
        }
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public BERTTrainer(@NotNull BERTModel bERTModel, @NotNull String str, double d, boolean z, @NotNull UpdateMethod<?> updateMethod, @NotNull Iterable<String> iterable, int i, @Nullable Shuffler shuffler, boolean z2) {
        super(str, CollectionsKt.listOf(new ParamsOptimizer(updateMethod, null, 2, null)), iterable, i, 1, null, shuffler, z2);
        Intrinsics.checkParameterIsNotNull(bERTModel, "model");
        Intrinsics.checkParameterIsNotNull(str, "modelFilename");
        Intrinsics.checkParameterIsNotNull(updateMethod, "updateMethod");
        Intrinsics.checkParameterIsNotNull(iterable, "examples");
        this.model = bERTModel;
        this.termsDropout = d;
        this.optimizeEmbeddings = z;
        this.tokenizer = new WordPieceTokenizer(this.model.getVocabulary(), (String) null, (String) null, 6, (DefaultConstructorMarker) null);
        this.bert = new BERT(this.model, false, true, true, this.optimizeEmbeddings, 0, 34, null);
        this.classifier = new FeedforwardNeuralProcessor<>(this.model.getClassifier(), 0.0d, true, (ParamsErrorsCollector) null, 0, 26, (DefaultConstructorMarker) null);
        this.zeroErrors = DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getInputSize(), 0, 2, null));
        this.stats = new Statistics.Simple();
        this.lastLosses = new ArrayList();
        this.timer = new Timer();
        double d2 = this.termsDropout;
        if (!(d2 > 0.0d && d2 < 1.0d)) {
            throw new IllegalArgumentException("The terms dropout must be in the range (0.0, 1.0)".toString());
        }
    }

    public /* synthetic */ BERTTrainer(BERTModel bERTModel, String str, double d, boolean z, UpdateMethod updateMethod, Iterable iterable, int i, Shuffler shuffler, boolean z2, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(bERTModel, str, (i2 & 4) != 0 ? 0.15d : d, z, updateMethod, iterable, i, (i2 & 128) != 0 ? new Shuffler(false, 0L, 3, (DefaultConstructorMarker) null) : shuffler, (i2 & 256) != 0 ? true : z2);
    }

    static {
        BERTModel.FuncToken[] values = BERTModel.FuncToken.values();
        ArrayList arrayList = new ArrayList(values.length);
        for (BERTModel.FuncToken funcToken : values) {
            arrayList.add(funcToken.getForm());
        }
        SPECIAL_TOKENS = CollectionsKt.toSet(arrayList);
    }
}
