package com.kotlinnlp.neuraltokenizer.helpers;

import com.kotlinnlp.neuraltokenizer.NeuralTokenizer;
import com.kotlinnlp.neuraltokenizer.NeuralTokenizerOptimizer;
import com.kotlinnlp.neuraltokenizer.helpers.ValidationHelper;
import com.kotlinnlp.neuraltokenizer.utils.DatasetUtilsKt;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.deeplearning.embeddings.EmbeddingsMap;
import com.kotlinnlp.simplednn.deeplearning.sequenceencoder.SequenceFeedforwardEncoder;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.IntIterator;
import kotlin.coroutines.experimental.SequenceBuilderKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import kotlin.sequences.Sequence;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: TrainingHelper.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��~\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0006\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0010\b\n\u0002\u0018\u0002\n��\n\u0002\u0010\t\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\b\u0005\u0018��2\u00020\u0001B\u0017\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J\u001b\u0010\u0015\u001a\u00020\u00162\f\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u00190\u0018H\u0002¢\u0006\u0002\u0010\u001aJ3\u0010\u001b\u001a\u00020\u00162\f\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u00190\u00182\u0016\u0010\u001c\u001a\u0012\u0012\u0004\u0012\u00020\r0\fj\b\u0012\u0004\u0012\u00020\r`\u000eH\u0002¢\u0006\u0002\u0010\u001dJ\b\u0010\u001e\u001a\u00020\u0016H\u0002J\b\u0010\u001f\u001a\u00020 H\u0002J\u0010\u0010!\u001a\u00020\b2\u0006\u0010\"\u001a\u00020#H\u0002J\b\u0010$\u001a\u00020\rH\u0002J\b\u0010%\u001a\u00020\rH\u0002J0\u0010&\u001a\u00020\u00162&\u0010'\u001a\"\u0012\u001a\u0012\u0018\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`)0(0\fj\u0002`*H\u0002J \u0010+\u001a\u00020\u00162\u0006\u0010,\u001a\u00020 2\u0006\u0010-\u001a\u00020\r2\u0006\u0010.\u001a\u00020\rH\u0002J:\u0010/\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0(002\u0006\u0010,\u001a\u00020 2\u0016\u00101\u001a\u0012\u0012\u0004\u0012\u00020\r0\fj\b\u0012\u0004\u0012\u00020\r`\u000eH\u0002J\b\u00102\u001a\u00020\u0016H\u0002J\b\u00103\u001a\u00020\u0016H\u0002J\u0088\u0001\u00104\u001a\u00020\u00162&\u0010'\u001a\"\u0012\u001a\u0012\u0018\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`)0(0\fj\u0002`*2\b\b\u0002\u00105\u001a\u00020\r2\b\b\u0002\u00106\u001a\u00020\r2,\b\u0002\u00107\u001a&\u0012\u001a\u0012\u0018\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`)0(\u0018\u00010\fj\u0004\u0018\u0001`*2\n\b\u0002\u00108\u001a\u0004\u0018\u0001092\n\b\u0002\u0010:\u001a\u0004\u0018\u00010 J0\u0010;\u001a\u00020\u00162\u0006\u0010,\u001a\u00020 2\u0016\u00101\u001a\u0012\u0012\u0004\u0012\u00020\r0\fj\b\u0012\u0004\u0012\u00020\r`\u000e2\u0006\u00105\u001a\u00020\rH\u0002J:\u0010<\u001a\u00020\u00162&\u00107\u001a\"\u0012\u001a\u0012\u0018\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`)0(0\fj\u0002`*2\b\u0010:\u001a\u0004\u0018\u00010 H\u0002J0\u0010=\u001a\u00020\b2&\u00107\u001a\"\u0012\u001a\u0012\u0018\u0012\u0004\u0012\u00020 \u0012\u000e\u0012\f\u0012\u0004\u0012\u00020\r0\fj\u0002`)0(0\fj\u0002`*H\u0002R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u000e¢\u0006\u0002\n��R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\t\u0010\nR\u001e\u0010\u000b\u001a\u0012\u0012\u0004\u0012\u00020\r0\fj\b\u0012\u0004\u0012\u00020\r`\u000eX\u0082.¢\u0006\u0002\n��R\u000e\u0010\u000f\u001a\u00020\u0010X\u0082\u000e¢\u0006\u0002\n��R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u000e\u0010\u0013\u001a\u00020\u0014X\u0082\u0004¢\u0006\u0002\n��¨\u0006>"}, d2 = {"Lcom/kotlinnlp/neuraltokenizer/helpers/TrainingHelper;", "", "tokenizer", "Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizer;", "optimizer", "Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizerOptimizer;", "(Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizer;Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizerOptimizer;)V", "bestAccuracy", "", "getOptimizer", "()Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizerOptimizer;", "segmentGoldClassification", "Ljava/util/ArrayList;", "", "Lkotlin/collections/ArrayList;", "startTime", "", "getTokenizer", "()Lcom/kotlinnlp/neuraltokenizer/NeuralTokenizer;", "validationHelper", "Lcom/kotlinnlp/neuraltokenizer/helpers/ValidationHelper;", "backward", "", "segmentClassification", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "([Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "backwardBoundariesClassifier", "goldSegmentClassification", "([Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;Ljava/util/ArrayList;)V", "endOfBatch", "formatElapsedTime", "", "getAccuracy", "stats", "Lcom/kotlinnlp/neuraltokenizer/helpers/ValidationHelper$EvaluationStats;", "getMiddleTokenBoundary", "getShiftCharIndex", "initEmbeddings", "trainingSet", "Lkotlin/Pair;", "Lcom/kotlinnlp/neuraltokenizer/utils/CharsClassification;", "Lcom/kotlinnlp/neuraltokenizer/utils/Dataset;", "learnFromExample", "text", "start", "length", "loopSegments", "Lkotlin/sequences/Sequence;", "goldClassifications", "resetValidationStats", "startTiming", "train", "batchSize", "epochs", "validationSet", "shuffler", "Lcom/kotlinnlp/simplednn/dataset/Shuffler;", "modelFilename", "trainEpoch", "validateAndSaveModel", "validateEpoch", "neuraltokenizer"})
/* loaded from: input_file:com/kotlinnlp/neuraltokenizer/helpers/TrainingHelper.class */
public final class TrainingHelper {
    private long startTime;
    private final ValidationHelper validationHelper;
    private double bestAccuracy;
    private ArrayList<Integer> segmentGoldClassification;

    @NotNull
    private final NeuralTokenizer tokenizer;

    @NotNull
    private final NeuralTokenizerOptimizer optimizer;

    public final void train(@NotNull ArrayList<Pair<String, ArrayList<Integer>>> trainingSet, int i, int i2, @Nullable ArrayList<Pair<String, ArrayList<Integer>>> arrayList, @Nullable Shuffler shuffler, @Nullable String str) {
        Intrinsics.checkParameterIsNotNull(trainingSet, "trainingSet");
        Object[] objArr = {Integer.valueOf(trainingSet.size())};
        String format = String.format("-- START TRAINING ON %d SENTENCES", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        resetValidationStats();
        initEmbeddings(trainingSet);
        Iterator<Integer> it = RangesKt.until(0, i2).iterator();
        while (it.hasNext()) {
            System.out.println((Object) ("\nEpoch " + (((IntIterator) it).nextInt() + 1) + " of " + i2));
            startTiming();
            Pair<String, ArrayList<Integer>> mergeDataset = DatasetUtilsKt.mergeDataset(shuffler != null ? DatasetUtilsKt.shuffleDataset(trainingSet, shuffler) : trainingSet);
            trainEpoch(mergeDataset.getFirst(), mergeDataset.getSecond(), i);
            Object[] objArr2 = {formatElapsedTime()};
            String format2 = String.format("Elapsed time: %s", Arrays.copyOf(objArr2, objArr2.length));
            Intrinsics.checkExpressionValueIsNotNull(format2, "java.lang.String.format(this, *args)");
            System.out.println((Object) format2);
            if (arrayList != null) {
                validateAndSaveModel(arrayList, str);
            }
        }
    }

    public static /* bridge */ /* synthetic */ void train$default(TrainingHelper trainingHelper, ArrayList arrayList, int i, int i2, ArrayList arrayList2, Shuffler shuffler, String str, int i3, Object obj) {
        if ((i3 & 2) != 0) {
            i = 1;
        }
        if ((i3 & 4) != 0) {
            i2 = 3;
        }
        if ((i3 & 8) != 0) {
            arrayList2 = (ArrayList) null;
        }
        if ((i3 & 16) != 0) {
            shuffler = (Shuffler) null;
        }
        if ((i3 & 32) != 0) {
            str = (String) null;
        }
        trainingHelper.train(arrayList, i, i2, arrayList2, shuffler, str);
    }

    private final void initEmbeddings(ArrayList<Pair<String, ArrayList<Integer>>> arrayList) {
        Iterator<T> it = arrayList.iterator();
        while (it.hasNext()) {
            String str = (String) ((Pair) it.next()).component1();
            for (int i = 0; i < str.length(); i++) {
                char charAt = str.charAt(i);
                if (!this.tokenizer.getModel().getEmbeddings().contains(Character.valueOf(charAt))) {
                    EmbeddingsMap.set$default(this.tokenizer.getModel().getEmbeddings(), Character.valueOf(charAt), null, 2, null);
                }
            }
        }
    }

    private final void trainEpoch(String str, ArrayList<Integer> arrayList, int i) {
        if (!(str.length() == arrayList.size())) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        int i2 = 0;
        this.optimizer.newEpoch();
        for (Pair<Integer, Integer> pair : loopSegments(str, arrayList)) {
            int intValue = pair.component1().intValue();
            i2++;
            learnFromExample(str, intValue, pair.component2().intValue() - intValue);
            if (i2 % i == 0) {
                endOfBatch();
            }
        }
        if (i2 % i > 0) {
            endOfBatch();
        }
    }

    private final Sequence<Pair<Integer, Integer>> loopSegments(String str, ArrayList<Integer> arrayList) {
        return SequenceBuilderKt.buildSequence(new TrainingHelper$loopSegments$1(this, str, arrayList, null));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final int getShiftCharIndex() {
        ArrayList<Integer> arrayList = this.segmentGoldClassification;
        if (arrayList == null) {
            Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
        }
        int size = arrayList.size();
        int i = size - 1;
        while (i >= 0) {
            ArrayList<Integer> arrayList2 = this.segmentGoldClassification;
            if (arrayList2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
            }
            Integer num = arrayList2.get(i);
            if (num != null && num.intValue() == 1) {
                break;
            }
            i--;
        }
        if (i >= 0) {
            return i;
        }
        int middleTokenBoundary = getMiddleTokenBoundary();
        return middleTokenBoundary >= 0 ? middleTokenBoundary : size / 2;
    }

    private final int getMiddleTokenBoundary() {
        ArrayList<Integer> arrayList = this.segmentGoldClassification;
        if (arrayList == null) {
            Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
        }
        int size = arrayList.size();
        int i = size / 2;
        int i2 = i;
        while (i2 < size) {
            ArrayList<Integer> arrayList2 = this.segmentGoldClassification;
            if (arrayList2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
            }
            Integer num = arrayList2.get(i2);
            if (num != null && num.intValue() == 0) {
                break;
            }
            i2++;
        }
        if (i2 >= size) {
            i2 = i;
            while (i2 > 0) {
                ArrayList<Integer> arrayList3 = this.segmentGoldClassification;
                if (arrayList3 == null) {
                    Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
                }
                Integer num2 = arrayList3.get(i2);
                if (num2 != null && num2.intValue() == 0) {
                    break;
                }
                i2--;
            }
        }
        int i3 = i2;
        if (0 <= i3 && size > i3) {
            return i2;
        }
        return -1;
    }

    private final void learnFromExample(String str, int i, int i2) {
        this.optimizer.newExample();
        backward(this.tokenizer.classifyChars(str, i, i2));
        this.optimizer.accumulateErrors(str.subSequence(i, i + i2));
    }

    private final void endOfBatch() {
        this.optimizer.newBatch();
        this.optimizer.update();
    }

    private final void backward(DenseNDArray[] denseNDArrayArr) {
        ArrayList<Integer> arrayList = this.segmentGoldClassification;
        if (arrayList == null) {
            Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
        }
        backwardBoundariesClassifier(denseNDArrayArr, arrayList);
        this.tokenizer.getCharsEncoder().backward(this.tokenizer.getBoundariesClassifier().getInputSequenceErrors(false), true);
    }

    private final void backwardBoundariesClassifier(DenseNDArray[] denseNDArrayArr, ArrayList<Integer> arrayList) {
        SequenceFeedforwardEncoder<DenseNDArray> boundariesClassifier = this.tokenizer.getBoundariesClassifier();
        DenseNDArray[] denseNDArrayArr2 = new DenseNDArray[arrayList.size()];
        int length = denseNDArrayArr2.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            DenseNDArray denseNDArray = denseNDArrayArr[i2];
            Integer num = arrayList.get(i2);
            Intrinsics.checkExpressionValueIsNotNull(num, "goldSegmentClassification[i]");
            int intValue = num.intValue();
            denseNDArray.set(intValue, Double.valueOf(denseNDArray.get(intValue).doubleValue() - 1));
            denseNDArrayArr2[i] = denseNDArray;
        }
        boundariesClassifier.backward(denseNDArrayArr2, true);
    }

    private final void validateAndSaveModel(ArrayList<Pair<String, ArrayList<Integer>>> arrayList, String str) {
        double validateEpoch = validateEpoch(arrayList);
        if (str == null || validateEpoch <= this.bestAccuracy) {
            return;
        }
        this.bestAccuracy = validateEpoch;
        this.tokenizer.getModel().dump(new FileOutputStream(new File(str)));
        System.out.println((Object) ("NEW BEST ACCURACY! Model saved to \"" + str + '\"'));
    }

    private final double validateEpoch(ArrayList<Pair<String, ArrayList<Integer>>> arrayList) {
        Object[] objArr = {Integer.valueOf(arrayList.size())};
        String format = String.format("Epoch validation on %d sentences", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        ValidationHelper.EvaluationStats validate = this.validationHelper.validate(arrayList);
        Object[] objArr2 = {Double.valueOf(100.0d * validate.getTokens().getPrecision()), Double.valueOf(100.0d * validate.getTokens().getRecall()), Double.valueOf(100.0d * validate.getTokens().getF1Score())};
        String format2 = String.format("Tokens accuracy     ->   Precision: %.2f%%  |  Recall: %.2f%%  |  F1 Score: %.2f%%", Arrays.copyOf(objArr2, objArr2.length));
        Intrinsics.checkExpressionValueIsNotNull(format2, "java.lang.String.format(this, *args)");
        System.out.println((Object) format2);
        Object[] objArr3 = {Double.valueOf(100.0d * validate.getSentences().getPrecision()), Double.valueOf(100.0d * validate.getSentences().getRecall()), Double.valueOf(100.0d * validate.getSentences().getF1Score())};
        String format3 = String.format("Sentences accuracy  ->   Precision: %.2f%%  |  Recall: %.2f%%  |  F1 Score: %.2f%%", Arrays.copyOf(objArr3, objArr3.length));
        Intrinsics.checkExpressionValueIsNotNull(format3, "java.lang.String.format(this, *args)");
        System.out.println((Object) format3);
        return getAccuracy(validate);
    }

    private final void resetValidationStats() {
        this.bestAccuracy = 0.0d;
    }

    private final double getAccuracy(ValidationHelper.EvaluationStats evaluationStats) {
        return evaluationStats.getTokens().getF1Score() * Math.pow(evaluationStats.getSentences().getF1Score(), 0.5d);
    }

    private final void startTiming() {
        this.startTime = System.currentTimeMillis();
    }

    private final String formatElapsedTime() {
        double currentTimeMillis = (System.currentTimeMillis() - this.startTime) / 1000.0d;
        Object[] objArr = {Double.valueOf(currentTimeMillis), Double.valueOf(currentTimeMillis / 60.0d)};
        String format = String.format("%.3f s (%.1f min)", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        return format;
    }

    @NotNull
    public final NeuralTokenizer getTokenizer() {
        return this.tokenizer;
    }

    @NotNull
    public final NeuralTokenizerOptimizer getOptimizer() {
        return this.optimizer;
    }

    public TrainingHelper(@NotNull NeuralTokenizer tokenizer, @NotNull NeuralTokenizerOptimizer optimizer) {
        Intrinsics.checkParameterIsNotNull(tokenizer, "tokenizer");
        Intrinsics.checkParameterIsNotNull(optimizer, "optimizer");
        this.tokenizer = tokenizer;
        this.optimizer = optimizer;
        this.validationHelper = new ValidationHelper(this.tokenizer);
    }

    public /* synthetic */ TrainingHelper(NeuralTokenizer neuralTokenizer, NeuralTokenizerOptimizer neuralTokenizerOptimizer, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(neuralTokenizer, (i & 2) != 0 ? new NeuralTokenizerOptimizer(neuralTokenizer, null, null, null, 14, null) : neuralTokenizerOptimizer);
    }

    @NotNull
    public static final /* synthetic */ ArrayList access$getSegmentGoldClassification$p(TrainingHelper trainingHelper) {
        ArrayList<Integer> arrayList = trainingHelper.segmentGoldClassification;
        if (arrayList == null) {
            Intrinsics.throwUninitializedPropertyAccessException("segmentGoldClassification");
        }
        return arrayList;
    }
}
