package com.kotlinnlp.neuralparser.helpers;

import com.kotlinnlp.neuralparser.NeuralParser;
import com.kotlinnlp.neuralparser.language.Sentence;
import com.kotlinnlp.neuralparser.utils.Timer;
import com.kotlinnlp.progressindicator.ProgressIndicator;
import com.kotlinnlp.progressindicator.ProgressIndicatorBar;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.helpers.training.utils.ExamplesIndices;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Trainer.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��\\\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\b&\u0018��2\u00020\u0001BG\u0012\n\u0010\u0002\u001a\u0006\u0012\u0002\b\u00030\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0005\u0012\b\u0010\u0007\u001a\u0004\u0018\u00010\b\u0012\u0006\u0010\t\u001a\u00020\n\u0012\b\b\u0002\u0010\u000b\u001a\u00020\u0005\u0012\b\b\u0002\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ\b\u0010\u0013\u001a\u00020\u0005H$J\b\u0010\u0014\u001a\u00020\u0015H\u0002J\u0010\u0010\u0016\u001a\u00020\u00152\u0006\u0010\u0017\u001a\u00020\u0005H\u0002J\b\u0010\u0018\u001a\u00020\u0015H\u0002J\b\u0010\u0019\u001a\u00020\u0015H\u0002J\b\u0010\u001a\u001a\u00020\u0015H\u0014J\b\u0010\u001b\u001a\u00020\u0015H\u0014JH\u0010\u001c\u001a\u00020\u00152\u0016\u0010\u001d\u001a\u0012\u0012\u0004\u0012\u00020\u001f0\u001ej\b\u0012\u0004\u0012\u00020\u001f` 2\u001c\b\u0002\u0010!\u001a\u0016\u0012\u0004\u0012\u00020\u001f\u0018\u00010\u001ej\n\u0012\u0004\u0012\u00020\u001f\u0018\u0001` 2\n\b\u0002\u0010\"\u001a\u0004\u0018\u00010#JF\u0010$\u001a\u00020\u00152\u0016\u0010\u001d\u001a\u0012\u0012\u0004\u0012\u00020\u001f0\u001ej\b\u0012\u0004\u0012\u00020\u001f` 2\u001a\u0010!\u001a\u0016\u0012\u0004\u0012\u00020\u001f\u0018\u00010\u001ej\n\u0012\u0004\u0012\u00020\u001f\u0018\u0001` 2\b\u0010\"\u001a\u0004\u0018\u00010#H\u0002J\u001a\u0010%\u001a\u00020\u00152\u0006\u0010&\u001a\u00020\u001f2\b\u0010'\u001a\u0004\u0018\u00010\u001fH$J\b\u0010(\u001a\u00020\u0015H$J\b\u0010)\u001a\u00020\u0015H\u0002R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u000f\u001a\u00020\u0010X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0006\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u000b\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\t\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u0012\u0010\u0002\u001a\u0006\u0012\u0002\b\u00030\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0011\u001a\u00020\u0012X\u0082\u000e¢\u0006\u0002\n��R\u0010\u0010\u0007\u001a\u0004\u0018\u00010\bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��¨\u0006*"}, d2 = {"Lcom/kotlinnlp/neuralparser/helpers/Trainer;", "", "neuralParser", "Lcom/kotlinnlp/neuralparser/NeuralParser;", "batchSize", "", "epochs", "validator", "Lcom/kotlinnlp/neuralparser/helpers/Validator;", "modelFilename", "", "minRelevantErrorsCountToUpdate", "verbose", "", "(Lcom/kotlinnlp/neuralparser/NeuralParser;IILcom/kotlinnlp/neuralparser/helpers/Validator;Ljava/lang/String;IZ)V", "bestAccuracy", "", "timer", "Lcom/kotlinnlp/neuralparser/utils/Timer;", "getRelevantErrorsCount", "logTrainingEnd", "", "logTrainingStart", "epochIndex", "logValidationEnd", "logValidationStart", "newBatch", "newEpoch", "train", "trainingSentences", "Ljava/util/ArrayList;", "Lcom/kotlinnlp/neuralparser/language/Sentence;", "Lkotlin/collections/ArrayList;", "goldPOSSentences", "shuffler", "Lcom/kotlinnlp/simplednn/dataset/Shuffler;", "trainEpoch", "trainSentence", "sentence", "goldPOSSentence", "update", "validateAndSaveModel", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/helpers/Trainer.class */
public abstract class Trainer {
    private Timer timer;
    private double bestAccuracy;
    private final NeuralParser<?> neuralParser;
    private final int batchSize;
    private final int epochs;
    private final Validator validator;
    private final String modelFilename;
    private final int minRelevantErrorsCountToUpdate;
    private final boolean verbose;

    public final void train(@NotNull ArrayList<Sentence> arrayList, @Nullable ArrayList<Sentence> arrayList2, @Nullable Shuffler shuffler) {
        Intrinsics.checkParameterIsNotNull(arrayList, "trainingSentences");
        if (!(arrayList2 == null || arrayList.size() == arrayList2.size())) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        IntIterator it = RangesKt.until(0, this.epochs).iterator();
        while (it.hasNext()) {
            logTrainingStart(it.nextInt());
            newEpoch();
            trainEpoch(arrayList, arrayList2, shuffler);
            logTrainingEnd();
            if (this.validator != null) {
                logValidationStart();
                validateAndSaveModel();
                logValidationEnd();
            }
        }
    }

    public static /* bridge */ /* synthetic */ void train$default(Trainer trainer, ArrayList arrayList, ArrayList arrayList2, Shuffler shuffler, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: train");
        }
        if ((i & 2) != 0) {
            arrayList2 = (ArrayList) null;
        }
        if ((i & 4) != 0) {
            shuffler = new Shuffler(true, 743L);
        }
        trainer.train(arrayList, arrayList2, shuffler);
    }

    private final void trainEpoch(ArrayList<Sentence> arrayList, ArrayList<Sentence> arrayList2, Shuffler shuffler) {
        ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(arrayList.size(), (OutputStream) null, 0, 6, (DefaultConstructorMarker) null);
        newBatch();
        int i = 0;
        for (Object obj : new ExamplesIndices(arrayList.size(), shuffler)) {
            int i2 = i;
            i++;
            int intValue = ((Number) obj).intValue();
            boolean z = (i2 + 1) % this.batchSize == 0 || i2 == CollectionsKt.getLastIndex(arrayList);
            ProgressIndicator.tick$default(progressIndicatorBar, 0, 1, (Object) null);
            Sentence sentence = arrayList.get(intValue);
            Intrinsics.checkExpressionValueIsNotNull(sentence, "trainingSentences[sentenceIndex]");
            trainSentence(sentence, arrayList2 != null ? arrayList2.get(intValue) : null);
            if (z && getRelevantErrorsCount() >= this.minRelevantErrorsCountToUpdate) {
                update();
                newBatch();
            }
        }
    }

    /* JADX WARN: Type inference failed for: r0v15, types: [com.kotlinnlp.neuralparser.NeuralParserModel] */
    private final void validateAndSaveModel() {
        Validator validator = this.validator;
        if (validator == null) {
            Intrinsics.throwNpe();
        }
        Statistics evaluate$default = Validator.evaluate$default(validator, false, 1, null);
        System.out.println((Object) new StringBuilder().append('\n').append(evaluate$default).toString());
        if (evaluate$default.getNoPunctuation().getUas().getPerc() > this.bestAccuracy) {
            this.neuralParser.getModel().dump(new FileOutputStream(new File(this.modelFilename)));
            System.out.println((Object) ("\nNEW BEST ACCURACY! Model saved to \"" + this.modelFilename + '\"'));
            this.bestAccuracy = evaluate$default.getNoPunctuation().getUas().getPerc();
        }
    }

    private final void logTrainingStart(int i) {
        if (this.verbose) {
            this.timer.reset();
            System.out.println((Object) ("\nEpoch " + (i + 1) + " of " + this.epochs));
            System.out.println((Object) "\nStart training...");
        }
    }

    private final void logTrainingEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    private final void logValidationStart() {
        if (this.verbose) {
            this.timer.reset();
            System.out.println();
        }
    }

    private final void logValidationEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    protected void newBatch() {
    }

    protected void newEpoch() {
    }

    protected abstract void update();

    protected abstract void trainSentence(@NotNull Sentence sentence, @Nullable Sentence sentence2);

    protected abstract int getRelevantErrorsCount();

    public Trainer(@NotNull NeuralParser<?> neuralParser, int i, int i2, @Nullable Validator validator, @NotNull String str, int i3, boolean z) {
        Intrinsics.checkParameterIsNotNull(neuralParser, "neuralParser");
        Intrinsics.checkParameterIsNotNull(str, "modelFilename");
        this.neuralParser = neuralParser;
        this.batchSize = i;
        this.epochs = i2;
        this.validator = validator;
        this.modelFilename = str;
        this.minRelevantErrorsCountToUpdate = i3;
        this.verbose = z;
        this.timer = new Timer();
        this.bestAccuracy = -1.0d;
        if (!(this.epochs > 0)) {
            throw new IllegalArgumentException("The number of epochs must be > 0".toString());
        }
        if (!(this.batchSize > 0)) {
            throw new IllegalArgumentException("The size of the batch must be > 0".toString());
        }
        if (!(this.minRelevantErrorsCountToUpdate > 0)) {
            throw new IllegalArgumentException("minRelevantErrorsCountToUpdate must be > 0".toString());
        }
    }

    public /* synthetic */ Trainer(NeuralParser neuralParser, int i, int i2, Validator validator, String str, int i3, boolean z, int i4, DefaultConstructorMarker defaultConstructorMarker) {
        this(neuralParser, i, i2, validator, str, (i4 & 32) != 0 ? 1 : i3, (i4 & 64) != 0 ? true : z);
    }
}
