package com.kotlinnlp.frameextractor.helpers;

import com.kotlinnlp.frameextractor.FrameExtractor;
import com.kotlinnlp.frameextractor.FrameExtractorModel;
import com.kotlinnlp.frameextractor.FrameExtractorParameters;
import com.kotlinnlp.frameextractor.helpers.dataset.Dataset;
import com.kotlinnlp.frameextractor.helpers.dataset.EncodedDataset;
import com.kotlinnlp.frameextractor.helpers.dataset.IOBTag;
import com.kotlinnlp.frameextractor.objects.Intent;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.adam.ADAMMethod;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.helpers.training.utils.ExamplesIndices;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.simplednn.utils.scheduling.BatchScheduling;
import com.kotlinnlp.simplednn.utils.scheduling.EpochScheduling;
import com.kotlinnlp.utils.Timer;
import com.kotlinnlp.utils.progressindicator.ProgressIndicator;
import com.kotlinnlp.utils.progressindicator.ProgressIndicatorBar;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Trainer.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��t\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\u0006\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2\u00020\u0001B=\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\f\b\u0002\u0010\b\u001a\u0006\u0012\u0002\b\u00030\t\u0012\u0006\u0010\n\u001a\u00020\u000b\u0012\b\b\u0002\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ\b\u0010\u0019\u001a\u00020\u001aH\u0002J\u0010\u0010\u001b\u001a\u00020\u001a2\u0006\u0010\u001c\u001a\u00020\u0007H\u0002J\b\u0010\u001d\u001a\u00020\u001aH\u0002J\b\u0010\u001e\u001a\u00020\u001aH\u0002J\b\u0010\u001f\u001a\u00020\u001aH\u0002J\b\u0010 \u001a\u00020\u001aH\u0002J\u001a\u0010!\u001a\u00020\u001a2\u0006\u0010\"\u001a\u00020#2\n\b\u0002\u0010$\u001a\u0004\u0018\u00010%J\u001a\u0010&\u001a\u00020\u001a2\u0006\u0010\"\u001a\u00020#2\b\u0010$\u001a\u0004\u0018\u00010%H\u0002J(\u0010'\u001a\u00020\u001a2\u0006\u0010(\u001a\u00020)2\u0006\u0010*\u001a\u00020\u00072\u0006\u0010+\u001a\u00020,2\u0006\u0010-\u001a\u00020\u0007H\u0002J\b\u0010.\u001a\u00020\u001aH\u0002R\u000e\u0010\u000f\u001a\u00020\u0010X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0011\u001a\u00020\u0007X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00160\u0015X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0017\u001a\u00020\u0018X\u0082\u000e¢\u0006\u0002\n��R\u0012\u0010\b\u001a\u0006\u0012\u0002\b\u00030\tX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\n\u001a\u00020\u000bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��¨\u0006/"}, d2 = {"Lcom/kotlinnlp/frameextractor/helpers/Trainer;", "", "model", "Lcom/kotlinnlp/frameextractor/FrameExtractorModel;", "modelFilename", "", "epochs", "", "updateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "validator", "Lcom/kotlinnlp/frameextractor/helpers/Validator;", "verbose", "", "(Lcom/kotlinnlp/frameextractor/FrameExtractorModel;Ljava/lang/String;ILcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;Lcom/kotlinnlp/frameextractor/helpers/Validator;Z)V", "bestAccuracy", "", "epochCount", "extractor", "Lcom/kotlinnlp/frameextractor/FrameExtractor;", "optimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/frameextractor/FrameExtractorParameters;", "timer", "Lcom/kotlinnlp/utils/Timer;", "logTrainingEnd", "", "logTrainingStart", "epochIndex", "logValidationEnd", "logValidationStart", "newBatch", "newEpoch", "train", "dataset", "Lcom/kotlinnlp/frameextractor/helpers/dataset/EncodedDataset;", "shuffler", "Lcom/kotlinnlp/simplednn/dataset/Shuffler;", "trainEpoch", "trainExample", "example", "Lcom/kotlinnlp/frameextractor/helpers/dataset/EncodedDataset$Example;", "intentIndex", "intentConfig", "Lcom/kotlinnlp/frameextractor/objects/Intent$Configuration;", "slotsOffset", "validateAndSaveModel", "frameextractor"})
/* loaded from: input_file:com/kotlinnlp/frameextractor/helpers/Trainer.class */
public final class Trainer {
    private int epochCount;
    private Timer timer;
    private double bestAccuracy;
    private final FrameExtractor extractor;
    private final ParamsOptimizer<FrameExtractorParameters> optimizer;
    private final FrameExtractorModel model;
    private final String modelFilename;
    private final int epochs;
    private final UpdateMethod<?> updateMethod;
    private final Validator validator;
    private final boolean verbose;

    public final void train(@NotNull EncodedDataset dataset, @Nullable Shuffler shuffler) {
        Intrinsics.checkParameterIsNotNull(dataset, "dataset");
        Iterator<Integer> it = RangesKt.until(0, this.epochs).iterator();
        while (it.hasNext()) {
            logTrainingStart(((IntIterator) it).nextInt());
            newEpoch();
            trainEpoch(dataset, shuffler);
            logTrainingEnd();
            logValidationStart();
            validateAndSaveModel();
            logValidationEnd();
        }
    }

    public static /* bridge */ /* synthetic */ void train$default(Trainer trainer, EncodedDataset encodedDataset, Shuffler shuffler, int i, Object obj) {
        if ((i & 2) != 0) {
            shuffler = new Shuffler(true, 743L);
        }
        trainer.train(encodedDataset, shuffler);
    }

    private final void trainEpoch(EncodedDataset encodedDataset, Shuffler shuffler) {
        int i;
        ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(encodedDataset.getExamples().size(), null, 0, 6, null);
        Iterator<Integer> it = new ExamplesIndices(encodedDataset.getExamples().size(), shuffler).iterator();
        while (it.hasNext()) {
            EncodedDataset.Example example = encodedDataset.getExamples().get(it.next().intValue());
            int i2 = 0;
            Iterator<Intent.Configuration> it2 = encodedDataset.getConfiguration().iterator();
            while (true) {
                if (!it2.hasNext()) {
                    i = -1;
                    break;
                } else {
                    if (Intrinsics.areEqual(it2.next().getName(), example.getIntent())) {
                        i = i2;
                        break;
                    }
                    i2++;
                }
            }
            int i3 = i;
            if (this.verbose) {
                ProgressIndicator.tick$default(progressIndicatorBar, 0, 1, null);
            }
            newBatch();
            trainExample(example, i3, encodedDataset.getConfiguration().get(i3), this.extractor.getSlotsOffset(example.getIntent()));
        }
    }

    private final void trainExample(EncodedDataset.Example example, int i, Intent.Configuration configuration, int i2) {
        FrameExtractor frameExtractor = this.extractor;
        List<EncodedDataset.Example.Token> tokens = example.getTokens();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(tokens, 10));
        Iterator<T> it = tokens.iterator();
        while (it.hasNext()) {
            arrayList.add(((EncodedDataset.Example.Token) it.next()).getEncoding());
        }
        FrameExtractor.Output forward2 = frameExtractor.forward2((List<DenseNDArray>) arrayList);
        DenseNDArray sub = forward2.getIntentsDistribution().sub(DenseNDArrayFactory.INSTANCE.oneHotEncoder(forward2.getIntentsDistribution().getLength(), i));
        List<Pair> zip = CollectionsKt.zip(forward2.getSlotsClassifications(), example.getTokens());
        ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(zip, 10));
        for (Pair pair : zip) {
            DenseNDArray denseNDArray = (DenseNDArray) pair.component1();
            Dataset.Example.Slot slot = ((EncodedDataset.Example.Token) pair.component2()).getSlot();
            arrayList2.add(denseNDArray.sub(DenseNDArrayFactory.INSTANCE.oneHotEncoder(denseNDArray.getLength(), (2 * (i2 + configuration.getSlotIndex(slot.getName()))) + (slot.getIob() == IOBTag.Beginning ? 0 : 1))));
        }
        this.extractor.backward(new FrameExtractor.Output(this.extractor, sub, arrayList2));
        this.optimizer.accumulate((ParamsOptimizer<FrameExtractorParameters>) this.extractor.getParamsErrors2(false), false);
        this.optimizer.update();
    }

    private final void newBatch() {
        if (this.updateMethod instanceof BatchScheduling) {
            ((BatchScheduling) this.updateMethod).newBatch();
        }
    }

    private final void newEpoch() {
        if (this.updateMethod instanceof EpochScheduling) {
            ((EpochScheduling) this.updateMethod).newEpoch();
        }
        this.epochCount++;
    }

    private final void validateAndSaveModel() {
        Statistics evaluate = this.validator.evaluate();
        double f1Score = evaluate.getIntents().getF1Score() * evaluate.getSlots().getF1Score();
        System.out.println((Object) ("\nStatistics\n" + evaluate));
        if (f1Score > this.bestAccuracy) {
            this.model.dump(new FileOutputStream(new File(this.modelFilename)));
            System.out.println((Object) ("\nNEW BEST ACCURACY! Model saved to \"" + this.modelFilename + '\"'));
            this.bestAccuracy = f1Score;
        }
    }

    private final void logTrainingStart(int i) {
        if (this.verbose) {
            this.timer.reset();
            System.out.println((Object) ("\nEpoch " + (i + 1) + " of " + this.epochs));
            System.out.println((Object) "\nStart training...");
        }
    }

    private final void logTrainingEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    private final void logValidationStart() {
        if (this.verbose) {
            this.timer.reset();
            System.out.println((Object) "\nStart validation...");
        }
    }

    private final void logValidationEnd() {
        if (this.verbose) {
            Object[] objArr = {this.timer.formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
        }
    }

    public Trainer(@NotNull FrameExtractorModel model, @NotNull String modelFilename, int i, @NotNull UpdateMethod<?> updateMethod, @NotNull Validator validator, boolean z) {
        Intrinsics.checkParameterIsNotNull(model, "model");
        Intrinsics.checkParameterIsNotNull(modelFilename, "modelFilename");
        Intrinsics.checkParameterIsNotNull(updateMethod, "updateMethod");
        Intrinsics.checkParameterIsNotNull(validator, "validator");
        this.model = model;
        this.modelFilename = modelFilename;
        this.epochs = i;
        this.updateMethod = updateMethod;
        this.validator = validator;
        this.verbose = z;
        this.timer = new Timer();
        this.bestAccuracy = -1.0d;
        this.extractor = new FrameExtractor(this.model, false, 0, 6, null);
        this.optimizer = new ParamsOptimizer<>(this.model.getParams(), this.updateMethod);
        if (!(this.epochs > 0)) {
            throw new IllegalArgumentException("The number of epochs must be > 0".toString());
        }
    }

    public /* synthetic */ Trainer(FrameExtractorModel frameExtractorModel, String str, int i, UpdateMethod updateMethod, Validator validator, boolean z, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(frameExtractorModel, str, i, (i2 & 8) != 0 ? new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, null, 24, null) : updateMethod, validator, (i2 & 32) != 0 ? true : z);
    }
}
