package org.deeplearning4j.optimize.listeners;

import com.google.common.io.Files;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.util.ModelSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/listeners/CheckpointListener.class */
public class CheckpointListener extends BaseTrainingListener implements Serializable {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) CheckpointListener.class);
    private static final String[] MODEL_TYPES = {"MultiLayerNetwork", "ComputationGraph", "Model"};
    private File rootDir;
    private KeepMode keepMode;
    private int keepLast;
    private int keepEvery;
    private boolean logSaving;
    private boolean deleteExisting;
    private Integer saveEveryNEpochs;
    private Integer saveEveryNIterations;
    private boolean saveEveryNIterSinceLast;
    private Long saveEveryAmount;
    private TimeUnit saveEveryUnit;
    private Long saveEveryMs;
    private boolean saveEverySinceLast;
    private int lastCheckpointNum;
    private File checkpointRecordFile;
    private Checkpoint lastCheckpoint;
    private long startTime;
    private int startIter;
    private Long lastSaveEveryMsNoSinceLast;

    /* loaded from: input_file:org/deeplearning4j/optimize/listeners/CheckpointListener$Builder.class */
    public static class Builder {
        private File rootDir;
        private KeepMode keepMode;
        private int keepLast;
        private int keepEvery;
        private boolean logSaving;
        private boolean deleteExisting;
        private Integer saveEveryNEpochs;
        private Integer saveEveryNIterations;
        private boolean saveEveryNIterSinceLast;
        private Long saveEveryAmount;
        private TimeUnit saveEveryUnit;
        private boolean saveEverySinceLast;

        public Builder(@NonNull String str) {
            this(new File(str));
            if (str == null) {
                throw new NullPointerException("rootDir is marked @NonNull but is null");
            }
        }

        public Builder(@NonNull File file) {
            this.logSaving = true;
            this.deleteExisting = false;
            if (file == null) {
                throw new NullPointerException("rootDir is marked @NonNull but is null");
            }
            this.rootDir = file;
        }

        public Builder saveEveryEpoch() {
            return saveEveryNEpochs(1);
        }

        public Builder saveEveryNEpochs(int i) {
            this.saveEveryNEpochs = Integer.valueOf(i);
            return this;
        }

        public Builder saveEveryNIterations(int i) {
            return saveEveryNIterations(i, false);
        }

        public Builder saveEveryNIterations(int i, boolean z) {
            this.saveEveryNIterations = Integer.valueOf(i);
            this.saveEveryNIterSinceLast = z;
            return this;
        }

        public Builder saveEvery(long j, TimeUnit timeUnit) {
            return saveEvery(j, timeUnit, false);
        }

        public Builder saveEvery(long j, TimeUnit timeUnit, boolean z) {
            this.saveEveryAmount = Long.valueOf(j);
            this.saveEveryUnit = timeUnit;
            this.saveEverySinceLast = z;
            return this;
        }

        public Builder keepAll() {
            this.keepMode = KeepMode.ALL;
            return this;
        }

        public Builder keepLast(int i) {
            if (i <= 0) {
                throw new IllegalArgumentException("Number of model files to keep should be > 0 (got: " + i + ")");
            }
            this.keepMode = KeepMode.LAST;
            this.keepLast = i;
            return this;
        }

        public Builder keepLastAndEvery(int i, int i2) {
            if (i <= 0) {
                throw new IllegalArgumentException("Most recent number of model files to keep should be > 0 (got: " + i + ")");
            }
            if (i2 <= 0) {
                throw new IllegalArgumentException("Every n model files to keep should be > 0 (got: " + i2 + ")");
            }
            this.keepMode = KeepMode.LAST_AND_EVERY;
            this.keepLast = i;
            this.keepEvery = i2;
            return this;
        }

        public Builder logSaving(boolean z) {
            this.logSaving = z;
            return this;
        }

        public Builder deleteExisting(boolean z) {
            this.deleteExisting = z;
            return this;
        }

        public CheckpointListener build() {
            if (this.saveEveryNEpochs == null && this.saveEveryAmount == null && this.saveEveryNIterations == null) {
                throw new IllegalStateException("Cannot construct listener: no models will be saved (must use at least one of: save every N epochs, every N iterations, or every T time periods)");
            }
            return new CheckpointListener(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/optimize/listeners/CheckpointListener$KeepMode.class */
    public enum KeepMode {
        ALL,
        LAST,
        LAST_AND_EVERY
    }

    private CheckpointListener(Builder builder) {
        this.lastCheckpointNum = -1;
        this.startTime = -1L;
        this.startIter = -1;
        this.rootDir = builder.rootDir;
        this.keepMode = builder.keepMode;
        this.keepLast = builder.keepLast;
        this.keepEvery = builder.keepEvery;
        this.logSaving = builder.logSaving;
        this.deleteExisting = builder.deleteExisting;
        this.saveEveryNEpochs = builder.saveEveryNEpochs;
        this.saveEveryNIterations = builder.saveEveryNIterations;
        this.saveEveryNIterSinceLast = builder.saveEveryNIterSinceLast;
        this.saveEveryAmount = builder.saveEveryAmount;
        this.saveEveryUnit = builder.saveEveryUnit;
        this.saveEverySinceLast = builder.saveEverySinceLast;
        if (this.saveEveryAmount != null) {
            this.saveEveryMs = Long.valueOf(TimeUnit.MILLISECONDS.convert(this.saveEveryAmount.longValue(), this.saveEveryUnit));
        }
        this.checkpointRecordFile = new File(this.rootDir, "checkpointInfo.txt");
        if (!this.checkpointRecordFile.exists() || this.checkpointRecordFile.length() <= 0) {
            return;
        }
        if (!this.deleteExisting) {
            throw new IllegalStateException("Detected existing checkpoint files at directory " + this.rootDir.getAbsolutePath() + ". Use deleteExisting(true) to delete existing checkpoint files when present.");
        }
        this.checkpointRecordFile.delete();
        File[] listFiles = this.rootDir.listFiles();
        if (listFiles == null || listFiles.length <= 0) {
            return;
        }
        for (File file : listFiles) {
            String name = file.getName();
            if (name.startsWith("checkpoint_") && (name.endsWith("MultiLayerNetwork.zip") || name.endsWith("ComputationGraph.zip"))) {
                file.delete();
            }
        }
    }

    @Override // org.deeplearning4j.optimize.api.BaseTrainingListener, org.deeplearning4j.optimize.api.TrainingListener
    public void onEpochEnd(Model model) {
        int epoch = getEpoch(model) + 1;
        if (this.saveEveryNEpochs == null || epoch <= 0 || epoch % this.saveEveryNEpochs.intValue() != 0) {
            return;
        }
        saveCheckpoint(model);
    }

    @Override // org.deeplearning4j.optimize.api.BaseTrainingListener, org.deeplearning4j.optimize.api.TrainingListener
    public void iterationDone(Model model, int i, int i2) {
        if (this.startTime < 0) {
            this.startTime = System.currentTimeMillis();
            this.startIter = i;
            return;
        }
        if (this.saveEveryNIterations != null) {
            if (this.saveEveryNIterSinceLast) {
                if (i - (this.lastCheckpoint != null ? this.lastCheckpoint.getIteration() : this.startIter) >= this.saveEveryNIterations.intValue()) {
                    saveCheckpoint(model);
                    return;
                }
            } else if (i > 0 && i % this.saveEveryNIterations.intValue() == 0) {
                saveCheckpoint(model);
                return;
            }
        }
        long currentTimeMillis = System.currentTimeMillis();
        if (this.saveEveryUnit != null) {
            if (this.saveEverySinceLast) {
                if (currentTimeMillis - (this.lastCheckpoint != null ? this.lastCheckpoint.getTimestamp() : this.startTime) >= this.saveEveryMs.longValue()) {
                    saveCheckpoint(model);
                }
            } else {
                if (currentTimeMillis - (this.lastSaveEveryMsNoSinceLast != null ? this.lastSaveEveryMsNoSinceLast.longValue() : this.startTime) > this.saveEveryMs.longValue()) {
                    saveCheckpoint(model);
                    this.lastSaveEveryMsNoSinceLast = Long.valueOf(currentTimeMillis);
                }
            }
        }
    }

    private void saveCheckpoint(Model model) {
        try {
            saveCheckpointHelper(model);
        } catch (Exception e) {
            throw new RuntimeException("Error saving checkpoint", e);
        }
    }

    private void saveCheckpointHelper(Model model) throws Exception {
        if (!this.checkpointRecordFile.exists()) {
            this.checkpointRecordFile.createNewFile();
            write(Checkpoint.getFileHeader() + "\n", this.checkpointRecordFile);
        }
        int i = this.lastCheckpointNum + 1;
        this.lastCheckpointNum = i;
        Checkpoint checkpoint = new Checkpoint(i, System.currentTimeMillis(), getIter(model), getEpoch(model), getModelType(model), null);
        setFileName(checkpoint);
        ModelSerializer.writeModel(model, new File(this.rootDir, checkpoint.getFilename()), true);
        write(checkpoint.toFileString() + "\n", this.checkpointRecordFile);
        if (this.logSaving) {
            log.info("Model checkpoint saved: epoch {}, iteration {}, path: {}", Integer.valueOf(checkpoint.getEpoch()), Integer.valueOf(checkpoint.getIteration()), new File(this.rootDir, checkpoint.getFilename()).getPath());
        }
        this.lastCheckpoint = checkpoint;
        if (this.keepMode == null || this.keepMode == KeepMode.ALL) {
            return;
        }
        if (this.keepMode == KeepMode.LAST) {
            List<Checkpoint> availableCheckpoints = availableCheckpoints();
            Iterator<Checkpoint> it2 = availableCheckpoints.iterator();
            while (availableCheckpoints.size() > this.keepLast) {
                getFileForCheckpoint(it2.next()).delete();
                it2.remove();
            }
            return;
        }
        for (Checkpoint checkpoint2 : availableCheckpoints()) {
            if (checkpoint2.getCheckpointNum() <= 0 || (checkpoint2.getCheckpointNum() + 1) % this.keepEvery != 0) {
                if (checkpoint2.getCheckpointNum() <= this.lastCheckpointNum - this.keepLast) {
                    getFileForCheckpoint(checkpoint2).delete();
                }
            }
        }
    }

    private static void setFileName(Checkpoint checkpoint) {
        checkpoint.setFilename(getFileName(checkpoint.getCheckpointNum(), checkpoint.getModelType()));
    }

    private static String getFileName(int i, String str) {
        return "checkpoint_" + i + "_" + str + ".zip";
    }

    private static String write(String str, File file) {
        try {
            if (!file.exists()) {
                file.createNewFile();
            }
            Files.append(str, file, Charset.defaultCharset());
            return str;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected static int getIter(Model model) {
        return model instanceof MultiLayerNetwork ? ((MultiLayerNetwork) model).getLayerWiseConfigurations().getIterationCount() : model instanceof ComputationGraph ? ((ComputationGraph) model).getConfiguration().getIterationCount() : model.conf().getIterationCount();
    }

    protected static int getEpoch(Model model) {
        return model instanceof MultiLayerNetwork ? ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount() : model instanceof ComputationGraph ? ((ComputationGraph) model).getConfiguration().getEpochCount() : model.conf().getEpochCount();
    }

    protected static String getModelType(Model model) {
        return model.getClass() == MultiLayerNetwork.class ? "MultiLayerNetwork" : model.getClass() == ComputationGraph.class ? "ComputationGraph" : "Model";
    }

    public List<Checkpoint> availableCheckpoints() {
        if (!this.checkpointRecordFile.exists()) {
            return Collections.emptyList();
        }
        try {
            BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(this.checkpointRecordFile));
            Throwable th = null;
            try {
                List<String> readLines = IOUtils.readLines(bufferedInputStream);
                if (bufferedInputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedInputStream.close();
                    }
                }
                ArrayList arrayList = new ArrayList(readLines.size() - 1);
                for (int i = 1; i < readLines.size(); i++) {
                    Checkpoint fromFileString = Checkpoint.fromFileString(readLines.get(i));
                    if (new File(this.rootDir, fromFileString.getFilename()).exists()) {
                        arrayList.add(fromFileString);
                    }
                }
                return arrayList;
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("Error loading checkpoint data from file: " + this.checkpointRecordFile.getAbsolutePath(), e);
        }
    }

    public Checkpoint lastCheckpoint() {
        List<Checkpoint> availableCheckpoints = availableCheckpoints();
        if (availableCheckpoints.isEmpty()) {
            return null;
        }
        return availableCheckpoints.get(availableCheckpoints.size() - 1);
    }

    public File getFileForCheckpoint(Checkpoint checkpoint) {
        return getFileForCheckpoint(checkpoint.getCheckpointNum());
    }

    public File getFileForCheckpoint(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("Invalid checkpoint number: " + i);
        }
        for (String str : MODEL_TYPES) {
            File file = new File(this.rootDir, getFileName(i, str));
            if (file.exists()) {
                return file;
            }
        }
        throw new IllegalStateException("Model file for checkpoint " + i + " does not exist");
    }

    public MultiLayerNetwork loadCheckpointMLN(Checkpoint checkpoint) {
        return loadCheckpointMLN(checkpoint.getCheckpointNum());
    }

    public MultiLayerNetwork loadCheckpointMLN(int i) {
        try {
            return ModelSerializer.restoreMultiLayerNetwork(getFileForCheckpoint(i), true);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public ComputationGraph loadCheckpointCG(Checkpoint checkpoint) {
        return loadCheckpointCG(checkpoint.getCheckpointNum());
    }

    public ComputationGraph loadCheckpointCG(int i) {
        try {
            return ModelSerializer.restoreComputationGraph(getFileForCheckpoint(i), true);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
