package org.deeplearning4j.arbiter.server;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import java.io.File;
import java.util.HashMap;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator;
import org.deeplearning4j.arbiter.evaluator.multilayer.RegressionDataEvaluator;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.arbiter.server.cli.NeuralNetTypeValidator;
import org.deeplearning4j.arbiter.server.cli.ProblemTypeValidator;
import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator;
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;

/* loaded from: input_file:org/deeplearning4j/arbiter/server/ArbiterCliRunner.class */
public class ArbiterCliRunner {

    @Parameter(names = {"--modelSavePath"})
    private String modelSavePath = System.getProperty("java.io.tmpdir");

    @Parameter(names = {"--optimizationConfigPath"})
    private String optimizationConfigPath = null;

    @Parameter(names = {"--problemType"}, validateWith = ProblemTypeValidator.class)
    private String problemType = "classification";

    @Parameter(names = {"--regressionType"})
    private String regressionType = null;

    @Parameter(names = {"--dataSetIteratorClass"}, required = true)
    private String dataSetIteratorClass = null;

    @Parameter(names = {"--neuralNetType"}, required = true, validateWith = NeuralNetTypeValidator.class)
    private String neuralNetType = null;
    public static final String CLASSIFICATION = "classification";
    public static final String REGRESSION = "regression";
    public static final String COMP_GRAPH = "compgraph";
    public static final String MULTI_LAYER_NETWORK = "multilayernetwork";

    public void runMain(String... strArr) throws Exception {
        JCommander jCommander = new JCommander(this);
        try {
            jCommander.parse(strArr);
        } catch (ParameterException e) {
            System.err.println(e.getMessage());
            jCommander.usage();
            try {
                Thread.sleep(500L);
            } catch (Exception e2) {
            }
            System.exit(1);
        }
        HashMap hashMap = new HashMap();
        hashMap.put("org.deeplearning4j.arbiter.data.data.factory", this.dataSetIteratorClass);
        File file = new File(this.modelSavePath);
        if (file.exists()) {
            file.delete();
        }
        file.mkdir();
        file.deleteOnExit();
        if (this.problemType.equals("regression")) {
            if (this.neuralNetType.equals("compgraph")) {
                new LocalOptimizationRunner(OptimizationConfiguration.fromJson(FileUtils.readFileToString(new File(this.optimizationConfigPath))), new ComputationGraphTaskCreator(new RegressionDataEvaluator(RegressionValue.valueOf(this.regressionType), hashMap))).execute();
                return;
            } else {
                if (this.neuralNetType.equals(MULTI_LAYER_NETWORK)) {
                    new LocalOptimizationRunner(OptimizationConfiguration.fromJson(FileUtils.readFileToString(new File(this.optimizationConfigPath))), new MultiLayerNetworkTaskCreator(new RegressionDataEvaluator(RegressionValue.valueOf(this.regressionType), hashMap))).execute();
                    return;
                }
                return;
            }
        }
        if (this.problemType.equals("classification")) {
            if (this.neuralNetType.equals("compgraph")) {
                new LocalOptimizationRunner(OptimizationConfiguration.fromJson(FileUtils.readFileToString(new File(this.optimizationConfigPath))), new ComputationGraphTaskCreator(new ClassificationEvaluator())).execute();
            } else if (this.neuralNetType.equals(MULTI_LAYER_NETWORK)) {
                new LocalOptimizationRunner(OptimizationConfiguration.fromJson(FileUtils.readFileToString(new File(this.optimizationConfigPath))), new MultiLayerNetworkTaskCreator(new ClassificationEvaluator())).execute();
            }
        }
    }

    public static void main(String... strArr) throws Exception {
        new ArbiterCliRunner().runMain(strArr);
    }
}
