package org.deeplearning4j;

import java.io.IOException;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.cli.subcommands.BaseSubCommand;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.spark.impl.computationgraph.SparkComputationGraph;
import org.deeplearning4j.util.ModelSerializer;
import org.kohsuke.args4j.Option;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:org/deeplearning4j/SparkTrain.class */
public class SparkTrain extends BaseSubCommand {

    @Option(name = "--model", usage = "model file (json,yaml,..) to resume training", aliases = {"-mo"}, required = true)
    private String modelInput;

    @Option(name = "--conf", usage = "computation graph configuration", aliases = {"-c"}, required = true)
    private String confInput;

    @Option(name = "--masterUri", usage = "spark master uri", aliases = {"-ma"}, required = true)
    private String masterUri;

    @Option(name = "--input", usage = "input data", aliases = {"-i"}, required = true)
    private String masterInputUri;

    @Option(name = "--type", usage = "input data type", aliases = {"-t"}, required = true)
    private String inputType;

    @Option(name = "--examplesPerFit", usage = "examples per fit", aliases = {"-b"}, required = true)
    private int examplesPerFit;

    @Option(name = "--totalExamples", usage = "total number of examples", aliases = {"-n"}, required = true)
    private int totalExamples;

    @Option(name = "--numPartitions", usage = "number of partitions", aliases = {"-p"}, required = true)
    private int numPartitions;

    @Option(name = "--output", usage = "output path", aliases = {"-o"}, required = true)
    private String outputPath;
    private SparkContext sc;

    public SparkTrain(String[] strArr) {
        super(strArr);
    }

    private SparkContext getContext() {
        if (this.sc != null) {
            return this.sc;
        }
        return null;
    }

    private JavaRDD<DataSet> getDataSet() {
        getContext();
        if (this.inputType.equals("binary") || this.inputType.equals("text")) {
            return null;
        }
        throw new IllegalArgumentException("Input type must be either binary or text.");
    }

    private ComputationGraph getComputationGraph() throws IOException {
        if (this.confInput != null && this.modelInput != null) {
            throw new IllegalArgumentException("Conf and model input both can't be defined");
        }
        ComputationGraph computationGraph = null;
        if (this.confInput != null) {
            computationGraph = new ComputationGraph(ComputationGraphConfiguration.fromJson(this.confInput));
            computationGraph.init();
        } else if (this.modelInput != null) {
            computationGraph = ModelSerializer.restoreComputationGraph(this.modelInput);
        }
        return computationGraph;
    }

    private void saveGraph(ComputationGraph computationGraph) {
    }

    public void execute() {
        try {
            saveGraph(new SparkComputationGraph(getContext(), getComputationGraph()).fitDataSet(getDataSet(), this.examplesPerFit, this.totalExamples, this.numPartitions));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
