package org.deeplearning4j.rl4j.util;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.beans.ConstructorProperties;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.StandardOpenOption;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.util.ModelSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/util/DataManager.class */
public class DataManager {
    private final Logger log = LoggerFactory.getLogger("DataManager");
    private final String home = System.getProperty("user.home");
    private final ObjectMapper mapper = new ObjectMapper();
    private String dataRoot = this.home + "/" + Constants.DATA_DIR;
    private boolean saveData;
    private String currentDir;

    /* loaded from: input_file:org/deeplearning4j/rl4j/util/DataManager$Info.class */
    public static final class Info {
        private final String trainingName;
        private final String mdpName;
        private final ILearning.LConfiguration conf;
        private final int stepCounter;
        private final long millisTime;

        public String getTrainingName() {
            return this.trainingName;
        }

        public String getMdpName() {
            return this.mdpName;
        }

        public ILearning.LConfiguration getConf() {
            return this.conf;
        }

        public int getStepCounter() {
            return this.stepCounter;
        }

        public long getMillisTime() {
            return this.millisTime;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Info)) {
                return false;
            }
            Info info = (Info) obj;
            String trainingName = getTrainingName();
            String trainingName2 = info.getTrainingName();
            if (trainingName == null) {
                if (trainingName2 != null) {
                    return false;
                }
            } else if (!trainingName.equals(trainingName2)) {
                return false;
            }
            String mdpName = getMdpName();
            String mdpName2 = info.getMdpName();
            if (mdpName == null) {
                if (mdpName2 != null) {
                    return false;
                }
            } else if (!mdpName.equals(mdpName2)) {
                return false;
            }
            ILearning.LConfiguration conf = getConf();
            ILearning.LConfiguration conf2 = info.getConf();
            if (conf == null) {
                if (conf2 != null) {
                    return false;
                }
            } else if (!conf.equals(conf2)) {
                return false;
            }
            return getStepCounter() == info.getStepCounter() && getMillisTime() == info.getMillisTime();
        }

        public int hashCode() {
            String trainingName = getTrainingName();
            int hashCode = (1 * 59) + (trainingName == null ? 43 : trainingName.hashCode());
            String mdpName = getMdpName();
            int hashCode2 = (hashCode * 59) + (mdpName == null ? 43 : mdpName.hashCode());
            ILearning.LConfiguration conf = getConf();
            int hashCode3 = (((hashCode2 * 59) + (conf == null ? 43 : conf.hashCode())) * 59) + getStepCounter();
            long millisTime = getMillisTime();
            return (hashCode3 * 59) + ((int) ((millisTime >>> 32) ^ millisTime));
        }

        public String toString() {
            return "DataManager.Info(trainingName=" + getTrainingName() + ", mdpName=" + getMdpName() + ", conf=" + getConf() + ", stepCounter=" + getStepCounter() + ", millisTime=" + getMillisTime() + ")";
        }

        @ConstructorProperties({"trainingName", "mdpName", "conf", "stepCounter", "millisTime"})
        public Info(String str, String str2, ILearning.LConfiguration lConfiguration, int i, long j) {
            this.trainingName = str;
            this.mdpName = str2;
            this.conf = lConfiguration;
            this.stepCounter = i;
            this.millisTime = j;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/rl4j/util/DataManager$StatEntry.class */
    public interface StatEntry {
        int getEpochCounter();

        int getStepCounter();

        double getReward();
    }

    public DataManager() {
        create(this.dataRoot, false);
    }

    public DataManager(boolean z) {
        create(this.dataRoot, z);
    }

    public DataManager(String str, boolean z) {
        create(str, z);
    }

    private static void writeEntry(InputStream inputStream, ZipOutputStream zipOutputStream) throws IOException {
        byte[] bArr = new byte[1024];
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                return;
            } else {
                zipOutputStream.write(bArr, 0, read);
            }
        }
    }

    public static void save(String str, Learning learning) {
        try {
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(str));
            Throwable th = null;
            try {
                save(bufferedOutputStream, learning);
                if (bufferedOutputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedOutputStream.close();
                    }
                }
            } finally {
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    public static void save(OutputStream outputStream, Learning learning) {
        ZipOutputStream zipOutputStream = new ZipOutputStream(outputStream);
        try {
            zipOutputStream.putNextEntry(new ZipEntry("configuration.json"));
            writeEntry(new ByteArrayInputStream(new ObjectMapper().writeValueAsString(learning.getConfiguration()).getBytes()), zipOutputStream);
            zipOutputStream.putNextEntry(new ZipEntry("dqn.bin"));
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            learning.getNeuralNet().save(byteArrayOutputStream);
            byteArrayOutputStream.flush();
            byteArrayOutputStream.close();
            writeEntry(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()), zipOutputStream);
            if (learning.getHistoryProcessor() != null) {
                zipOutputStream.putNextEntry(new ZipEntry("hpconf.bin"));
                ByteArrayOutputStream byteArrayOutputStream2 = new ByteArrayOutputStream();
                learning.getNeuralNet().save(byteArrayOutputStream2);
                byteArrayOutputStream2.flush();
                byteArrayOutputStream2.close();
                writeEntry(new ByteArrayInputStream(byteArrayOutputStream2.toByteArray()), zipOutputStream);
            }
            zipOutputStream.flush();
            zipOutputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static <C> Pair<IDQN, C> load(File file, Class<C> cls) {
        LoggerFactory.getLogger("Serializer").info("Deserializing: " + file.getName());
        Object obj = null;
        DQN dqn = null;
        try {
            ZipFile zipFile = new ZipFile(file);
            Throwable th = null;
            try {
                try {
                    InputStream inputStream = zipFile.getInputStream(zipFile.getEntry("configuration.json"));
                    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
                    StringBuilder sb = new StringBuilder();
                    while (true) {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        sb.append(readLine).append("\n");
                    }
                    String sb2 = sb.toString();
                    bufferedReader.close();
                    inputStream.close();
                    obj = new ObjectMapper().readValue(sb2, cls);
                    InputStream inputStream2 = zipFile.getInputStream(zipFile.getEntry("dqn.bin"));
                    File createTempFile = File.createTempFile("restore", "dqn");
                    Files.copy(inputStream2, Paths.get(createTempFile.getAbsolutePath(), new String[0]), StandardCopyOption.REPLACE_EXISTING);
                    dqn = new DQN(ModelSerializer.restoreMultiLayerNetwork(createTempFile));
                    inputStream2.close();
                    if (zipFile != null) {
                        if (0 != 0) {
                            try {
                                zipFile.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            zipFile.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return new Pair<>(dqn, obj);
    }

    public static <C> Pair<IDQN, C> load(String str, Class<C> cls) {
        return load(new File(str), cls);
    }

    public static <C> Pair<IDQN, C> load(InputStream inputStream, Class<C> cls) {
        try {
            File createTempFile = File.createTempFile("restore", "learning");
            Files.copy(inputStream, Paths.get(createTempFile.getAbsolutePath(), new String[0]), StandardCopyOption.REPLACE_EXISTING);
            return load(createTempFile, cls);
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    private void create(String str, boolean z) {
        this.saveData = z;
        this.dataRoot = str;
        createSubdir();
    }

    public String createSubdir() {
        if (!this.saveData) {
            return "";
        }
        File file = new File(this.dataRoot);
        file.mkdirs();
        int i = 1;
        while (childrenExist(file.listFiles(), i + "")) {
            i++;
        }
        File file2 = new File(this.dataRoot + "/" + i);
        file2.mkdirs();
        this.currentDir = file2.getAbsolutePath();
        this.log.info("Created training data directory: " + this.currentDir);
        new File(getVideoDir()).mkdirs();
        new File(getModelDir()).mkdirs();
        File file3 = new File(getStat());
        File file4 = new File(getInfo());
        try {
            file3.createNewFile();
            file4.createNewFile();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return file2.getAbsolutePath();
    }

    public String getVideoDir() {
        return this.currentDir + "/" + Constants.VIDEO_DIR;
    }

    public String getModelDir() {
        return this.currentDir + "/" + Constants.MODEL_DIR;
    }

    public String getInfo() {
        return this.currentDir + "/" + Constants.INFO_FILENAME;
    }

    public String getStat() {
        return this.currentDir + "/" + Constants.STATISTIC_FILENAME;
    }

    public void appendStat(StatEntry statEntry) {
        if (this.saveData) {
            try {
                Files.write(Paths.get(getStat(), new String[0]), toJson(statEntry).getBytes(), StandardOpenOption.APPEND);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private String toJson(Object obj) {
        try {
            return this.mapper.writeValueAsString(obj) + "\n";
        } catch (JsonProcessingException e) {
            e.printStackTrace();
            return "";
        }
    }

    public void writeInfo(ILearning iLearning) {
        if (this.saveData) {
            try {
                Files.write(Paths.get(getInfo(), new String[0]), toJson(new Info(iLearning.getClass().getSimpleName(), iLearning.getMdp().getClass().getSimpleName(), iLearning.getConfiguration(), iLearning.getStepCounter(), System.currentTimeMillis())).getBytes(), StandardOpenOption.TRUNCATE_EXISTING);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private boolean childrenExist(File[] fileArr, String str) {
        boolean z = false;
        int i = 0;
        while (true) {
            if (i >= fileArr.length) {
                break;
            }
            if (fileArr[i].getName().equals(str)) {
                z = true;
                break;
            }
            i++;
        }
        return z;
    }

    public void save(Learning learning) {
        if (this.saveData) {
            save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
            learning.getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
        }
    }

    public boolean isSaveData() {
        return this.saveData;
    }
}
