package com.kotlinnlp.neuraltokenizer.helpers;

import com.kotlinnlp.neuraltokenizer.NeuralTokenizer;
import com.kotlinnlp.neuraltokenizer.NeuralTokenizerOptimizer;
import com.kotlinnlp.neuraltokenizer.helpers.ValidationHelper;
import com.kotlinnlp.neuraltokenizer.utils.DatasetUtilsKt;
import com.kotlinnlp.simplednn.core.embeddings.Embedding;
import com.kotlinnlp.simplednn.core.embeddings.EmbeddingsMap;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.utils.progressindicator.ProgressIndicatorBar;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.Unit;
import kotlin.collections.IntIterator;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.Ref;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: TrainingHelper.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0082\u0001\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0006\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0010\b\n��\n\u0002\u0010\t\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0010\r\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0018\u0002\n\u0002\b\u0005\u0018��2\u00020\u0001B\u0017\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J\u0010\u0010\u0014\u001a\u00020\u00152\u0006\u0010\u0016\u001a\u00020\u0017H\u0002J\u0016\u0010\u0018\u001a\u00020\u00152\f\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\u001a0\fH\u0002J$\u0010\u001b\u001a\u00020\u00152\f\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\u001a0\f2\f\u0010\u001c\u001a\b\u0012\u0004\u0012\u00020\r0\fH\u0002J\b\u0010\u001d\u001a\u00020\u0015H\u0002JB\u0010\u001e\u001a\u00020\u00152\u0006\u0010\u001f\u001a\u00020 2\u0010\u0010!\u001a\f\u0012\u0004\u0012\u00020\r0\fj\u0002`\"2\u001e\u0010#\u001a\u001a\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0%\u0012\u0004\u0012\u00020\u00150$H\u0002J\b\u0010&\u001a\u00020 H\u0002J\u0010\u0010'\u001a\u00020\b2\u0006\u0010(\u001a\u00020)H\u0002J\b\u0010*\u001a\u00020\rH\u0002J\b\u0010+\u001a\u00020\rH\u0002J4\u0010,\u001a\u00020\u00152*\u0010-\u001a&\u0012\u001e\u0012\u001c\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`\"0%j\u0002`.0\fj\u0002`/H\u0002J \u00100\u001a\u00020\u00152\u0006\u0010\u001f\u001a\u00020 2\u0006\u00101\u001a\u00020\r2\u0006\u00102\u001a\u00020\rH\u0002J\b\u00103\u001a\u00020\u0015H\u0002J\b\u00104\u001a\u00020\u0015H\u0002J\u0090\u0001\u00105\u001a\u00020\u00152*\u0010-\u001a&\u0012\u001e\u0012\u001c\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`\"0%j\u0002`.0\fj\u0002`/2\b\b\u0002\u00106\u001a\u00020\r2\b\b\u0002\u00107\u001a\u00020\r20\b\u0002\u00108\u001a*\u0012\u001e\u0012\u001c\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`\"0%j\u0002`.\u0018\u00010\fj\u0004\u0018\u0001`/2\n\b\u0002\u00109\u001a\u0004\u0018\u00010:2\n\b\u0002\u0010;\u001a\u0004\u0018\u00010 J*\u0010<\u001a\u00020\u00152\u0006\u0010\u001f\u001a\u00020 2\u0010\u0010!\u001a\f\u0012\u0004\u0012\u00020\r0\fj\u0002`\"2\u0006\u00106\u001a\u00020\rH\u0002J>\u0010=\u001a\u00020\u00152*\u00108\u001a&\u0012\u001e\u0012\u001c\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`\"0%j\u0002`.0\fj\u0002`/2\b\u0010;\u001a\u0004\u0018\u00010 H\u0002J4\u0010>\u001a\u00020\b2*\u00108\u001a&\u0012\u001e\u0012\u001c\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`\"0%j\u0002`.0\fj\u0002`/H\u0002R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u000e¢\u0006\u0002\n��R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\t\u0010\nR\u0014\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\r0\fX\u0082.¢\u0006\u0002\n��R\u000e\u0010\u000e\u001a\u00020\u000fX\u0082\u000e¢\u0006\u0002\n��R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011R\u000e\u0010\u0012\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��¨\u0006?"}, d2 = {"Lcom/kotlinnlp/neuraltokenizer/helpers/TrainingHelper;", "", "tokenizer", "Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizer;", "optimizer", "Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizerOptimizer;", "(Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizer;Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizerOptimizer;)V", "bestAccuracy", "", "getOptimizer", "()Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizerOptimizer;", "segmentGoldClassification", "", "", "startTime", "", "getTokenizer", "()Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizer;", "validationHelper", "Lcom/kotlinnlp/neuraltokenizer/helpers/ValidationHelper;", "accumulateErrors", "", "segment", "", "backward", "segmentClassification", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "backwardBoundariesClassifier", "goldSegmentClassification", "endOfBatch", "forEachSegment", "text", "", "goldClassifications", "Lcom/kotlinnlp/neuraltokenizer/utils/CharsClassification;", "callback", "Lkotlin/Function1;", "Lkotlin/Pair;", "formatElapsedTime", "getAccuracy", "stats", "Lcom/kotlinnlp/neuraltokenizer/helpers/ValidationHelper$EvaluationStats;", "getMiddleTokenBoundary", "getShiftCharIndex", "initEmbeddings", "trainingSet", "Lcom/kotlinnlp/neuraltokenizer/utils/AnnotatedSentence;", "Lcom/kotlinnlp/neuraltokenizer/utils/Dataset;", "learnFromExample", "start", "length", "resetValidationStats", "startTiming", "train", "batchSize", "epochs", "validationSet", "shuffler", "Lcom/kotlinnlp/simplednn/dataset/Shuffler;", "modelFilename", "trainEpoch", "validateAndSaveModel", "validateEpoch", "neuraltokenizer"})
/* loaded from: input_file:com/kotlinnlp/neuraltokenizer/helpers/TrainingHelper.class */
public final class TrainingHelper {
    private long startTime;
    private final ValidationHelper validationHelper;
    private double bestAccuracy;
    private List<Integer> segmentGoldClassification;

    @NotNull
    private final NeuralTokenizer tokenizer;

    @NotNull
    private final NeuralTokenizerOptimizer optimizer;

    public final void train(@NotNull List<? extends Pair<String, ? extends List<Integer>>> list, int i, int i2, @Nullable List<? extends Pair<String, ? extends List<Integer>>> list2, @Nullable Shuffler shuffler, @Nullable String str) {
        Intrinsics.checkParameterIsNotNull(list, "trainingSet");
        Object[] objArr = {Integer.valueOf(list.size())};
        String format = String.format("-- START TRAINING ON %d SENTENCES", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        resetValidationStats();
        initEmbeddings(list);
        IntIterator it = RangesKt.until(0, i2).iterator();
        while (it.hasNext()) {
            System.out.println((Object) ("\nEpoch " + (it.nextInt() + 1) + " of " + i2));
            startTiming();
            Pair<String, List<Integer>> mergeDataset = DatasetUtilsKt.mergeDataset(shuffler != null ? DatasetUtilsKt.shuffleDataset(list, shuffler) : list);
            trainEpoch((String) mergeDataset.getFirst(), (List) mergeDataset.getSecond(), i);
            Object[] objArr2 = {formatElapsedTime()};
            String format2 = String.format("Elapsed time: %s", Arrays.copyOf(objArr2, objArr2.length));
            Intrinsics.checkExpressionValueIsNotNull(format2, "java.lang.String.format(this, *args)");
            System.out.println((Object) format2);
            if (list2 != null) {
                validateAndSaveModel(list2, str);
            }
        }
    }

    public static /* bridge */ /* synthetic */ void train$default(TrainingHelper trainingHelper, List list, int i, int i2, List list2, Shuffler shuffler, String str, int i3, Object obj) {
        if ((i3 & 2) != 0) {
            i = 1;
        }
        if ((i3 & 4) != 0) {
            i2 = 3;
        }
        if ((i3 & 8) != 0) {
            list2 = (List) null;
        }
        if ((i3 & 16) != 0) {
            shuffler = (Shuffler) null;
        }
        if ((i3 & 32) != 0) {
            str = (String) null;
        }
        trainingHelper.train(list, i, i2, list2, shuffler, str);
    }

    private final void initEmbeddings(List<? extends Pair<String, ? extends List<Integer>>> list) {
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            String str = (String) ((Pair) it.next()).component1();
            for (int i = 0; i < str.length(); i++) {
                char charAt = str.charAt(i);
                if (!this.tokenizer.getModel().getEmbeddings().contains(Character.valueOf(charAt))) {
                    EmbeddingsMap.set$default(this.tokenizer.getModel().getEmbeddings(), Character.valueOf(charAt), (Embedding) null, 2, (Object) null);
                }
            }
        }
    }

    private final void trainEpoch(final String str, List<Integer> list, final int i) {
        if (!(str.length() == list.size())) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        final Ref.IntRef intRef = new Ref.IntRef();
        intRef.element = 0;
        this.optimizer.newEpoch();
        forEachSegment(str, list, new Function1<Pair<? extends Integer, ? extends Integer>, Unit>() { // from class: com.kotlinnlp.neuraltokenizer.helpers.TrainingHelper$trainEpoch$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                invoke((Pair<Integer, Integer>) obj);
                return Unit.INSTANCE;
            }

            public final void invoke(@NotNull Pair<Integer, Integer> pair) {
                Intrinsics.checkParameterIsNotNull(pair, "<name for destructuring parameter 0>");
                int intValue = ((Number) pair.component1()).intValue();
                int intValue2 = ((Number) pair.component2()).intValue();
                intRef.element++;
                TrainingHelper.this.learnFromExample(str, intValue, intValue2 - intValue);
                if (intRef.element % i == 0) {
                    TrainingHelper.this.endOfBatch();
                }
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(1);
            }
        });
        if (intRef.element % i > 0) {
            endOfBatch();
        }
    }

    private final void forEachSegment(String str, List<Integer> list, Function1<? super Pair<Integer, Integer>, Unit> function1) {
        ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(str.length(), (OutputStream) null, 0, 6, (DefaultConstructorMarker) null);
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= str.length()) {
                return;
            }
            int min = Math.min(i2 + this.tokenizer.getModel().getMaxSegmentSize(), str.length());
            this.segmentGoldClassification = list.subList(i2, min);
            function1.invoke(new Pair(Integer.valueOf(i2), Integer.valueOf(min)));
            int shiftCharIndex = getShiftCharIndex();
            progressIndicatorBar.tick(shiftCharIndex + 1);
            i = i2 + shiftCharIndex + 1;
        }
    }

    private final int getShiftCharIndex() {
        List<Integer> list = this.segmentGoldClassification;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
        }
        int size = list.size();
        int i = size - 1;
        while (i >= 0) {
            List<Integer> list2 = this.segmentGoldClassification;
            if (list2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
            }
            if (list2.get(i).intValue() == 1) {
                break;
            }
            i--;
        }
        if (i >= 0) {
            return i;
        }
        int middleTokenBoundary = getMiddleTokenBoundary();
        return middleTokenBoundary >= 0 ? middleTokenBoundary : size / 2;
    }

    private final int getMiddleTokenBoundary() {
        List<Integer> list = this.segmentGoldClassification;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
        }
        int size = list.size();
        int i = size / 2;
        int i2 = i;
        while (i2 < size) {
            List<Integer> list2 = this.segmentGoldClassification;
            if (list2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
            }
            if (list2.get(i2).intValue() == 0) {
                break;
            }
            i2++;
        }
        if (i2 >= size) {
            i2 = i;
            while (i2 > 0) {
                List<Integer> list3 = this.segmentGoldClassification;
                if (list3 == null) {
                    Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
                }
                if (list3.get(i2).intValue() == 0) {
                    break;
                }
                i2--;
            }
        }
        int i3 = i2;
        if (0 <= i3 && size > i3) {
            return i2;
        }
        return -1;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final void learnFromExample(String str, int i, int i2) {
        this.optimizer.newExample();
        backward(this.tokenizer.classifyChars(str, i, i2));
        accumulateErrors(str.subSequence(i, i + i2));
    }

    private final void accumulateErrors(CharSequence charSequence) {
        Optimizer.accumulate$default(this.optimizer.getCharsEncoderOptimizer(), this.tokenizer.getCharsEncoder().getParamsErrors(false), false, 2, (Object) null);
        Optimizer.accumulate$default(this.optimizer.getBoundariesClassifierOptimizer(), this.tokenizer.getBoundariesClassifier().getParamsErrors(false), false, 2, (Object) null);
        int i = 0;
        for (Object obj : this.tokenizer.getCharsEncoder().getInputErrors(false)) {
            int i2 = i;
            i++;
            DenseNDArray denseNDArray = (DenseNDArray) obj;
            this.optimizer.getEmbeddingsOptimizer().accumulate(Character.valueOf(charSequence.charAt(i2)), denseNDArray.getRange(0, denseNDArray.getLength() - this.tokenizer.getModel().getAddingFeaturesSize()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final void endOfBatch() {
        this.optimizer.newBatch();
        this.optimizer.update();
    }

    private final void backward(List<DenseNDArray> list) {
        List<Integer> list2 = this.segmentGoldClassification;
        if (list2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
        }
        backwardBoundariesClassifier(list, list2);
        this.tokenizer.getCharsEncoder().backward(this.tokenizer.getBoundariesClassifier().getInputErrors(false));
    }

    private final void backwardBoundariesClassifier(List<DenseNDArray> list, List<Integer> list2) {
        BatchFeedforwardProcessor<DenseNDArray> boundariesClassifier = this.tokenizer.getBoundariesClassifier();
        int size = list2.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            int i2 = i;
            DenseNDArray denseNDArray = list.get(i2);
            int intValue = list2.get(i2).intValue();
            denseNDArray.set(intValue, Double.valueOf(denseNDArray.get(intValue).doubleValue() - 1));
            arrayList.add(denseNDArray);
        }
        boundariesClassifier.backward(arrayList);
    }

    private final void validateAndSaveModel(List<? extends Pair<String, ? extends List<Integer>>> list, String str) {
        double validateEpoch = validateEpoch(list);
        if (str == null || validateEpoch <= this.bestAccuracy) {
            return;
        }
        this.bestAccuracy = validateEpoch;
        this.tokenizer.getModel().dump(new FileOutputStream(new File(str)));
        System.out.println((Object) ("NEW BEST ACCURACY! Model saved to \"" + str + '\"'));
    }

    private final double validateEpoch(List<? extends Pair<String, ? extends List<Integer>>> list) {
        Object[] objArr = {Integer.valueOf(list.size())};
        String format = String.format("Epoch validation on %d sentences", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        ValidationHelper.EvaluationStats validate = this.validationHelper.validate(list);
        Object[] objArr2 = {Double.valueOf(100.0d * validate.getTokens().getPrecision()), Double.valueOf(100.0d * validate.getTokens().getRecall()), Double.valueOf(100.0d * validate.getTokens().getF1Score())};
        String format2 = String.format("Tokens accuracy     ->   Precision: %.2f%%  |  Recall: %.2f%%  |  F1 Score: %.2f%%", Arrays.copyOf(objArr2, objArr2.length));
        Intrinsics.checkExpressionValueIsNotNull(format2, "java.lang.String.format(this, *args)");
        System.out.println((Object) format2);
        Object[] objArr3 = {Double.valueOf(100.0d * validate.getSentences().getPrecision()), Double.valueOf(100.0d * validate.getSentences().getRecall()), Double.valueOf(100.0d * validate.getSentences().getF1Score())};
        String format3 = String.format("Sentences accuracy  ->   Precision: %.2f%%  |  Recall: %.2f%%  |  F1 Score: %.2f%%", Arrays.copyOf(objArr3, objArr3.length));
        Intrinsics.checkExpressionValueIsNotNull(format3, "java.lang.String.format(this, *args)");
        System.out.println((Object) format3);
        return getAccuracy(validate);
    }

    private final void resetValidationStats() {
        this.bestAccuracy = 0.0d;
    }

    private final double getAccuracy(ValidationHelper.EvaluationStats evaluationStats) {
        return evaluationStats.getTokens().getF1Score() * Math.pow(evaluationStats.getSentences().getF1Score(), 0.5d);
    }

    private final void startTiming() {
        this.startTime = System.currentTimeMillis();
    }

    private final String formatElapsedTime() {
        double currentTimeMillis = (System.currentTimeMillis() - this.startTime) / 1000.0d;
        Object[] objArr = {Double.valueOf(currentTimeMillis), Double.valueOf(currentTimeMillis / 60.0d)};
        String format = String.format("%.3f s (%.1f min)", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        return format;
    }

    @NotNull
    public final NeuralTokenizer getTokenizer() {
        return this.tokenizer;
    }

    @NotNull
    public final NeuralTokenizerOptimizer getOptimizer() {
        return this.optimizer;
    }

    public TrainingHelper(@NotNull NeuralTokenizer neuralTokenizer, @NotNull NeuralTokenizerOptimizer neuralTokenizerOptimizer) {
        Intrinsics.checkParameterIsNotNull(neuralTokenizer, "tokenizer");
        Intrinsics.checkParameterIsNotNull(neuralTokenizerOptimizer, "optimizer");
        this.tokenizer = neuralTokenizer;
        this.optimizer = neuralTokenizerOptimizer;
        this.validationHelper = new ValidationHelper(this.tokenizer.getModel());
    }

    public /* synthetic */ TrainingHelper(NeuralTokenizer neuralTokenizer, NeuralTokenizerOptimizer neuralTokenizerOptimizer, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(neuralTokenizer, (i & 2) != 0 ? new NeuralTokenizerOptimizer(neuralTokenizer.getModel(), null, null, null, 14, null) : neuralTokenizerOptimizer);
    }
}
