package org.deeplearning4j.earlystopping.trainer;

import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedHashMap;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.class */
public abstract class BaseEarlyStoppingTrainer<T extends Model> implements IEarlyStoppingTrainer<T> {
    private static Logger log = LoggerFactory.getLogger((Class<?>) BaseEarlyStoppingTrainer.class);
    protected T model;
    protected final EarlyStoppingConfiguration<T> esConfig;
    private final DataSetIterator train;
    private final MultiDataSetIterator trainMulti;
    private final Iterator<?> iterator;
    private EarlyStoppingListener<T> listener;
    private double bestModelScore = Double.MAX_VALUE;
    private int bestModelEpoch = -1;

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public BaseEarlyStoppingTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T t, DataSetIterator dataSetIterator, MultiDataSetIterator multiDataSetIterator, EarlyStoppingListener<T> earlyStoppingListener) {
        this.esConfig = earlyStoppingConfiguration;
        this.model = t;
        this.train = dataSetIterator;
        this.trainMulti = multiDataSetIterator;
        this.iterator = dataSetIterator != null ? dataSetIterator : multiDataSetIterator;
        this.listener = earlyStoppingListener;
    }

    protected abstract void fit(DataSet dataSet);

    protected abstract void fit(MultiDataSet multiDataSet);

    @Override // org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer
    public EarlyStoppingResult<T> fit() {
        log.info("Starting early stopping training");
        if (this.esConfig.getScoreCalculator() == null) {
            log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
        }
        if (this.esConfig.getIterationTerminationConditions() != null) {
            Iterator<IterationTerminationCondition> it2 = this.esConfig.getIterationTerminationConditions().iterator();
            while (it2.hasNext()) {
                it2.next().initialize();
            }
        }
        if (this.esConfig.getEpochTerminationConditions() != null) {
            Iterator<EpochTerminationCondition> it3 = this.esConfig.getEpochTerminationConditions().iterator();
            while (it3.hasNext()) {
                it3.next().initialize();
            }
        }
        if (this.listener != null) {
            this.listener.onStart(this.esConfig, this.model);
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        int i = 0;
        while (true) {
            reset();
            boolean z = false;
            IterationTerminationCondition iterationTerminationCondition = null;
            int i2 = 0;
            while (this.iterator.hasNext()) {
                try {
                    if (this.train != null) {
                        fit((DataSet) this.iterator.next());
                    } else {
                        fit(this.trainMulti.next());
                    }
                    double score = this.model.score();
                    Iterator<IterationTerminationCondition> it4 = this.esConfig.getIterationTerminationConditions().iterator();
                    while (true) {
                        if (!it4.hasNext()) {
                            break;
                        }
                        IterationTerminationCondition next = it4.next();
                        if (next.terminate(score)) {
                            z = true;
                            iterationTerminationCondition = next;
                            break;
                        }
                    }
                    if (z) {
                        break;
                    }
                    i2++;
                } catch (Exception e) {
                    log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", Integer.valueOf(i), Integer.valueOf(i2), e);
                    try {
                        return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.Error, e.toString(), linkedHashMap, this.bestModelEpoch, this.bestModelScore, i, this.esConfig.getModelSaver().getBestModel());
                    } catch (IOException e2) {
                        throw new RuntimeException(e2);
                    }
                }
            }
            if (z) {
                log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", Integer.valueOf(i), Integer.valueOf(i2), iterationTerminationCondition);
                if (this.esConfig.isSaveLastModel()) {
                    try {
                        this.esConfig.getModelSaver().saveLatestModel(this.model, CMAESOptimizer.DEFAULT_STOPFITNESS);
                    } catch (IOException e3) {
                        throw new RuntimeException("Error saving most recent model", e3);
                    }
                }
                try {
                    EarlyStoppingResult<T> earlyStoppingResult = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, iterationTerminationCondition.toString(), linkedHashMap, this.bestModelEpoch, this.bestModelScore, i, this.esConfig.getModelSaver().getBestModel());
                    if (this.listener != null) {
                        this.listener.onCompletion(earlyStoppingResult);
                    }
                    return earlyStoppingResult;
                } catch (IOException e4) {
                    throw new RuntimeException(e4);
                }
            }
            log.info("Completed training epoch {}", Integer.valueOf(i));
            if ((i == 0 && this.esConfig.getEvaluateEveryNEpochs() == 1) || i % this.esConfig.getEvaluateEveryNEpochs() == 0) {
                ScoreCalculator<T> scoreCalculator = this.esConfig.getScoreCalculator();
                double calculateScore = scoreCalculator == null ? CMAESOptimizer.DEFAULT_STOPFITNESS : this.esConfig.getScoreCalculator().calculateScore(this.model);
                linkedHashMap.put(Integer.valueOf(i - 1), Double.valueOf(calculateScore));
                if (scoreCalculator != null && calculateScore < this.bestModelScore) {
                    if (this.bestModelEpoch == -1) {
                        log.info("Score at epoch {}: {}", Integer.valueOf(i), Double.valueOf(calculateScore));
                    } else {
                        log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", Double.valueOf(calculateScore), Integer.valueOf(i), Double.valueOf(this.bestModelScore), Integer.valueOf(this.bestModelEpoch));
                    }
                    this.bestModelScore = calculateScore;
                    this.bestModelEpoch = i;
                    try {
                        this.esConfig.getModelSaver().saveBestModel(this.model, calculateScore);
                    } catch (IOException e5) {
                        throw new RuntimeException("Error saving best model", e5);
                    }
                }
                if (this.esConfig.isSaveLastModel()) {
                    try {
                        this.esConfig.getModelSaver().saveLatestModel(this.model, calculateScore);
                    } catch (IOException e6) {
                        throw new RuntimeException("Error saving most recent model", e6);
                    }
                }
                if (this.listener != null) {
                    this.listener.onEpoch(i, calculateScore, this.esConfig, this.model);
                }
                boolean z2 = false;
                EpochTerminationCondition epochTerminationCondition = null;
                Iterator<EpochTerminationCondition> it5 = this.esConfig.getEpochTerminationConditions().iterator();
                while (true) {
                    if (!it5.hasNext()) {
                        break;
                    }
                    EpochTerminationCondition next2 = it5.next();
                    if (next2.terminate(i, calculateScore)) {
                        z2 = true;
                        epochTerminationCondition = next2;
                        break;
                    }
                }
                if (z2) {
                    log.info("Hit epoch termination condition at epoch {}. Details: {}", Integer.valueOf(i), epochTerminationCondition.toString());
                    try {
                        EarlyStoppingResult<T> earlyStoppingResult2 = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, epochTerminationCondition.toString(), linkedHashMap, this.bestModelEpoch, this.bestModelScore, i + 1, this.esConfig.getModelSaver().getBestModel());
                        if (this.listener != null) {
                            this.listener.onCompletion(earlyStoppingResult2);
                        }
                        return earlyStoppingResult2;
                    } catch (IOException e7) {
                        throw new RuntimeException(e7);
                    }
                }
            }
            i++;
        }
    }

    @Override // org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer
    public void setListener(EarlyStoppingListener<T> earlyStoppingListener) {
        this.listener = earlyStoppingListener;
    }

    protected void reset() {
        if (this.train != null) {
            this.train.reset();
        }
        if (this.trainMulti != null) {
            this.trainMulti.reset();
        }
    }
}
