package com.kotlinnlp.hanclassifier.helpers;

import com.kotlinnlp.hanclassifier.HANClassifier;
import com.kotlinnlp.hanclassifier.dataset.Example;
import com.kotlinnlp.progressindicator.ProgressIndicator;
import com.kotlinnlp.progressindicator.ProgressIndicatorBar;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.han.HANEncoder;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.han.HANParameters;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.han.HierarchyGroup;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.han.HierarchyItem;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.han.HierarchySequence;
import com.kotlinnlp.simplednn.deeplearning.embeddings.Embedding;
import com.kotlinnlp.simplednn.deeplearning.embeddings.EmbeddingsMap;
import com.kotlinnlp.simplednn.deeplearning.embeddings.EmbeddingsOptimizer;
import com.kotlinnlp.simplednn.helpers.training.utils.ExamplesIndices;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: TrainingHelper.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��~\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010\u000e\n��\n\u0002\u0010\t\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0010 \n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\u0018��2\u00020\u0001B%\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\n\u0010\u0004\u001a\u0006\u0012\u0002\b\u00030\u0005\u0012\n\u0010\u0006\u001a\u0006\u0012\u0002\b\u00030\u0005¢\u0006\u0002\u0010\u0007J$\u0010\u0014\u001a\u00020\u00152\u0012\u0010\u0016\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u000f0\u00170\u00172\u0006\u0010\u0018\u001a\u00020\u0019H\u0002J$\u0010\u001a\u001a\u00020\u00152\f\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u000f0\u00172\f\u0010\u001c\u001a\b\u0012\u0004\u0012\u00020\u001e0\u001dH\u0002J\b\u0010\u001f\u001a\u00020\u000fH\u0002J \u0010 \u001a\u00020\u00152\u0016\u0010!\u001a\u0012\u0012\u0004\u0012\u00020#0\"j\b\u0012\u0004\u0012\u00020#`$H\u0002J\u0010\u0010%\u001a\u00020\u00152\u0006\u0010&\u001a\u00020#H\u0002J\b\u0010'\u001a\u00020\u0015H\u0002J\b\u0010(\u001a\u00020\u0015H\u0002J\b\u0010)\u001a\u00020\u0015H\u0002J\b\u0010*\u001a\u00020\u0015H\u0002Jf\u0010+\u001a\u00020\u00152\u0016\u0010!\u001a\u0012\u0012\u0004\u0012\u00020#0\"j\b\u0012\u0004\u0012\u00020#`$2\u0006\u0010,\u001a\u00020-2\b\b\u0002\u0010.\u001a\u00020-2\n\b\u0002\u0010/\u001a\u0004\u0018\u0001002\u001c\b\u0002\u00101\u001a\u0016\u0012\u0004\u0012\u00020#\u0018\u00010\"j\n\u0012\u0004\u0012\u00020#\u0018\u0001`$2\n\b\u0002\u00102\u001a\u0004\u0018\u00010\u000fJ2\u00103\u001a\u00020\u00152\u0016\u0010!\u001a\u0012\u0012\u0004\u0012\u00020#0\"j\b\u0012\u0004\u0012\u00020#`$2\u0006\u0010.\u001a\u00020-2\b\u0010/\u001a\u0004\u0018\u000100H\u0002J\b\u00104\u001a\u00020\u0015H\u0002J*\u00105\u001a\u00020\u00152\u0016\u00101\u001a\u0012\u0012\u0004\u0012\u00020#0\"j\b\u0012\u0004\u0012\u00020#`$2\b\u00102\u001a\u0004\u0018\u00010\u000fH\u0002J \u00106\u001a\u00020\t2\u0016\u00101\u001a\u0012\u0012\u0004\u0012\u00020#0\"j\b\u0012\u0004\u0012\u00020#`$H\u0002R\u000e\u0010\b\u001a\u00020\tX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\n\u001a\b\u0012\u0004\u0012\u00020\f0\u000bX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\r\u001a\b\u0012\u0004\u0012\u00020\u000f0\u000eX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��¨\u00067"}, d2 = {"Lcom/kotlinnlp/hanclassifier/helpers/TrainingHelper;", "", "classifier", "Lcom/kotlinnlp/hanclassifier/HANClassifier;", "classifierUpdateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "embeddingsUpdateMethod", "(Lcom/kotlinnlp/hanclassifier/HANClassifier;Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;)V", "bestAccuracy", "", "classifierOptimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HANParameters;", "embeddingsOptimizer", "Lcom/kotlinnlp/simplednn/deeplearning/embeddings/EmbeddingsOptimizer;", "", "startTime", "", "validationHelper", "Lcom/kotlinnlp/hanclassifier/helpers/ValidationHelper;", "accumulateEmbeddingsErrors", "", "inputText", "", "errorsHierarchy", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HierarchyGroup;", "accumulateSentenceEmbeddingsErrors", "tokens", "tokensErrors", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HierarchySequence;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "formatElapsedTime", "initEmbeddings", "trainingSet", "Ljava/util/ArrayList;", "Lcom/kotlinnlp/hanclassifier/dataset/Example;", "Lkotlin/collections/ArrayList;", "learnFromExample", "example", "newBatch", "newEpoch", "newExample", "startTiming", "train", "epochs", "", "batchSize", "shuffler", "Lcom/kotlinnlp/simplednn/dataset/Shuffler;", "validationSet", "modelFilename", "trainEpoch", "update", "validateAndSaveModel", "validateEpoch", "hanclassifier"})
/* loaded from: input_file:com/kotlinnlp/hanclassifier/helpers/TrainingHelper.class */
public final class TrainingHelper {
    private long startTime;
    private double bestAccuracy;
    private final ValidationHelper validationHelper;
    private final ParamsOptimizer<HANParameters> classifierOptimizer;
    private final EmbeddingsOptimizer<String> embeddingsOptimizer;
    private final HANClassifier classifier;

    public final void train(@NotNull ArrayList<Example> arrayList, int i, int i2, @Nullable Shuffler shuffler, @Nullable ArrayList<Example> arrayList2, @Nullable String str) {
        Intrinsics.checkParameterIsNotNull(arrayList, "trainingSet");
        initEmbeddings(arrayList);
        IntIterator it = RangesKt.until(0, i).iterator();
        while (it.hasNext()) {
            System.out.println((Object) ("\nEpoch " + (it.nextInt() + 1) + " of " + i));
            startTiming();
            newEpoch();
            trainEpoch(arrayList, i2, shuffler);
            Object[] objArr = {formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
            if (arrayList2 != null) {
                validateAndSaveModel(arrayList2, str);
            }
        }
    }

    public static /* bridge */ /* synthetic */ void train$default(TrainingHelper trainingHelper, ArrayList arrayList, int i, int i2, Shuffler shuffler, ArrayList arrayList2, String str, int i3, Object obj) {
        if ((i3 & 4) != 0) {
            i2 = 1;
        }
        if ((i3 & 8) != 0) {
            shuffler = new Shuffler(true, 743L);
        }
        if ((i3 & 16) != 0) {
            arrayList2 = (ArrayList) null;
        }
        if ((i3 & 32) != 0) {
            str = (String) null;
        }
        trainingHelper.train(arrayList, i, i2, shuffler, arrayList2, str);
    }

    private final void initEmbeddings(ArrayList<Example> arrayList) {
        Iterator<T> it = arrayList.iterator();
        while (it.hasNext()) {
            Iterator<T> it2 = ((Example) it.next()).getInputText().iterator();
            while (it2.hasNext()) {
                for (String str : (List) it2.next()) {
                    if (!this.classifier.getModel().getEmbeddings().contains(str)) {
                        EmbeddingsMap.set$default(this.classifier.getModel().getEmbeddings(), str, (Embedding) null, 2, (Object) null);
                    }
                }
            }
        }
    }

    private final void trainEpoch(ArrayList<Example> arrayList, int i, Shuffler shuffler) {
        ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(arrayList.size(), (OutputStream) null, 0, 6, (DefaultConstructorMarker) null);
        int i2 = 0;
        Iterator it = new ExamplesIndices(arrayList.size(), shuffler).iterator();
        while (it.hasNext()) {
            int intValue = ((Number) it.next()).intValue();
            i2++;
            ProgressIndicator.tick$default(progressIndicatorBar, 0, 1, (Object) null);
            if ((i2 - 1) % i == 0) {
                newBatch();
            }
            newExample();
            Example example = arrayList.get(intValue);
            Intrinsics.checkExpressionValueIsNotNull(example, "trainingSet[exampleIndex]");
            learnFromExample(example);
            if (i2 % i == 0 || i2 == arrayList.size()) {
                update();
            }
        }
    }

    private final void learnFromExample(Example example) {
        DenseNDArray copy = this.classifier.classify(example.getInputText()).copy();
        copy.set(example.getOutputGold(), Double.valueOf(copy.get(example.getOutputGold()).doubleValue() - 1));
        HANEncoder.backward$default(this.classifier.getEncoder(), copy, true, (Double) null, 4, (Object) null);
        ParamsOptimizer.accumulate$default(this.classifierOptimizer, this.classifier.getEncoder().getParamsErrors(false), false, 2, (Object) null);
        List<List<String>> inputText = example.getInputText();
        HierarchyItem inputSequenceErrors = this.classifier.getEncoder().getInputSequenceErrors(false);
        if (inputSequenceErrors == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.deeplearning.attentionnetwork.han.HierarchyGroup");
        }
        accumulateEmbeddingsErrors(inputText, (HierarchyGroup) inputSequenceErrors);
    }

    private final void accumulateEmbeddingsErrors(List<? extends List<String>> list, HierarchyGroup hierarchyGroup) {
        for (Pair pair : CollectionsKt.zip(list, (Iterable) hierarchyGroup)) {
            List<String> list2 = (List) pair.component1();
            HierarchyItem hierarchyItem = (HierarchyItem) pair.component2();
            if (hierarchyItem == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.deeplearning.attentionnetwork.han.HierarchySequence<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray>");
            }
            accumulateSentenceEmbeddingsErrors(list2, (HierarchySequence) hierarchyItem);
        }
    }

    private final void accumulateSentenceEmbeddingsErrors(List<String> list, HierarchySequence<DenseNDArray> hierarchySequence) {
        for (Pair pair : CollectionsKt.zip(list, (Iterable) hierarchySequence)) {
            this.embeddingsOptimizer.accumulate((String) pair.component1(), (DenseNDArray) pair.component2());
        }
    }

    private final void newEpoch() {
        this.classifierOptimizer.newEpoch();
        this.embeddingsOptimizer.newEpoch();
    }

    private final void newBatch() {
        this.classifierOptimizer.newBatch();
        this.embeddingsOptimizer.newBatch();
    }

    private final void newExample() {
        this.classifierOptimizer.newExample();
        this.embeddingsOptimizer.newExample();
    }

    private final void update() {
        this.classifierOptimizer.update();
        this.embeddingsOptimizer.update();
    }

    private final void validateAndSaveModel(ArrayList<Example> arrayList, String str) {
        double validateEpoch = validateEpoch(arrayList);
        Object[] objArr = {Double.valueOf(100.0d * validateEpoch)};
        String format = String.format("Accuracy: %.2f%%", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        if (str == null || validateEpoch <= this.bestAccuracy) {
            return;
        }
        this.bestAccuracy = validateEpoch;
        this.classifier.getModel().dump(new FileOutputStream(new File(str)));
        System.out.println((Object) ("NEW BEST ACCURACY! Model saved to \"" + str + '\"'));
    }

    private final double validateEpoch(ArrayList<Example> arrayList) {
        Object[] objArr = {Integer.valueOf(arrayList.size())};
        String format = String.format("Epoch validation on %d sentences", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        return this.validationHelper.validate(arrayList);
    }

    private final void startTiming() {
        this.startTime = System.currentTimeMillis();
    }

    private final String formatElapsedTime() {
        double currentTimeMillis = (System.currentTimeMillis() - this.startTime) / 1000.0d;
        Object[] objArr = {Double.valueOf(currentTimeMillis), Double.valueOf(currentTimeMillis / 60.0d)};
        String format = String.format("%.3f s (%.1f min)", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        return format;
    }

    public TrainingHelper(@NotNull HANClassifier hANClassifier, @NotNull UpdateMethod<?> updateMethod, @NotNull UpdateMethod<?> updateMethod2) {
        Intrinsics.checkParameterIsNotNull(hANClassifier, "classifier");
        Intrinsics.checkParameterIsNotNull(updateMethod, "classifierUpdateMethod");
        Intrinsics.checkParameterIsNotNull(updateMethod2, "embeddingsUpdateMethod");
        this.classifier = hANClassifier;
        this.validationHelper = new ValidationHelper(this.classifier);
        this.classifierOptimizer = new ParamsOptimizer<>(this.classifier.getModel().getHan().getParams(), updateMethod);
        this.embeddingsOptimizer = new EmbeddingsOptimizer<>(this.classifier.getModel().getEmbeddings(), updateMethod2);
    }
}
