package org.deeplearning4j.perf.listener;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import oshi.json.SystemInfo;

/* loaded from: input_file:org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.class */
public class SystemInfoFilePrintListener implements TrainingListener {
    private static final Logger log = LoggerFactory.getLogger(SystemInfoFilePrintListener.class);
    private boolean printOnEpochStart;
    private boolean printOnEpochEnd;
    private boolean printOnForwardPass;
    private boolean printOnBackwardPass;
    private boolean printOnGradientCalculation;
    private File printFileTarget;

    /* loaded from: input_file:org/deeplearning4j/perf/listener/SystemInfoFilePrintListener$SystemInfoFilePrintListenerBuilder.class */
    public static class SystemInfoFilePrintListenerBuilder {
        private boolean printOnEpochStart;
        private boolean printOnEpochEnd;
        private boolean printOnForwardPass;
        private boolean printOnBackwardPass;
        private boolean printOnGradientCalculation;
        private File printFileTarget;

        SystemInfoFilePrintListenerBuilder() {
        }

        public SystemInfoFilePrintListenerBuilder printOnEpochStart(boolean z) {
            this.printOnEpochStart = z;
            return this;
        }

        public SystemInfoFilePrintListenerBuilder printOnEpochEnd(boolean z) {
            this.printOnEpochEnd = z;
            return this;
        }

        public SystemInfoFilePrintListenerBuilder printOnForwardPass(boolean z) {
            this.printOnForwardPass = z;
            return this;
        }

        public SystemInfoFilePrintListenerBuilder printOnBackwardPass(boolean z) {
            this.printOnBackwardPass = z;
            return this;
        }

        public SystemInfoFilePrintListenerBuilder printOnGradientCalculation(boolean z) {
            this.printOnGradientCalculation = z;
            return this;
        }

        public SystemInfoFilePrintListenerBuilder printFileTarget(File file) {
            this.printFileTarget = file;
            return this;
        }

        public SystemInfoFilePrintListener build() {
            return new SystemInfoFilePrintListener(this.printOnEpochStart, this.printOnEpochEnd, this.printOnForwardPass, this.printOnBackwardPass, this.printOnGradientCalculation, this.printFileTarget);
        }

        public String toString() {
            return "SystemInfoFilePrintListener.SystemInfoFilePrintListenerBuilder(printOnEpochStart=" + this.printOnEpochStart + ", printOnEpochEnd=" + this.printOnEpochEnd + ", printOnForwardPass=" + this.printOnForwardPass + ", printOnBackwardPass=" + this.printOnBackwardPass + ", printOnGradientCalculation=" + this.printOnGradientCalculation + ", printFileTarget=" + this.printFileTarget + ")";
        }
    }

    public SystemInfoFilePrintListener(boolean z, boolean z2, boolean z3, boolean z4, boolean z5, @NonNull File file) {
        if (file == null) {
            throw new NullPointerException("printFileTarget is marked @NonNull but is null");
        }
        this.printOnEpochStart = z;
        this.printOnEpochEnd = z2;
        this.printOnForwardPass = z3;
        this.printOnBackwardPass = z4;
        this.printOnGradientCalculation = z5;
        this.printFileTarget = file;
    }

    public void iterationDone(Model model, int i, int i2) {
    }

    public void onEpochStart(Model model) {
        if (!this.printOnEpochStart || this.printFileTarget == null) {
            return;
        }
        writeFileWithMessage("epoch end");
    }

    public void onEpochEnd(Model model) {
        if (!this.printOnEpochEnd || this.printFileTarget == null) {
            return;
        }
        writeFileWithMessage("epoch begin");
    }

    public void onForwardPass(Model model, List<INDArray> list) {
        if (!this.printOnBackwardPass || this.printFileTarget == null) {
            return;
        }
        writeFileWithMessage("forward pass");
    }

    public void onForwardPass(Model model, Map<String, INDArray> map) {
        if (!this.printOnForwardPass || this.printFileTarget == null) {
            return;
        }
        writeFileWithMessage("forward pass");
    }

    public void onGradientCalculation(Model model) {
        if (!this.printOnGradientCalculation || this.printFileTarget == null) {
            return;
        }
        writeFileWithMessage("gradient calculation");
    }

    public void onBackwardPass(Model model) {
        if (!this.printOnBackwardPass || this.printFileTarget == null) {
            return;
        }
        writeFileWithMessage("backward pass");
    }

    private void writeFileWithMessage(String str) {
        if (this.printFileTarget == null) {
            log.warn("File not specified for writing!");
        }
        SystemInfo systemInfo = new SystemInfo();
        log.info("Writing system info to file on " + str + ": " + this.printFileTarget.getAbsolutePath());
        try {
            FileUtils.write(this.printFileTarget, systemInfo.toPrettyJSON(), true);
        } catch (IOException e) {
            log.error("Error writing file for system info", e);
        }
    }

    public static SystemInfoFilePrintListenerBuilder builder() {
        return new SystemInfoFilePrintListenerBuilder();
    }
}
