package com.kotlinnlp.languagemodel.training;

import com.kotlinnlp.languagemodel.CharLM;
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.simplednn.utils.scheduling.BatchScheduling;
import com.kotlinnlp.simplednn.utils.scheduling.EpochScheduling;
import com.kotlinnlp.simplednn.utils.scheduling.ExampleScheduling;
import com.kotlinnlp.utils.Timer;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.IntRange;
import kotlin.ranges.RangesKt;
import kotlin.text.StringsKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Trainer.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��Z\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0002\b\n\u0018��2\u00020\u0001BG\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0005\u0012\n\b\u0002\u0010\u0007\u001a\u0004\u0018\u00010\b\u0012\u0006\u0010\t\u001a\u00020\b\u0012\n\u0010\n\u001a\u0006\u0012\u0002\b\u00030\u000b\u0012\b\b\u0002\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ\u0016\u0010\u0016\u001a\u00020\u00172\f\u0010\u0018\u001a\b\u0012\u0004\u0012\u00020\u001a0\u0019H\u0002J\u0016\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\b0\u00192\u0006\u0010\u001c\u001a\u00020\u0005H\u0002J\b\u0010\u001d\u001a\u00020\u001eH\u0002J\u0010\u0010\u001f\u001a\u00020\u001e2\u0006\u0010 \u001a\u00020\bH\u0002J\b\u0010!\u001a\u00020\u001eH\u0002J\b\u0010\"\u001a\u00020\u001eH\u0002J\b\u0010#\u001a\u00020\u001eH\u0002J\u0006\u0010$\u001a\u00020\u001eJ\b\u0010%\u001a\u00020\u001eH\u0002J\u0010\u0010&\u001a\u00020\u001e2\u0006\u0010'\u001a\u00020\u0005H\u0002R\u000e\u0010\u0006\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\t\u001a\u00020\bX\u0082\u0004¢\u0006\u0002\n��R\u0012\u0010\u0007\u001a\u0004\u0018\u00010\bX\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u000fR\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0014\u001a\u00020\u0015X\u0082\u000e¢\u0006\u0002\n��R\u0012\u0010\n\u001a\u0006\u0012\u0002\b\u00030\u000bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��¨\u0006("}, d2 = {"Lcom/kotlinnlp/languagemodel/training/Trainer;", "", "model", "Lcom/kotlinnlp/languagemodel/CharLM;", "modelFilename", "", "corpusFilePath", "maxSentences", "", "epochs", "updateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "verbose", "", "(Lcom/kotlinnlp/languagemodel/CharLM;Ljava/lang/String;Ljava/lang/String;Ljava/lang/Integer;ILcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;Z)V", "Ljava/lang/Integer;", "optimizer", "Lcom/kotlinnlp/languagemodel/training/Optimizer;", "processor", "Lcom/kotlinnlp/languagemodel/training/Processor;", "timer", "Lcom/kotlinnlp/utils/Timer;", "calculatePerplexity", "", "prediction", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getExpectedCharsSequence", "s", "logTrainingEnd", "", "logTrainingStart", "epochIndex", "newBatch", "newEpoch", "newExample", "train", "trainEpoch", "trainSentence", "sentence", "languagemodel"})
/* loaded from: input_file:com/kotlinnlp/languagemodel/training/Trainer.class */
public final class Trainer {
    private Timer timer;
    private final Processor processor;
    private final Optimizer optimizer;
    private final CharLM model;
    private final String modelFilename;
    private final String corpusFilePath;
    private final Integer maxSentences;
    private final int epochs;
    private final UpdateMethod<?> updateMethod;
    private final boolean verbose;

    public final void train() {
        Iterator<Integer> it = RangesKt.until(0, this.epochs).iterator();
        while (it.hasNext()) {
            logTrainingStart(((IntIterator) it).nextInt());
            newEpoch();
            trainEpoch();
            logTrainingEnd();
        }
    }

    private final void trainEpoch() {
        ExtensionsKt.forEachIndexedSentence(new File(this.corpusFilePath), this.maxSentences, new Function2<Integer, String, Unit>() { // from class: com.kotlinnlp.languagemodel.training.Trainer$trainEpoch$1
            @Override // kotlin.jvm.functions.Function2
            public /* bridge */ /* synthetic */ Unit invoke(Integer num, String str) {
                invoke(num.intValue(), str);
                return Unit.INSTANCE;
            }

            public final void invoke(int i, @NotNull String sentence) {
                CharLM charLM;
                Optimizer optimizer;
                CharLM charLM2;
                String str;
                String str2;
                Intrinsics.checkParameterIsNotNull(sentence, "sentence");
                Trainer.this.newBatch();
                Trainer.this.newExample();
                charLM = Trainer.this.model;
                if (charLM.getReverseModel()) {
                    Trainer.this.trainSentence(StringsKt.reversed((CharSequence) sentence).toString());
                } else {
                    Trainer.this.trainSentence(sentence);
                }
                optimizer = Trainer.this.optimizer;
                optimizer.update();
                if (i <= 0 || i % 100 != 0) {
                    return;
                }
                charLM2 = Trainer.this.model;
                str = Trainer.this.modelFilename;
                charLM2.dump(new FileOutputStream(new File(str)));
                StringBuilder append = new StringBuilder().append("\n[").append(i).append("] Model saved to \"");
                str2 = Trainer.this.modelFilename;
                System.out.println((Object) append.append(str2).append('\"').toString());
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(2);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final void trainSentence(String str) {
        List<DenseNDArray> forward = this.processor.forward(str);
        List<Integer> expectedCharsSequence = getExpectedCharsSequence(str);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(expectedCharsSequence, 10));
        Iterator<T> it = expectedCharsSequence.iterator();
        while (it.hasNext()) {
            arrayList.add(DenseNDArrayFactory.INSTANCE.oneHotEncoder(this.model.getClassifier().getOutputSize(), ((Number) it.next()).intValue()));
        }
        ArrayList arrayList2 = arrayList;
        List<DenseNDArray> calculateErrors = new SoftmaxCrossEntropyCalculator().calculateErrors(forward, arrayList2);
        if (this.verbose) {
            System.out.println((Object) ("Loss: " + new SoftmaxCrossEntropyCalculator().calculateMeanLoss(forward, arrayList2) + " Perplexity: " + calculatePerplexity(forward)));
        }
        this.processor.backward2(calculateErrors);
        com.kotlinnlp.simplednn.core.optimizer.Optimizer.accumulate$default(this.optimizer, this.processor.getParamsErrors2(false), false, 2, null);
    }

    private final double calculatePerplexity(List<DenseNDArray> list) {
        double d = 0.0d;
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            d += Math.log(((DenseNDArray) it.next()).max());
        }
        return Math.exp(-(d / this.model.getClassifierOutputSize()));
    }

    private final List<Integer> getExpectedCharsSequence(String str) {
        IntRange until = RangesKt.until(0, str.length());
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(until, 10));
        Iterator<Integer> it = until.iterator();
        while (it.hasNext()) {
            int nextInt = ((IntIterator) it).nextInt();
            arrayList.add(Integer.valueOf(nextInt < StringsKt.getLastIndex(str) ? this.model.getCharId(str.charAt(nextInt + 1)) : this.model.getEosId()));
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final void newExample() {
        if (this.updateMethod instanceof ExampleScheduling) {
            ((ExampleScheduling) this.updateMethod).newExample();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final void newBatch() {
        if (this.updateMethod instanceof BatchScheduling) {
            ((BatchScheduling) this.updateMethod).newBatch();
        }
    }

    private final void newEpoch() {
        if (this.updateMethod instanceof EpochScheduling) {
            ((EpochScheduling) this.updateMethod).newEpoch();
        }
    }

    private final void logTrainingStart(int i) {
        if (this.verbose) {
            this.timer.reset();
            System.out.println((Object) ("\nEpoch " + (i + 1) + " of " + this.epochs));
            System.out.println((Object) "\nStart training...");
        }
    }

    private final void logTrainingEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    public Trainer(@NotNull CharLM model, @NotNull String modelFilename, @NotNull String corpusFilePath, @Nullable Integer num, int i, @NotNull UpdateMethod<?> updateMethod, boolean z) {
        Intrinsics.checkParameterIsNotNull(model, "model");
        Intrinsics.checkParameterIsNotNull(modelFilename, "modelFilename");
        Intrinsics.checkParameterIsNotNull(corpusFilePath, "corpusFilePath");
        Intrinsics.checkParameterIsNotNull(updateMethod, "updateMethod");
        this.model = model;
        this.modelFilename = modelFilename;
        this.corpusFilePath = corpusFilePath;
        this.maxSentences = num;
        this.epochs = i;
        this.updateMethod = updateMethod;
        this.verbose = z;
        this.timer = new Timer();
        this.processor = new Processor(this.model, true);
        this.optimizer = new Optimizer(this.model, this.updateMethod);
        if (!(this.epochs > 0)) {
            throw new IllegalArgumentException("The number of epochs must be > 0".toString());
        }
    }

    public /* synthetic */ Trainer(CharLM charLM, String str, String str2, Integer num, int i, UpdateMethod updateMethod, boolean z, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(charLM, str, str2, (i2 & 8) != 0 ? (Integer) null : num, i, updateMethod, (i2 & 64) != 0 ? true : z);
    }
}
