package com.kotlinnlp.hanclassifier.helpers;

import com.kotlinnlp.hanclassifier.EncodedSentence;
import com.kotlinnlp.hanclassifier.HANClassifier;
import com.kotlinnlp.hanclassifier.dataset.Example;
import com.kotlinnlp.linguisticdescription.sentence.Sentence;
import com.kotlinnlp.linguisticdescription.sentence.token.FormToken;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HANParameters;
import com.kotlinnlp.simplednn.helpers.training.utils.ExamplesIndices;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.tokensencoder.TokensEncoder;
import com.kotlinnlp.tokensencoder.TokensEncoderOptimizer;
import com.kotlinnlp.utils.progressindicator.ProgressIndicator;
import com.kotlinnlp.utils.progressindicator.ProgressIndicatorBar;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Trainer.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��~\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\t\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010 \n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\u0018��2\u00020\u0001BO\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0018\u0010\u0004\u001a\u0014\u0012\u0004\u0012\u00020\u0006\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00060\u00070\u0005\u0012\n\b\u0002\u0010\b\u001a\u0004\u0018\u00010\t\u0012\n\u0010\n\u001a\u0006\u0012\u0002\b\u00030\u000b\u0012\u000e\b\u0002\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u000e0\r¢\u0006\u0002\u0010\u000fJ\b\u0010\u001b\u001a\u00020\u001cH\u0002J\u0010\u0010\u001d\u001a\u00020\u000e2\u0006\u0010\u001e\u001a\u00020\u001fH\u0002J\b\u0010 \u001a\u00020\u000eH\u0002J\b\u0010!\u001a\u00020\u000eH\u0002J\b\u0010\"\u001a\u00020\u000eH\u0002J\b\u0010#\u001a\u00020\u000eH\u0002JP\u0010$\u001a\u00020\u000e2\f\u0010%\u001a\b\u0012\u0004\u0012\u00020\u001f0&2\u0006\u0010'\u001a\u00020(2\b\b\u0002\u0010)\u001a\u00020(2\n\b\u0002\u0010*\u001a\u0004\u0018\u00010+2\u0010\b\u0002\u0010,\u001a\n\u0012\u0004\u0012\u00020\u001f\u0018\u00010&2\n\b\u0002\u0010-\u001a\u0004\u0018\u00010\u001cJ(\u0010.\u001a\u00020\u000e2\f\u0010%\u001a\b\u0012\u0004\u0012\u00020\u001f0&2\u0006\u0010)\u001a\u00020(2\b\u0010*\u001a\u0004\u0018\u00010+H\u0002J\b\u0010/\u001a\u00020\u000eH\u0002J \u00100\u001a\u00020\u000e2\f\u0010,\u001a\b\u0012\u0004\u0012\u00020\u001f0&2\b\u0010-\u001a\u0004\u0018\u00010\u001cH\u0002J\u0016\u00101\u001a\u00020\u00112\f\u0010,\u001a\b\u0012\u0004\u0012\u00020\u001f0&H\u0002R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u000e0\rX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0015\u001a\u00020\u0016X\u0082\u000e¢\u0006\u0002\n��R\u0010\u0010\b\u001a\u0004\u0018\u00010\tX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0017\u001a\u00020\u0018X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0019\u001a\u00020\u001aX\u0082\u0004¢\u0006\u0002\n��¨\u00062"}, d2 = {"Lcom/kotlinnlp/hanclassifier/helpers/Trainer;", "", "classifier", "Lcom/kotlinnlp/hanclassifier/HANClassifier;", "tokensEncoder", "Lcom/kotlinnlp/tokensencoder/TokensEncoder;", "Lcom/kotlinnlp/linguisticdescription/sentence/token/FormToken;", "Lcom/kotlinnlp/linguisticdescription/sentence/Sentence;", "tokensEncoderOptimizer", "Lcom/kotlinnlp/tokensencoder/TokensEncoderOptimizer;", "updateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "onSaveModel", "Lkotlin/Function0;", "", "(Lcom/kotlinnlp/hanclassifier/HANClassifier;Lcom/kotlinnlp/tokensencoder/TokensEncoder;Lcom/kotlinnlp/tokensencoder/TokensEncoderOptimizer;Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;Lkotlin/jvm/functions/Function0;)V", "bestAccuracy", "", "classifierOptimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HANParameters;", "startTime", "", "tokensEncodersPool", "Lcom/kotlinnlp/hanclassifier/helpers/TokensEncodersPool;", "validationHelper", "Lcom/kotlinnlp/hanclassifier/helpers/Validator;", "formatElapsedTime", "", "learnFromExample", "example", "Lcom/kotlinnlp/hanclassifier/dataset/Example;", "newBatch", "newEpoch", "newExample", "startTiming", "train", "trainingSet", "", "epochs", "", "batchSize", "shuffler", "Lcom/kotlinnlp/simplednn/dataset/Shuffler;", "validationSet", "modelFilename", "trainEpoch", "update", "validateAndSaveModel", "validateEpoch", "hanclassifier"})
/* loaded from: input_file:com/kotlinnlp/hanclassifier/helpers/Trainer.class */
public final class Trainer {
    private long startTime;
    private double bestAccuracy;
    private final Validator validationHelper;
    private final TokensEncodersPool tokensEncodersPool;
    private final ParamsOptimizer<HANParameters> classifierOptimizer;
    private final HANClassifier classifier;
    private final TokensEncoderOptimizer tokensEncoderOptimizer;
    private final Function0<Unit> onSaveModel;

    public final void train(@NotNull List<Example> list, int i, int i2, @Nullable Shuffler shuffler, @Nullable List<Example> list2, @Nullable String str) {
        Intrinsics.checkParameterIsNotNull(list, "trainingSet");
        IntIterator it = RangesKt.until(0, i).iterator();
        while (it.hasNext()) {
            System.out.println((Object) ("\nEpoch " + (it.nextInt() + 1) + " of " + i));
            startTiming();
            newEpoch();
            trainEpoch(list, i2, shuffler);
            Object[] objArr = {formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
            if (list2 != null) {
                validateAndSaveModel(list2, str);
            }
        }
    }

    public static /* bridge */ /* synthetic */ void train$default(Trainer trainer, List list, int i, int i2, Shuffler shuffler, List list2, String str, int i3, Object obj) {
        if ((i3 & 4) != 0) {
            i2 = 1;
        }
        if ((i3 & 8) != 0) {
            shuffler = new Shuffler(true, 743L);
        }
        if ((i3 & 16) != 0) {
            list2 = (List) null;
        }
        if ((i3 & 32) != 0) {
            str = (String) null;
        }
        trainer.train(list, i, i2, shuffler, list2, str);
    }

    private final void trainEpoch(List<Example> list, int i, Shuffler shuffler) {
        ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(list.size(), (OutputStream) null, 0, 6, (DefaultConstructorMarker) null);
        int i2 = 0;
        Iterator it = new ExamplesIndices(list.size(), shuffler).iterator();
        while (it.hasNext()) {
            int intValue = ((Number) it.next()).intValue();
            i2++;
            ProgressIndicator.tick$default(progressIndicatorBar, 0, 1, (Object) null);
            if ((i2 - 1) % i == 0) {
                newBatch();
            }
            newExample();
            learnFromExample(list.get(intValue));
            if (i2 % i == 0 || i2 == list.size()) {
                update();
            }
        }
    }

    private final void learnFromExample(Example example) {
        List<Sentence<FormToken>> sentences = example.getSentences();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(sentences, 10));
        Iterator<T> it = sentences.iterator();
        while (it.hasNext()) {
            arrayList.add(this.tokensEncodersPool.getItem());
        }
        ArrayList arrayList2 = arrayList;
        HANClassifier hANClassifier = this.classifier;
        List<Sentence<FormToken>> sentences2 = example.getSentences();
        Iterator<T> it2 = sentences2.iterator();
        Iterator it3 = arrayList2.iterator();
        ArrayList arrayList3 = new ArrayList(Math.min(CollectionsKt.collectionSizeOrDefault(sentences2, 10), CollectionsKt.collectionSizeOrDefault(arrayList2, 10)));
        while (it2.hasNext() && it3.hasNext()) {
            arrayList3.add(new EncodedSentence((List) ((TokensEncoder) it3.next()).forward((Sentence) it2.next())));
        }
        DenseNDArray copy = hANClassifier.forward((List<EncodedSentence>) arrayList3).copy();
        copy.set(example.getOutputGold(), Double.valueOf(copy.get(example.getOutputGold()).doubleValue() - 1));
        this.classifier.backward(copy);
        Optimizer.accumulate$default(this.classifierOptimizer, this.classifier.m1getParamsErrors(false), false, 2, (Object) null);
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            List<EncodedSentence> m0getInputErrors = this.classifier.m0getInputErrors(false);
            ArrayList arrayList4 = arrayList2;
            Iterator<T> it4 = m0getInputErrors.iterator();
            Iterator it5 = arrayList4.iterator();
            ArrayList arrayList5 = new ArrayList(Math.min(CollectionsKt.collectionSizeOrDefault(m0getInputErrors, 10), CollectionsKt.collectionSizeOrDefault(arrayList4, 10)));
            while (it4.hasNext() && it5.hasNext()) {
                Object next = it4.next();
                TokensEncoder tokensEncoder = (TokensEncoder) it5.next();
                tokensEncoder.backward(((EncodedSentence) next).getTokens());
                Optimizer.accumulate$default(tokensEncoderOptimizer, NeuralProcessor.DefaultImpls.getParamsErrors$default(tokensEncoder, false, 1, (Object) null), false, 2, (Object) null);
                arrayList5.add(Unit.INSTANCE);
            }
        }
    }

    private final void newEpoch() {
        this.classifierOptimizer.newEpoch();
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            tokensEncoderOptimizer.newEpoch();
        }
    }

    private final void newBatch() {
        this.classifierOptimizer.newBatch();
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            tokensEncoderOptimizer.newBatch();
        }
    }

    private final void newExample() {
        this.classifierOptimizer.newExample();
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            tokensEncoderOptimizer.newExample();
        }
    }

    private final void update() {
        this.classifierOptimizer.update();
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            tokensEncoderOptimizer.update();
        }
    }

    private final void validateAndSaveModel(List<Example> list, String str) {
        double validateEpoch = validateEpoch(list);
        Object[] objArr = {Double.valueOf(100.0d * validateEpoch)};
        String format = String.format("Accuracy: %.2f%%", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        if (str == null || validateEpoch <= this.bestAccuracy) {
            return;
        }
        this.bestAccuracy = validateEpoch;
        this.classifier.getModel().dump(new FileOutputStream(new File(str)));
        System.out.println((Object) ("NEW BEST ACCURACY! Model saved to \"" + str + '\"'));
        this.onSaveModel.invoke();
    }

    private final double validateEpoch(List<Example> list) {
        Object[] objArr = {Integer.valueOf(list.size())};
        String format = String.format("Epoch validation on %d sentences", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        return this.validationHelper.validate(list);
    }

    private final void startTiming() {
        this.startTime = System.currentTimeMillis();
    }

    private final String formatElapsedTime() {
        double currentTimeMillis = (System.currentTimeMillis() - this.startTime) / 1000.0d;
        Object[] objArr = {Double.valueOf(currentTimeMillis), Double.valueOf(currentTimeMillis / 60.0d)};
        String format = String.format("%.3f s (%.1f min)", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        return format;
    }

    public Trainer(@NotNull HANClassifier hANClassifier, @NotNull TokensEncoder<FormToken, Sentence<FormToken>> tokensEncoder, @Nullable TokensEncoderOptimizer tokensEncoderOptimizer, @NotNull UpdateMethod<?> updateMethod, @NotNull Function0<Unit> function0) {
        Intrinsics.checkParameterIsNotNull(hANClassifier, "classifier");
        Intrinsics.checkParameterIsNotNull(tokensEncoder, "tokensEncoder");
        Intrinsics.checkParameterIsNotNull(updateMethod, "updateMethod");
        Intrinsics.checkParameterIsNotNull(function0, "onSaveModel");
        this.classifier = hANClassifier;
        this.tokensEncoderOptimizer = tokensEncoderOptimizer;
        this.onSaveModel = function0;
        this.validationHelper = new Validator(this.classifier.getModel(), tokensEncoder.getModel());
        this.tokensEncodersPool = new TokensEncodersPool(tokensEncoder.getModel(), tokensEncoder.getUseDropout());
        this.classifierOptimizer = new ParamsOptimizer<>(this.classifier.getModel().getHan().getParams(), updateMethod);
    }

    public /* synthetic */ Trainer(HANClassifier hANClassifier, TokensEncoder tokensEncoder, TokensEncoderOptimizer tokensEncoderOptimizer, UpdateMethod updateMethod, Function0 function0, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(hANClassifier, tokensEncoder, (i & 4) != 0 ? (TokensEncoderOptimizer) null : tokensEncoderOptimizer, updateMethod, (i & 16) != 0 ? new Function0<Unit>() { // from class: com.kotlinnlp.hanclassifier.helpers.Trainer.1
            public /* bridge */ /* synthetic */ Object invoke() {
                m5invoke();
                return Unit.INSTANCE;
            }

            /* renamed from: invoke, reason: collision with other method in class */
            public final void m5invoke() {
            }
        } : function0);
    }
}
