package com.kotlinnlp.tokenslabeler.helpers;

import com.kotlinnlp.simplednn.core.functionalities.gradientclipping.GradientClipping;
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator;
import com.kotlinnlp.simplednn.core.functionalities.regularization.WeightsRegularization;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.adam.ADAMMethod;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.utils.scheduling.BatchScheduling;
import com.kotlinnlp.simplednn.utils.scheduling.EpochScheduling;
import com.kotlinnlp.tokenslabeler.TokensLabeler;
import com.kotlinnlp.tokenslabeler.TokensLabelerModel;
import com.kotlinnlp.tokenslabeler.language.AnnotatedSentence;
import com.kotlinnlp.tokenslabeler.language.AnnotatedToken;
import com.kotlinnlp.tokenslabeler.language.BaseSentence;
import com.kotlinnlp.tokenslabeler.language.Label;
import com.kotlinnlp.utils.ExamplesIndices;
import com.kotlinnlp.utils.Shuffler;
import com.kotlinnlp.utils.Timer;
import com.kotlinnlp.utils.progressindicator.ProgressIndicator;
import com.kotlinnlp.utils.progressindicator.ProgressIndicatorBar;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Trainer.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��d\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\b\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0005\u0018��2\u00020\u0001B=\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\f\b\u0002\u0010\b\u001a\u0006\u0012\u0002\b\u00030\t\u0012\u0006\u0010\n\u001a\u00020\u000b\u0012\b\b\u0002\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ\b\u0010\u0018\u001a\u00020\u0019H\u0002J\u0010\u0010\u001a\u001a\u00020\u00192\u0006\u0010\u001b\u001a\u00020\u0007H\u0002J\b\u0010\u001c\u001a\u00020\u0019H\u0002J\b\u0010\u001d\u001a\u00020\u0019H\u0002J\b\u0010\u001e\u001a\u00020\u0019H\u0002J\b\u0010\u001f\u001a\u00020\u0019H\u0002J \u0010 \u001a\u00020\u00192\f\u0010!\u001a\b\u0012\u0004\u0012\u00020#0\"2\n\b\u0002\u0010$\u001a\u0004\u0018\u00010%J \u0010&\u001a\u00020\u00192\f\u0010!\u001a\b\u0012\u0004\u0012\u00020#0\"2\b\u0010$\u001a\u0004\u0018\u00010%H\u0002J\u0010\u0010'\u001a\u00020\u00192\u0006\u0010(\u001a\u00020#H\u0002J\b\u0010)\u001a\u00020\u0019H\u0002R\u000e\u0010\u000f\u001a\u00020\u0010X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0011\u001a\u00020\u0012X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0013\u001a\u00020\u0007X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0014\u001a\u00020\u0015X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0016\u001a\u00020\u0017X\u0082\u000e¢\u0006\u0002\n��R\u0012\u0010\b\u001a\u0006\u0012\u0002\b\u00030\tX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\n\u001a\u00020\u000bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��¨\u0006*"}, d2 = {"Lcom/kotlinnlp/tokenslabeler/helpers/Trainer;", Label.EMPTY_VALUE, "model", "Lcom/kotlinnlp/tokenslabeler/TokensLabelerModel;", "modelFilename", Label.EMPTY_VALUE, "epochs", Label.EMPTY_VALUE, "updateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "validator", "Lcom/kotlinnlp/tokenslabeler/helpers/Validator;", "verbose", Label.EMPTY_VALUE, "(Lcom/kotlinnlp/tokenslabeler/TokensLabelerModel;Ljava/lang/String;ILcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;Lcom/kotlinnlp/tokenslabeler/helpers/Validator;Z)V", "annotator", "Lcom/kotlinnlp/tokenslabeler/TokensLabeler;", "bestAccuracy", Label.EMPTY_VALUE, "epochCount", "optimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "timer", "Lcom/kotlinnlp/utils/Timer;", "logTrainingEnd", Label.EMPTY_VALUE, "logTrainingStart", "epochIndex", "logValidationEnd", "logValidationStart", "newBatch", "newEpoch", "train", "dataset", Label.EMPTY_VALUE, "Lcom/kotlinnlp/tokenslabeler/language/AnnotatedSentence;", "shuffler", "Lcom/kotlinnlp/utils/Shuffler;", "trainEpoch", "trainExample", "example", "validateAndSaveModel", "tokenslabeler"})
/* loaded from: input_file:com/kotlinnlp/tokenslabeler/helpers/Trainer.class */
public final class Trainer {
    private int epochCount;
    private Timer timer;
    private double bestAccuracy;
    private final TokensLabeler annotator;
    private final ParamsOptimizer optimizer;
    private final TokensLabelerModel model;
    private final String modelFilename;
    private final int epochs;
    private final UpdateMethod<?> updateMethod;
    private final Validator validator;
    private final boolean verbose;

    public final void train(@NotNull List<AnnotatedSentence> list, @Nullable Shuffler shuffler) {
        Intrinsics.checkParameterIsNotNull(list, "dataset");
        IntIterator it = RangesKt.until(0, this.epochs).iterator();
        while (it.hasNext()) {
            logTrainingStart(it.nextInt());
            newEpoch();
            trainEpoch(list, shuffler);
            logTrainingEnd();
            logValidationStart();
            validateAndSaveModel();
            logValidationEnd();
        }
    }

    public static /* synthetic */ void train$default(Trainer trainer, List list, Shuffler shuffler, int i, Object obj) {
        if ((i & 2) != 0) {
            shuffler = new Shuffler(true, 743L);
        }
        trainer.train(list, shuffler);
    }

    private final void trainEpoch(List<AnnotatedSentence> list, Shuffler shuffler) {
        ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(list.size(), (OutputStream) null, 0, 6, (DefaultConstructorMarker) null);
        Iterator it = new ExamplesIndices(list.size(), shuffler).iterator();
        while (it.hasNext()) {
            int intValue = ((Number) it.next()).intValue();
            if (this.verbose) {
                ProgressIndicator.tick$default(progressIndicatorBar, 0, 1, (Object) null);
            }
            newBatch();
            trainExample(list.get(intValue));
            this.optimizer.update();
        }
    }

    private final void trainExample(AnnotatedSentence annotatedSentence) {
        List<Pair> zip = CollectionsKt.zip(this.annotator.forward(new BaseSentence(annotatedSentence)), annotatedSentence.getTokens());
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(zip, 10));
        for (Pair pair : zip) {
            DenseNDArray denseNDArray = (DenseNDArray) pair.component1();
            Integer id = this.model.getOutputLabels().getId(((AnnotatedToken) pair.component2()).getLabel());
            if (id == null) {
                Intrinsics.throwNpe();
            }
            arrayList.add(new SoftmaxCrossEntropyCalculator().calculateErrors(denseNDArray, id.intValue()));
        }
        this.annotator.backward((List<DenseNDArray>) arrayList);
        this.optimizer.accumulate(this.annotator.getParamsErrors(false), false);
    }

    private final void newBatch() {
        if (this.updateMethod instanceof BatchScheduling) {
            this.updateMethod.newBatch();
        }
    }

    private final void newEpoch() {
        if (this.updateMethod instanceof EpochScheduling) {
            this.updateMethod.newEpoch();
        }
        this.epochCount++;
    }

    private final void validateAndSaveModel() {
        Map<String, LabelStatistics> evaluate = this.validator.evaluate();
        double d = 0.0d;
        Iterator<T> it = evaluate.values().iterator();
        while (it.hasNext()) {
            d += ((LabelStatistics) it.next()).getF1();
        }
        double size = d / evaluate.size();
        if (this.verbose) {
            evaluate.forEach(new BiConsumer<String, LabelStatistics>() { // from class: com.kotlinnlp.tokenslabeler.helpers.Trainer$validateAndSaveModel$1
                @Override // java.util.function.BiConsumer
                public final void accept(@NotNull String str, @NotNull LabelStatistics labelStatistics) {
                    Intrinsics.checkParameterIsNotNull(str, "<anonymous parameter 0>");
                    Intrinsics.checkParameterIsNotNull(labelStatistics, "stats");
                    System.out.println(labelStatistics);
                }
            });
            System.out.println();
        }
        if (size > this.bestAccuracy) {
            String str = "\nNEW BEST ACCURACY (%.2f)! Saving model to \"" + this.modelFilename + "\"...";
            Object[] objArr = {Double.valueOf(size)};
            String format = String.format(str, Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
            this.model.dump(new FileOutputStream(new File(this.modelFilename)));
            this.bestAccuracy = size;
        }
    }

    private final void logTrainingStart(int i) {
        if (this.verbose) {
            this.timer.reset();
            System.out.println((Object) ("\nEpoch " + (i + 1) + " of " + this.epochs));
            System.out.println((Object) "\nStart training...");
        }
    }

    private final void logTrainingEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    private final void logValidationStart() {
        if (this.verbose) {
            this.timer.reset();
            System.out.println((Object) "\nStart validation...");
        }
    }

    private final void logValidationEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    public Trainer(@NotNull TokensLabelerModel tokensLabelerModel, @NotNull String str, int i, @NotNull UpdateMethod<?> updateMethod, @NotNull Validator validator, boolean z) {
        Intrinsics.checkParameterIsNotNull(tokensLabelerModel, "model");
        Intrinsics.checkParameterIsNotNull(str, "modelFilename");
        Intrinsics.checkParameterIsNotNull(updateMethod, "updateMethod");
        Intrinsics.checkParameterIsNotNull(validator, "validator");
        this.model = tokensLabelerModel;
        this.modelFilename = str;
        this.epochs = i;
        this.updateMethod = updateMethod;
        this.validator = validator;
        this.verbose = z;
        this.timer = new Timer();
        this.annotator = new TokensLabeler(this.model, 0, false, 6, null);
        this.optimizer = new ParamsOptimizer(this.updateMethod, (GradientClipping) null, 2, (DefaultConstructorMarker) null);
        if (!(this.epochs > 0)) {
            throw new IllegalArgumentException("The number of epochs must be > 0".toString());
        }
    }

    public /* synthetic */ Trainer(TokensLabelerModel tokensLabelerModel, String str, int i, UpdateMethod updateMethod, Validator validator, boolean z, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(tokensLabelerModel, str, i, (i2 & 8) != 0 ? (UpdateMethod) new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, (WeightsRegularization) null, 24, (DefaultConstructorMarker) null) : updateMethod, validator, (i2 & 32) != 0 ? true : z);
    }
}
