package com.kotlinnlp.hanclassifier.helpers;

import com.kotlinnlp.hanclassifier.EncodedSentence;
import com.kotlinnlp.hanclassifier.HANClassifier;
import com.kotlinnlp.hanclassifier.HANClassifierModel;
import com.kotlinnlp.hanclassifier.MultiLevelHANModel;
import com.kotlinnlp.hanclassifier.dataset.Example;
import com.kotlinnlp.hanclassifier.helpers.Validator;
import com.kotlinnlp.linguisticdescription.sentence.Sentence;
import com.kotlinnlp.linguisticdescription.sentence.token.FormToken;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HANParameters;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.tokensencoder.TokensEncoder;
import com.kotlinnlp.tokensencoder.TokensEncoderOptimizer;
import com.kotlinnlp.utils.ExamplesIndices;
import com.kotlinnlp.utils.Shuffler;
import com.kotlinnlp.utils.progressindicator.ProgressIndicator;
import com.kotlinnlp.utils.progressindicator.ProgressIndicatorBar;
import com.kotlinnlp.utils.stats.MetricCounter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.collections.MapsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Trainer.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��¢\u0001\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010!\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\t\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010 \n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018��2\u00020\u0001:\u0001BB;\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u000e\b\u0002\u0010\u0004\u001a\b\u0012\u0002\b\u0003\u0018\u00010\u0005\u0012\n\u0010\u0006\u001a\u0006\u0012\u0002\b\u00030\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\b\b\u0002\u0010\t\u001a\u00020\b¢\u0006\u0002\u0010\nJ\u0010\u0010\u001d\u001a\u00020\u001a2\u0006\u0010\u001e\u001a\u00020\u001fH\u0002J\b\u0010 \u001a\u00020!H\u0002J\u0010\u0010\"\u001a\u00020#2\u0006\u0010$\u001a\u00020%H\u0002J\b\u0010&\u001a\u00020#H\u0002J\b\u0010'\u001a\u00020#H\u0002J\b\u0010(\u001a\u00020#H\u0002J\b\u0010)\u001a\u00020#H\u0002JP\u0010*\u001a\u00020#2\f\u0010+\u001a\b\u0012\u0004\u0012\u00020%0,2\u0006\u0010-\u001a\u00020.2\b\b\u0002\u0010/\u001a\u00020.2\n\b\u0002\u00100\u001a\u0004\u0018\u0001012\u0010\b\u0002\u00102\u001a\n\u0012\u0004\u0012\u00020%\u0018\u00010,2\n\b\u0002\u00103\u001a\u0004\u0018\u00010!J(\u00104\u001a\u00020#2\f\u0010+\u001a\b\u0012\u0004\u0012\u00020%0,2\u0006\u0010/\u001a\u00020.2\b\u00100\u001a\u0004\u0018\u000101H\u0002JL\u00105\u001a\u00020#2\u0006\u00106\u001a\u0002072\u0006\u00108\u001a\u00020\u001a2\f\u00109\u001a\b\u0012\u0004\u0012\u00020:0,2\f\u0010;\u001a\b\u0012\u0004\u0012\u00020:0,2\f\u0010<\u001a\b\u0012\u0004\u0012\u00020.0,2\b\b\u0002\u0010=\u001a\u00020.H\u0002J\b\u0010>\u001a\u00020#H\u0002J \u0010?\u001a\u00020#2\f\u00102\u001a\b\u0012\u0004\u0012\u00020%0,2\b\u00103\u001a\u0004\u0018\u00010!H\u0002J\u001a\u0010@\u001a\u00060AR\u00020\u001c2\f\u00102\u001a\b\u0012\u0004\u0012\u00020%0,H\u0002R\u000e\u0010\u000b\u001a\u00020\fX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\r\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010\u000f\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u0010X\u0082\u0004¢\u0006\u0002\n��R\u0012\u0010\u0006\u001a\u0006\u0012\u0002\b\u00030\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\t\u001a\u00020\bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0013\u001a\u00020\u0014X\u0082\u000e¢\u0006\u0002\n��R\u0010\u0010\u0015\u001a\u0004\u0018\u00010\u0016X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0017\u001a\u00020\u0018X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0019\u001a\u00020\u001aX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001b\u001a\u00020\u001cX\u0082\u0004¢\u0006\u0002\n��¨\u0006C"}, d2 = {"Lcom/kotlinnlp/hanclassifier/helpers/Trainer;", "", "model", "Lcom/kotlinnlp/hanclassifier/HANClassifierModel;", "tokensEncoderUpdateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "classifierUpdateMethod", "useDropout", "", "saveClassifiersOnly", "(Lcom/kotlinnlp/hanclassifier/HANClassifierModel;Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;ZZ)V", "bestAccuracy", "", "classifier", "Lcom/kotlinnlp/hanclassifier/HANClassifier;", "classifierOptimizers", "", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HANParameters;", "startTime", "", "tokensEncoderOptimizer", "Lcom/kotlinnlp/tokensencoder/TokensEncoderOptimizer;", "tokensEncodersPool", "Lcom/kotlinnlp/hanclassifier/helpers/TokensEncodersPool;", "topLevelOptimizer", "Lcom/kotlinnlp/hanclassifier/helpers/Trainer$LevelOptimizer;", "validationHelper", "Lcom/kotlinnlp/hanclassifier/helpers/Validator;", "buildLevelOptimizer", "levelModel", "Lcom/kotlinnlp/hanclassifier/MultiLevelHANModel$LevelModel;", "formatElapsedTime", "", "learnFromExample", "", "example", "Lcom/kotlinnlp/hanclassifier/dataset/Example;", "newBatch", "newEpoch", "newExample", "startTiming", "train", "trainingSet", "", "epochs", "", "batchSize", "shuffler", "Lcom/kotlinnlp/utils/Shuffler;", "validationSet", "modelFilename", "trainEpoch", "trainLevelClassifier", "levelClassifier", "Lcom/kotlinnlp/hanclassifier/HANClassifier$LevelClassifier;", "levelOptimizer", "encodedSentences", "Lcom/kotlinnlp/hanclassifier/EncodedSentence;", "sentencesErrors", "expectedClasses", "levelIndex", "update", "validateAndSaveModel", "validateEpoch", "Lcom/kotlinnlp/hanclassifier/helpers/Validator$ValidationInfo;", "LevelOptimizer", "hanclassifier"})
/* loaded from: input_file:com/kotlinnlp/hanclassifier/helpers/Trainer.class */
public final class Trainer {
    private long startTime;
    private double bestAccuracy;
    private final HANClassifier classifier;
    private final Validator validationHelper;
    private final TokensEncodersPool tokensEncodersPool;
    private final TokensEncoderOptimizer tokensEncoderOptimizer;
    private final List<ParamsOptimizer<HANParameters>> classifierOptimizers;
    private final LevelOptimizer topLevelOptimizer;
    private final HANClassifierModel model;
    private final UpdateMethod<?> classifierUpdateMethod;
    private final boolean saveClassifiersOnly;

    /* JADX INFO: Access modifiers changed from: private */
    /* compiled from: Trainer.kt */
    @Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��.\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010$\n\u0002\u0010\b\n\u0002\b\t\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u000e\n��\b\u0082\b\u0018��2\u00020\u0001B)\u0012\f\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\u0014\u0010\u0005\u001a\u0010\u0012\u0004\u0012\u00020\u0007\u0012\u0006\u0012\u0004\u0018\u00010��0\u0006¢\u0006\u0002\u0010\bJ\u000f\u0010\r\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003HÆ\u0003J\u0017\u0010\u000e\u001a\u0010\u0012\u0004\u0012\u00020\u0007\u0012\u0006\u0012\u0004\u0018\u00010��0\u0006HÆ\u0003J1\u0010\u000f\u001a\u00020��2\u000e\b\u0002\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u00032\u0016\b\u0002\u0010\u0005\u001a\u0010\u0012\u0004\u0012\u00020\u0007\u0012\u0006\u0012\u0004\u0018\u00010��0\u0006HÆ\u0001J\u0013\u0010\u0010\u001a\u00020\u00112\b\u0010\u0012\u001a\u0004\u0018\u00010\u0001HÖ\u0003J\t\u0010\u0013\u001a\u00020\u0007HÖ\u0001J\t\u0010\u0014\u001a\u00020\u0015HÖ\u0001R\u0017\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003¢\u0006\b\n��\u001a\u0004\b\t\u0010\nR\u001f\u0010\u0005\u001a\u0010\u0012\u0004\u0012\u00020\u0007\u0012\u0006\u0012\u0004\u0018\u00010��0\u0006¢\u0006\b\n��\u001a\u0004\b\u000b\u0010\f¨\u0006\u0016"}, d2 = {"Lcom/kotlinnlp/hanclassifier/helpers/Trainer$LevelOptimizer;", "", "optimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HANParameters;", "subLevels", "", "", "(Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;Ljava/util/Map;)V", "getOptimizer", "()Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "getSubLevels", "()Ljava/util/Map;", "component1", "component2", "copy", "equals", "", "other", "hashCode", "toString", "", "hanclassifier"})
    /* loaded from: input_file:com/kotlinnlp/hanclassifier/helpers/Trainer$LevelOptimizer.class */
    public static final class LevelOptimizer {

        @NotNull
        private final ParamsOptimizer<HANParameters> optimizer;

        @NotNull
        private final Map<Integer, LevelOptimizer> subLevels;

        @NotNull
        public final ParamsOptimizer<HANParameters> getOptimizer() {
            return this.optimizer;
        }

        @NotNull
        public final Map<Integer, LevelOptimizer> getSubLevels() {
            return this.subLevels;
        }

        public LevelOptimizer(@NotNull ParamsOptimizer<HANParameters> paramsOptimizer, @NotNull Map<Integer, LevelOptimizer> map) {
            Intrinsics.checkParameterIsNotNull(paramsOptimizer, "optimizer");
            Intrinsics.checkParameterIsNotNull(map, "subLevels");
            this.optimizer = paramsOptimizer;
            this.subLevels = map;
        }

        @NotNull
        public final ParamsOptimizer<HANParameters> component1() {
            return this.optimizer;
        }

        @NotNull
        public final Map<Integer, LevelOptimizer> component2() {
            return this.subLevels;
        }

        @NotNull
        public final LevelOptimizer copy(@NotNull ParamsOptimizer<HANParameters> paramsOptimizer, @NotNull Map<Integer, LevelOptimizer> map) {
            Intrinsics.checkParameterIsNotNull(paramsOptimizer, "optimizer");
            Intrinsics.checkParameterIsNotNull(map, "subLevels");
            return new LevelOptimizer(paramsOptimizer, map);
        }

        @NotNull
        public static /* synthetic */ LevelOptimizer copy$default(LevelOptimizer levelOptimizer, ParamsOptimizer paramsOptimizer, Map map, int i, Object obj) {
            if ((i & 1) != 0) {
                paramsOptimizer = levelOptimizer.optimizer;
            }
            if ((i & 2) != 0) {
                map = levelOptimizer.subLevels;
            }
            return levelOptimizer.copy(paramsOptimizer, map);
        }

        @NotNull
        public String toString() {
            return "LevelOptimizer(optimizer=" + this.optimizer + ", subLevels=" + this.subLevels + ")";
        }

        public int hashCode() {
            ParamsOptimizer<HANParameters> paramsOptimizer = this.optimizer;
            int hashCode = (paramsOptimizer != null ? paramsOptimizer.hashCode() : 0) * 31;
            Map<Integer, LevelOptimizer> map = this.subLevels;
            return hashCode + (map != null ? map.hashCode() : 0);
        }

        public boolean equals(@Nullable Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof LevelOptimizer)) {
                return false;
            }
            LevelOptimizer levelOptimizer = (LevelOptimizer) obj;
            return Intrinsics.areEqual(this.optimizer, levelOptimizer.optimizer) && Intrinsics.areEqual(this.subLevels, levelOptimizer.subLevels);
        }
    }

    public final void train(@NotNull List<Example> list, int i, int i2, @Nullable Shuffler shuffler, @Nullable List<Example> list2, @Nullable String str) {
        Intrinsics.checkParameterIsNotNull(list, "trainingSet");
        IntIterator it = RangesKt.until(0, i).iterator();
        while (it.hasNext()) {
            System.out.println((Object) ("\nEpoch " + (it.nextInt() + 1) + " of " + i));
            startTiming();
            newEpoch();
            trainEpoch(list, i2, shuffler);
            Object[] objArr = {formatElapsedTime()};
            String format = String.format("Elapsed time: %s", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            System.out.println((Object) format);
            if (list2 != null) {
                validateAndSaveModel(list2, str);
            }
        }
    }

    public static /* synthetic */ void train$default(Trainer trainer, List list, int i, int i2, Shuffler shuffler, List list2, String str, int i3, Object obj) {
        if ((i3 & 4) != 0) {
            i2 = 1;
        }
        if ((i3 & 8) != 0) {
            shuffler = new Shuffler(true, 743L);
        }
        if ((i3 & 16) != 0) {
            list2 = (List) null;
        }
        if ((i3 & 32) != 0) {
            str = (String) null;
        }
        trainer.train(list, i, i2, shuffler, list2, str);
    }

    private final LevelOptimizer buildLevelOptimizer(MultiLevelHANModel.LevelModel levelModel) {
        ParamsOptimizer paramsOptimizer = new ParamsOptimizer(levelModel.getHan().getParams(), this.classifierUpdateMethod);
        Map<Integer, MultiLevelHANModel.LevelModel> subLevels = levelModel.getSubLevels();
        LinkedHashMap linkedHashMap = new LinkedHashMap(MapsKt.mapCapacity(subLevels.size()));
        for (Object obj : subLevels.entrySet()) {
            Object key = ((Map.Entry) obj).getKey();
            MultiLevelHANModel.LevelModel levelModel2 = (MultiLevelHANModel.LevelModel) ((Map.Entry) obj).getValue();
            linkedHashMap.put(key, levelModel2 != null ? buildLevelOptimizer(levelModel2) : null);
        }
        LevelOptimizer levelOptimizer = new LevelOptimizer(paramsOptimizer, linkedHashMap);
        this.classifierOptimizers.add(levelOptimizer.getOptimizer());
        return levelOptimizer;
    }

    private final void trainEpoch(List<Example> list, int i, Shuffler shuffler) {
        ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(list.size(), (OutputStream) null, 0, 6, (DefaultConstructorMarker) null);
        int i2 = 0;
        Iterator it = new ExamplesIndices(list.size(), shuffler).iterator();
        while (it.hasNext()) {
            int intValue = ((Number) it.next()).intValue();
            i2++;
            ProgressIndicator.tick$default(progressIndicatorBar, 0, 1, (Object) null);
            if ((i2 - 1) % i == 0) {
                newBatch();
            }
            newExample();
            learnFromExample(list.get(intValue));
            if (i2 % i == 0 || i2 == list.size()) {
                update();
            }
        }
    }

    private final void learnFromExample(Example example) {
        List<TokensEncoder<FormToken, Sentence<FormToken>>> encoders = this.tokensEncodersPool.getEncoders(example.getSentences().size());
        List<Sentence<FormToken>> sentences = example.getSentences();
        Iterator<T> it = sentences.iterator();
        Iterator<T> it2 = encoders.iterator();
        ArrayList arrayList = new ArrayList(Math.min(CollectionsKt.collectionSizeOrDefault(sentences, 10), CollectionsKt.collectionSizeOrDefault(encoders, 10)));
        while (it.hasNext() && it2.hasNext()) {
            arrayList.add(new EncodedSentence((List) ((TokensEncoder) it2.next()).forward((Sentence) it.next())));
        }
        ArrayList arrayList2 = arrayList;
        ArrayList<EncodedSentence> arrayList3 = arrayList2;
        ArrayList arrayList4 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList3, 10));
        for (EncodedSentence encodedSentence : arrayList3) {
            List<DenseNDArray> tokens = encodedSentence.getTokens();
            ArrayList arrayList5 = new ArrayList(CollectionsKt.collectionSizeOrDefault(tokens, 10));
            Iterator<T> it3 = tokens.iterator();
            while (it3.hasNext()) {
                arrayList5.add(((DenseNDArray) it3.next()).zerosLike());
            }
            arrayList4.add(encodedSentence.copy(arrayList5));
        }
        ArrayList arrayList6 = arrayList4;
        trainLevelClassifier$default(this, this.classifier.getTopLevelClassifier$hanclassifier(), this.topLevelOptimizer, arrayList2, arrayList6, this.classifier.getModel().hasSubLevels$hanclassifier(example.getGoldClasses()) ? CollectionsKt.plus(example.getGoldClasses(), Integer.valueOf(this.classifier.getModel().getNoClassIndex$hanclassifier(example.getGoldClasses()))) : example.getGoldClasses(), 0, 32, null);
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            for (Pair pair : CollectionsKt.zip(encoders, arrayList6)) {
                TokensEncoder tokensEncoder = (TokensEncoder) pair.component1();
                tokensEncoder.backward(((EncodedSentence) pair.component2()).getTokens());
                Optimizer.accumulate$default(tokensEncoderOptimizer, NeuralProcessor.DefaultImpls.getParamsErrors$default(tokensEncoder, false, 1, (Object) null), false, 2, (Object) null);
            }
        }
    }

    private final void trainLevelClassifier(HANClassifier.LevelClassifier levelClassifier, LevelOptimizer levelOptimizer, List<EncodedSentence> list, List<EncodedSentence> list2, List<Integer> list3, int i) {
        int intValue = list3.get(i).intValue();
        DenseNDArray copy = levelClassifier.getClassifier().forward(list).copy();
        copy.set(intValue, Double.valueOf(copy.get(intValue).doubleValue() - 1));
        levelClassifier.getClassifier().backward(copy);
        Optimizer.accumulate$default(levelOptimizer.getOptimizer(), levelClassifier.getClassifier().m4getParamsErrors(false), false, 2, (Object) null);
        if (this.tokensEncoderOptimizer != null) {
            List<EncodedSentence> m3getInputErrors = levelClassifier.getClassifier().m3getInputErrors(false);
            Iterator<T> it = m3getInputErrors.iterator();
            Iterator<T> it2 = list2.iterator();
            ArrayList arrayList = new ArrayList(Math.min(CollectionsKt.collectionSizeOrDefault(m3getInputErrors, 10), CollectionsKt.collectionSizeOrDefault(list2, 10)));
            while (it.hasNext() && it2.hasNext()) {
                for (Pair pair : CollectionsKt.zip(((EncodedSentence) it2.next()).getTokens(), ((EncodedSentence) it.next()).getTokens())) {
                    ((DenseNDArray) pair.getFirst()).assignSum((NDArray) pair.getSecond());
                }
                arrayList.add(Unit.INSTANCE);
            }
        }
        if (i < CollectionsKt.getLastIndex(list3)) {
            Object value = MapsKt.getValue(levelClassifier.getSubLevels(), Integer.valueOf(intValue));
            if (value == null) {
                Intrinsics.throwNpe();
            }
            HANClassifier.LevelClassifier levelClassifier2 = (HANClassifier.LevelClassifier) value;
            Object value2 = MapsKt.getValue(levelOptimizer.getSubLevels(), Integer.valueOf(intValue));
            if (value2 == null) {
                Intrinsics.throwNpe();
            }
            trainLevelClassifier(levelClassifier2, (LevelOptimizer) value2, list, list2, list3, i + 1);
        }
    }

    static /* synthetic */ void trainLevelClassifier$default(Trainer trainer, HANClassifier.LevelClassifier levelClassifier, LevelOptimizer levelOptimizer, List list, List list2, List list3, int i, int i2, Object obj) {
        if ((i2 & 32) != 0) {
            i = 0;
        }
        trainer.trainLevelClassifier(levelClassifier, levelOptimizer, list, list2, list3, i);
    }

    private final void newEpoch() {
        Iterator<T> it = this.classifierOptimizers.iterator();
        while (it.hasNext()) {
            ((ParamsOptimizer) it.next()).newEpoch();
        }
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            tokensEncoderOptimizer.newEpoch();
        }
    }

    private final void newBatch() {
        Iterator<T> it = this.classifierOptimizers.iterator();
        while (it.hasNext()) {
            ((ParamsOptimizer) it.next()).newBatch();
        }
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            tokensEncoderOptimizer.newBatch();
        }
    }

    private final void newExample() {
        Iterator<T> it = this.classifierOptimizers.iterator();
        while (it.hasNext()) {
            ((ParamsOptimizer) it.next()).newExample();
        }
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            tokensEncoderOptimizer.newExample();
        }
    }

    private final void update() {
        Iterator<T> it = this.classifierOptimizers.iterator();
        while (it.hasNext()) {
            ((ParamsOptimizer) it.next()).update();
        }
        TokensEncoderOptimizer tokensEncoderOptimizer = this.tokensEncoderOptimizer;
        if (tokensEncoderOptimizer != null) {
            tokensEncoderOptimizer.update();
        }
    }

    private final void validateAndSaveModel(List<Example> list, String str) {
        Validator.ValidationInfo validateEpoch = validateEpoch(list);
        List<MetricCounter> metrics = validateEpoch.getMetrics();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(metrics, 10));
        Iterator<T> it = metrics.iterator();
        while (it.hasNext()) {
            arrayList.add(Double.valueOf(((MetricCounter) it.next()).getF1Score()));
        }
        double averageOfDouble = CollectionsKt.averageOfDouble(arrayList);
        Object[] objArr = {Double.valueOf(100 * averageOfDouble)};
        String format = String.format("Accuracy (f1 average): %5.2f %%", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        int i = 0;
        for (Object obj : validateEpoch.getMetrics()) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            System.out.println((Object) ("- Level " + i2 + ": " + ((MetricCounter) obj)));
        }
        System.out.println((Object) "Level 0 confusion:");
        System.out.println(validateEpoch.getConfusionMatrix());
        if (str == null || averageOfDouble <= this.bestAccuracy) {
            return;
        }
        this.bestAccuracy = averageOfDouble;
        System.out.println((Object) ("NEW BEST ACCURACY! Saving model to \"" + str + "\"..."));
        if (this.saveClassifiersOnly) {
            this.classifier.getModel().getMultiLevelHAN().dump(new FileOutputStream(new File(str)));
        } else {
            this.classifier.getModel().dump(new FileOutputStream(new File(str)));
        }
    }

    private final Validator.ValidationInfo validateEpoch(List<Example> list) {
        Object[] objArr = {Integer.valueOf(list.size())};
        String format = String.format("Epoch validation on %d sentences", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        return this.validationHelper.validate(list);
    }

    private final void startTiming() {
        this.startTime = System.currentTimeMillis();
    }

    private final String formatElapsedTime() {
        double currentTimeMillis = (System.currentTimeMillis() - this.startTime) / 1000.0d;
        Object[] objArr = {Double.valueOf(currentTimeMillis), Double.valueOf(currentTimeMillis / 60.0d)};
        String format = String.format("%.3f s (%.1f min)", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        return format;
    }

    public Trainer(@NotNull HANClassifierModel hANClassifierModel, @Nullable UpdateMethod<?> updateMethod, @NotNull UpdateMethod<?> updateMethod2, boolean z, boolean z2) {
        TokensEncoderOptimizer tokensEncoderOptimizer;
        Intrinsics.checkParameterIsNotNull(hANClassifierModel, "model");
        Intrinsics.checkParameterIsNotNull(updateMethod2, "classifierUpdateMethod");
        this.model = hANClassifierModel;
        this.classifierUpdateMethod = updateMethod2;
        this.saveClassifiersOnly = z2;
        this.classifier = new HANClassifier(this.model, z, true);
        this.validationHelper = new Validator(this.classifier.getModel());
        this.tokensEncodersPool = new TokensEncodersPool(this.model.getTokensEncoder(), z);
        Trainer trainer = this;
        if (updateMethod != null) {
            trainer = trainer;
            tokensEncoderOptimizer = this.model.getTokensEncoder().buildOptimizer(updateMethod);
        } else {
            tokensEncoderOptimizer = null;
        }
        trainer.tokensEncoderOptimizer = tokensEncoderOptimizer;
        this.classifierOptimizers = new ArrayList();
        this.topLevelOptimizer = buildLevelOptimizer(this.classifier.getModel().getTopLevelModel$hanclassifier());
    }

    public /* synthetic */ Trainer(HANClassifierModel hANClassifierModel, UpdateMethod updateMethod, UpdateMethod updateMethod2, boolean z, boolean z2, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(hANClassifierModel, (i & 2) != 0 ? (UpdateMethod) null : updateMethod, updateMethod2, z, (i & 16) != 0 ? false : z2);
    }
}
