package org.deeplearning4j.nn.modelimport.keras.trainedmodels;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/trainedmodels/TrainedModelHelper.class */
public class TrainedModelHelper {
    private final String h5URL;
    private final String jsonURL;
    private static final String BASE_DIR = ".dl4j/trainedmodels/";
    private final File MODEL_DIR;
    private final String h5FileName;
    private final String jsonFileName;
    private File h5File;
    private File jsonFile;
    private boolean userProvidedH5 = false;
    private boolean userProvidedJSON = false;
    private String[] decodeMap;
    protected static final Logger logger = LoggerFactory.getLogger((Class<?>) TrainedModelHelper.class);
    private static final File HOME_DIR = new File(System.getProperty("user.home"));

    public TrainedModelHelper(TrainedModels trainedModels) {
        this.MODEL_DIR = new File(HOME_DIR, BASE_DIR + trainedModels.getModelDir());
        this.h5URL = trainedModels.getH5URL();
        this.h5FileName = trainedModels.getH5FileName();
        this.h5File = new File(this.MODEL_DIR, this.h5FileName);
        this.jsonURL = trainedModels.getJSONURL();
        this.jsonFileName = trainedModels.getJSONFileName();
        this.jsonFile = new File(this.MODEL_DIR, this.jsonFileName);
    }

    public void setPathToH5(String str) {
        this.h5File = new File(str);
        this.userProvidedH5 = true;
        logger.info("Helper will use path given to H5 file");
    }

    public void setPathToJSON(String str) {
        this.jsonFile = new File(str);
        this.userProvidedJSON = true;
        logger.info("Helper will use path given to JSON file");
    }

    public ComputationGraph loadModel() throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        download();
        return KerasModelImport.importKerasModelAndWeights(this.jsonFile.getAbsolutePath(), this.h5File.getAbsolutePath(), false);
    }

    private void download() throws IOException {
        if (!this.h5File.exists() && !this.userProvidedH5) {
            if (!this.MODEL_DIR.isDirectory() && !this.MODEL_DIR.mkdirs()) {
                throw new IOException("Could not mkdir " + this.MODEL_DIR);
            }
            logger.info("H5 weights not found in default location. Copying from URL " + this.h5URL + "\n\tto location " + this.h5File.getAbsolutePath());
            FileUtils.copyURLToFile(new URL(this.h5URL), this.h5File);
        }
        if (this.jsonFile.exists() || this.userProvidedJSON) {
            return;
        }
        if (!this.MODEL_DIR.isDirectory() && !this.MODEL_DIR.mkdirs()) {
            throw new IOException("Could not mkdir " + this.MODEL_DIR);
        }
        logger.info("JSON config not found in default location. Copying from URL " + this.jsonURL + "\n\tto location " + this.jsonFile.getAbsolutePath());
        FileUtils.copyURLToFile(new URL(this.jsonURL), this.jsonFile);
    }
}
