package com.kotlinnlp.neuralparser.helpers;

import com.kotlinnlp.dependencytree.DependencyTree;
import com.kotlinnlp.neuralparser.NeuralParser;
import com.kotlinnlp.neuralparser.NeuralParserModel;
import com.kotlinnlp.neuralparser.language.Sentence;
import com.kotlinnlp.neuralparser.utils.Timer;
import com.kotlinnlp.progressindicator.ProgressIndicator;
import com.kotlinnlp.progressindicator.ProgressIndicatorBar;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.helpers.training.utils.ExamplesIndices;
import com.kotlinnlp.syntaxdecoder.SyntaxDecoderTrainer;
import com.kotlinnlp.syntaxdecoder.context.InputContext;
import com.kotlinnlp.syntaxdecoder.context.items.StateItem;
import com.kotlinnlp.syntaxdecoder.modules.actionserrorssetter.ActionsErrorsSetter;
import com.kotlinnlp.syntaxdecoder.modules.bestactionselector.BestActionSelector;
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features;
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.FeaturesErrors;
import com.kotlinnlp.syntaxdecoder.modules.supportstructure.DecodingSupportStructure;
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.oracle.OracleFactory;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.State;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import kotlin.Metadata;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Trainer.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��¦\u0001\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u000b\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\b&\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u0002*\u0014\b\u0001\u0010\u0003*\u000e\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u0002H\u00010\u0004*\u0014\b\u0002\u0010\u0005*\u000e\u0012\u0004\u0012\u0002H\u0005\u0012\u0004\u0012\u0002H\u00070\u0006*\u0016\b\u0003\u0010\u0007*\u0010\u0012\u0004\u0012\u0002H\u0007\u0012\u0002\b\u0003\u0012\u0002\b\u00030\b*\b\b\u0004\u0010\t*\u00020\n*\u0012\b\u0005\u0010\u000b*\f\u0012\u0004\u0012\u0002H\t\u0012\u0002\b\u00030\f*\n\b\u0006\u0010\r \u0001*\u00020\u000e*\b\b\u0007\u0010\u000f*\u00020\u00102\u00020\u0011BÅ\u0001\u00126\u0010\u0012\u001a2\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0005\u0012\u0004\u0012\u00028\u0006\u0012\u0004\u0012\u00028\u00070\u0013\u0012\u001e\u0010\u0014\u001a\u001a\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u00020\u0015\u0012\u0012\u0010\u0016\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0017\u0012\u0006\u0010\u0018\u001a\u00020\u0019\u0012\u001e\u0010\u001a\u001a\u001a\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u00020\u001b\u0012\u0006\u0010\u001c\u001a\u00020\u0019\u0012\u0006\u0010\u001d\u001a\u00020\u0019\u0012\b\u0010\u001e\u001a\u0004\u0018\u00010\u001f\u0012\u0006\u0010 \u001a\u00020!\u0012\b\b\u0002\u0010\"\u001a\u00020#¢\u0006\u0002\u0010$J\u0015\u0010+\u001a\u00020,2\u0006\u0010-\u001a\u00028\u0002H$¢\u0006\u0002\u0010.J-\u0010/\u001a\u00020,2\u0016\u00100\u001a\u001201R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u00042\u0006\u0010-\u001a\u00028\u0002H$¢\u0006\u0002\u00102J\u0015\u00103\u001a\u00020,2\u0006\u0010-\u001a\u00028\u0002H$¢\u0006\u0002\u0010.J\b\u00104\u001a\u00020,H\u0002J\u0010\u00105\u001a\u00020,2\u0006\u00106\u001a\u00020\u0019H\u0002J\b\u00107\u001a\u00020,H\u0002J\b\u00108\u001a\u00020,H\u0002J\b\u00109\u001a\u00020,H\u0002J\b\u0010:\u001a\u00020,H\u0002J*\u0010;\u001a\u00020,2\u0016\u0010<\u001a\u0012\u0012\u0004\u0012\u00020>0=j\b\u0012\u0004\u0012\u00020>`?2\n\b\u0002\u0010@\u001a\u0004\u0018\u00010AJ*\u0010B\u001a\u00020,2\u0016\u0010<\u001a\u0012\u0012\u0004\u0012\u00020>0=j\b\u0012\u0004\u0012\u00020>`?2\b\u0010@\u001a\u0004\u0018\u00010AH\u0002J\u0010\u0010C\u001a\u00020,2\u0006\u0010D\u001a\u00020>H\u0002J\b\u0010E\u001a\u00020,H$J\b\u0010F\u001a\u00020,H\u0002R\u000e\u0010\u001c\u001a\u00020\u0019X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010%\u001a\u00020&X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0018\u001a\u00020\u0019X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001d\u001a\u00020\u0019X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010 \u001a\u00020!X\u0082\u0004¢\u0006\u0002\n��R>\u0010\u0012\u001a2\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0005\u0012\u0004\u0012\u00028\u0006\u0012\u0004\u0012\u00028\u00070\u0013X\u0088\u0004¢\u0006\u0002\n��R8\u0010'\u001a,\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0005\u0012\u0004\u0012\u00028\u00060(X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010)\u001a\u00020*X\u0082\u000e¢\u0006\u0002\n��R\u0010\u0010\u001e\u001a\u0004\u0018\u00010\u001fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\"\u001a\u00020#X\u0082\u0004¢\u0006\u0002\n��¨\u0006G"}, d2 = {"Lcom/kotlinnlp/neuralparser/helpers/Trainer;", "StateType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/state/State;", "TransitionType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;", "InputContextType", "Lcom/kotlinnlp/syntaxdecoder/context/InputContext;", "ItemType", "Lcom/kotlinnlp/syntaxdecoder/context/items/StateItem;", "FeaturesErrorsType", "Lcom/kotlinnlp/syntaxdecoder/modules/featuresextractor/features/FeaturesErrors;", "FeaturesType", "Lcom/kotlinnlp/syntaxdecoder/modules/featuresextractor/features/Features;", "SupportStructureType", "Lcom/kotlinnlp/syntaxdecoder/modules/supportstructure/DecodingSupportStructure;", "ModelType", "Lcom/kotlinnlp/neuralparser/NeuralParserModel;", "", "neuralParser", "Lcom/kotlinnlp/neuralparser/NeuralParser;", "actionsErrorsSetter", "Lcom/kotlinnlp/syntaxdecoder/modules/actionserrorssetter/ActionsErrorsSetter;", "oracleFactory", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/oracle/OracleFactory;", "epochs", "", "bestActionSelector", "Lcom/kotlinnlp/syntaxdecoder/modules/bestactionselector/BestActionSelector;", "batchSize", "minRelevantErrorsCountToUpdate", "validator", "Lcom/kotlinnlp/neuralparser/helpers/Validator;", "modelFilename", "", "verbose", "", "(Lcom/kotlinnlp/neuralparser/NeuralParser;Lcom/kotlinnlp/syntaxdecoder/modules/actionserrorssetter/ActionsErrorsSetter;Lcom/kotlinnlp/syntaxdecoder/transitionsystem/oracle/OracleFactory;ILcom/kotlinnlp/syntaxdecoder/modules/bestactionselector/BestActionSelector;IILcom/kotlinnlp/neuralparser/helpers/Validator;Ljava/lang/String;Z)V", "bestAccuracy", "", "syntaxDecoderTrainer", "Lcom/kotlinnlp/syntaxdecoder/SyntaxDecoderTrainer;", "timer", "Lcom/kotlinnlp/neuralparser/utils/Timer;", "afterSentenceLearning", "", "context", "(Lcom/kotlinnlp/syntaxdecoder/context/InputContext;)V", "beforeApplyAction", "action", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;", "(Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;Lcom/kotlinnlp/syntaxdecoder/context/InputContext;)V", "beforeSentenceLearning", "logTrainingEnd", "logTrainingStart", "epochIndex", "logValidationEnd", "logValidationStart", "newBatch", "newEpoch", "train", "trainingSentences", "Ljava/util/ArrayList;", "Lcom/kotlinnlp/neuralparser/language/Sentence;", "Lkotlin/collections/ArrayList;", "shuffler", "Lcom/kotlinnlp/simplednn/dataset/Shuffler;", "trainEpoch", "trainSentence", "sentence", "update", "validateAndSaveModel", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/helpers/Trainer.class */
public abstract class Trainer<StateType extends State<StateType>, TransitionType extends Transition<TransitionType, StateType>, InputContextType extends InputContext<InputContextType, ItemType>, ItemType extends StateItem<ItemType, ?, ?>, FeaturesErrorsType extends FeaturesErrors, FeaturesType extends Features<FeaturesErrorsType, ?>, SupportStructureType extends DecodingSupportStructure, ModelType extends NeuralParserModel> {
    private Timer timer;
    private final SyntaxDecoderTrainer<StateType, TransitionType, InputContextType, ItemType, FeaturesErrorsType, FeaturesType, SupportStructureType> syntaxDecoderTrainer;
    private double bestAccuracy;
    private final NeuralParser<StateType, TransitionType, InputContextType, ItemType, FeaturesErrorsType, FeaturesType, SupportStructureType, ModelType> neuralParser;
    private final int epochs;
    private final int batchSize;
    private final int minRelevantErrorsCountToUpdate;
    private final Validator validator;
    private final String modelFilename;
    private final boolean verbose;

    public final void train(@NotNull ArrayList<Sentence> trainingSentences, @Nullable Shuffler shuffler) {
        Intrinsics.checkParameterIsNotNull(trainingSentences, "trainingSentences");
        Iterator<Integer> it = RangesKt.until(0, this.epochs).iterator();
        while (it.hasNext()) {
            logTrainingStart(((IntIterator) it).nextInt());
            newEpoch();
            trainEpoch(trainingSentences, shuffler);
            logTrainingEnd();
            if (this.validator != null) {
                logValidationStart();
                validateAndSaveModel();
                logValidationEnd();
            }
        }
    }

    public static /* bridge */ /* synthetic */ void train$default(Trainer trainer, ArrayList arrayList, Shuffler shuffler, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: train");
        }
        if ((i & 2) != 0) {
            shuffler = new Shuffler(true, 743L);
        }
        trainer.train(arrayList, shuffler);
    }

    protected abstract void beforeSentenceLearning(@NotNull InputContextType inputcontexttype);

    protected abstract void afterSentenceLearning(@NotNull InputContextType inputcontexttype);

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract void beforeApplyAction(@NotNull Transition<TransitionType, StateType>.Action action, @NotNull InputContextType inputcontexttype);

    private final void trainEpoch(ArrayList<Sentence> arrayList, Shuffler shuffler) {
        ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(arrayList.size(), null, 0, 6, null);
        newBatch();
        Iterator<Integer> it = new ExamplesIndices(arrayList.size(), shuffler).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            ProgressIndicator.tick$default(progressIndicatorBar, 0, 1, null);
            Sentence sentence = arrayList.get(intValue);
            Intrinsics.checkExpressionValueIsNotNull(sentence, "trainingSentences[sentenceIndex]");
            trainSentence(sentence);
            if (this.syntaxDecoderTrainer.getRelevantErrorsCount() >= this.minRelevantErrorsCountToUpdate) {
                this.syntaxDecoderTrainer.update();
                update();
                newBatch();
            }
        }
    }

    private final void trainSentence(Sentence sentence) {
        InputContextType buildContext = this.neuralParser.buildContext(sentence, true);
        DependencyTree dependencyTree = sentence.getDependencyTree();
        if (dependencyTree == null) {
            throw new IllegalStateException("The gold dependency tree of a sentence was null during its training.".toString());
        }
        beforeSentenceLearning(buildContext);
        this.syntaxDecoderTrainer.learn(buildContext, dependencyTree, true, new Trainer$trainSentence$1(this));
        afterSentenceLearning(buildContext);
    }

    private final void validateAndSaveModel() {
        Validator validator = this.validator;
        if (validator == null) {
            Intrinsics.throwNpe();
        }
        Statistics evaluate$default = Validator.evaluate$default(validator, false, 1, null);
        System.out.println((Object) new StringBuilder().append('\n').append(evaluate$default).toString());
        if (evaluate$default.getNoPunctuation().getUas().getPerc() > this.bestAccuracy) {
            this.neuralParser.getModel().dump(new FileOutputStream(new File(this.modelFilename)));
            System.out.println((Object) ("\nNEW BEST ACCURACY! Model saved to \"" + this.modelFilename + '\"'));
            this.bestAccuracy = evaluate$default.getNoPunctuation().getUas().getPerc();
        }
    }

    private final void newBatch() {
        this.syntaxDecoderTrainer.newBatch();
    }

    private final void newEpoch() {
        this.syntaxDecoderTrainer.newEpoch();
    }

    protected abstract void update();

    private final void logTrainingStart(int i) {
        if (this.verbose) {
            this.timer.reset();
            System.out.println((Object) ("\nEpoch " + (i + 1) + " of " + this.epochs));
            System.out.println((Object) "\nStart training...");
        }
    }

    private final void logTrainingEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    private final void logValidationStart() {
        if (this.verbose) {
            this.timer.reset();
            System.out.println();
        }
    }

    private final void logValidationEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Trainer(@NotNull NeuralParser<StateType, TransitionType, InputContextType, ItemType, FeaturesErrorsType, FeaturesType, SupportStructureType, ? extends ModelType> neuralParser, @NotNull ActionsErrorsSetter<StateType, TransitionType, ItemType, InputContextType> actionsErrorsSetter, @NotNull OracleFactory<StateType, TransitionType> oracleFactory, int i, @NotNull BestActionSelector<StateType, TransitionType, ItemType, InputContextType> bestActionSelector, int i2, int i3, @Nullable Validator validator, @NotNull String modelFilename, boolean z) {
        Intrinsics.checkParameterIsNotNull(neuralParser, "neuralParser");
        Intrinsics.checkParameterIsNotNull(actionsErrorsSetter, "actionsErrorsSetter");
        Intrinsics.checkParameterIsNotNull(oracleFactory, "oracleFactory");
        Intrinsics.checkParameterIsNotNull(bestActionSelector, "bestActionSelector");
        Intrinsics.checkParameterIsNotNull(modelFilename, "modelFilename");
        this.neuralParser = neuralParser;
        this.epochs = i;
        this.batchSize = i2;
        this.minRelevantErrorsCountToUpdate = i3;
        this.validator = validator;
        this.modelFilename = modelFilename;
        this.verbose = z;
        this.timer = new Timer();
        this.syntaxDecoderTrainer = new SyntaxDecoderTrainer<>(this.neuralParser.getSyntaxDecoder(), actionsErrorsSetter, bestActionSelector, oracleFactory);
        this.bestAccuracy = -1.0d;
        if (!(this.minRelevantErrorsCountToUpdate > 0)) {
            throw new IllegalArgumentException("minRelevantErrorsCountToUpdate must be > 0".toString());
        }
    }

    public /* synthetic */ Trainer(NeuralParser neuralParser, ActionsErrorsSetter actionsErrorsSetter, OracleFactory oracleFactory, int i, BestActionSelector bestActionSelector, int i2, int i3, Validator validator, String str, boolean z, int i4, DefaultConstructorMarker defaultConstructorMarker) {
        this(neuralParser, actionsErrorsSetter, oracleFactory, i, bestActionSelector, i2, i3, validator, str, (i4 & 512) != 0 ? true : z);
    }
}
