package org.apache.spark.examples.ml;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.PosixParser;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.util.MetadataUtils;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;

/* loaded from: input_file:org/apache/spark/examples/ml/JavaOneVsRestExample.class */
public class JavaOneVsRestExample {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/examples/ml/JavaOneVsRestExample$Params.class */
    public static class Params {
        String input;
        String testInput;
        Integer maxIter;
        double tol;
        boolean fitIntercept;
        Double regParam;
        Double elasticNetParam;
        double fracTest;

        private Params() {
            this.testInput = null;
            this.maxIter = 100;
            this.tol = 1.0E-6d;
            this.fitIntercept = true;
            this.regParam = null;
            this.elasticNetParam = null;
            this.fracTest = 0.2d;
        }
    }

    public static void main(String[] strArr) {
        DataFrame dataFrame;
        DataFrame dataFrame2;
        Params parse = parse(strArr);
        JavaSparkContext javaSparkContext = new JavaSparkContext(new SparkConf().setAppName("JavaOneVsRestExample"));
        SQLContext sQLContext = new SQLContext(javaSparkContext);
        LogisticRegression fitIntercept = new LogisticRegression().setMaxIter(parse.maxIter.intValue()).setTol(parse.tol).setFitIntercept(parse.fitIntercept);
        if (parse.regParam != null) {
            fitIntercept.setRegParam(parse.regParam.doubleValue());
        }
        if (parse.elasticNetParam != null) {
            fitIntercept.setElasticNetParam(parse.elasticNetParam.doubleValue());
        }
        OneVsRest classifier = new OneVsRest().setClassifier(fitIntercept);
        DataFrame load = sQLContext.read().format("libsvm").load(parse.input);
        String str = parse.testInput;
        if (str != null) {
            dataFrame = load;
            dataFrame2 = sQLContext.read().format("libsvm").option("numFeatures", String.valueOf(((Vector) load.first().getAs(1)).size())).load(str);
        } else {
            double d = parse.fracTest;
            DataFrame[] randomSplit = load.randomSplit(new double[]{1.0d - d, d}, 12345L);
            dataFrame = randomSplit[0];
            dataFrame2 = randomSplit[1];
        }
        DataFrame select = classifier.fit(dataFrame.cache()).transform(dataFrame2.cache()).select("prediction", new String[]{"label"});
        MulticlassMetrics multiclassMetrics = new MulticlassMetrics(select);
        Integer num = (Integer) MetadataUtils.getNumClasses(select.schema().apply("prediction")).get();
        StringBuilder sb = new StringBuilder();
        sb.append("label\tfpr\n");
        for (int i = 0; i < num.intValue(); i++) {
            sb.append(i);
            sb.append("\t");
            sb.append(multiclassMetrics.falsePositiveRate(i));
            sb.append("\n");
        }
        Matrix confusionMatrix = multiclassMetrics.confusionMatrix();
        System.out.println("Confusion Matrix");
        System.out.println(confusionMatrix);
        System.out.println();
        System.out.println(sb);
        javaSparkContext.stop();
    }

    private static Params parse(String[] strArr) {
        Options generateCommandlineOptions = generateCommandlineOptions();
        PosixParser posixParser = new PosixParser();
        Params params = new Params();
        try {
            CommandLine parse = posixParser.parse(generateCommandlineOptions, strArr);
            if (parse.hasOption("input")) {
                params.input = parse.getOptionValue("input");
            }
            if (parse.hasOption("maxIter")) {
                params.maxIter = Integer.valueOf(Integer.parseInt(parse.getOptionValue("maxIter")));
            }
            if (parse.hasOption("tol")) {
                params.tol = Double.parseDouble(parse.getOptionValue("tol"));
            }
            if (parse.hasOption("fitIntercept")) {
                params.fitIntercept = Boolean.parseBoolean(parse.getOptionValue("fitIntercept"));
            }
            if (parse.hasOption("regParam")) {
                params.regParam = Double.valueOf(Double.parseDouble(parse.getOptionValue("regParam")));
            }
            if (parse.hasOption("elasticNetParam")) {
                params.elasticNetParam = Double.valueOf(Double.parseDouble(parse.getOptionValue("elasticNetParam")));
            }
            if (parse.hasOption("testInput")) {
                params.testInput = parse.getOptionValue("testInput");
            }
            if (parse.hasOption("fracTest")) {
                params.fracTest = Double.parseDouble(parse.getOptionValue("fracTest"));
            }
        } catch (ParseException e) {
            printHelpAndQuit(generateCommandlineOptions);
        }
        return params;
    }

    private static Options generateCommandlineOptions() {
        OptionBuilder.withArgName("input");
        OptionBuilder.hasArg();
        OptionBuilder.isRequired();
        OptionBuilder.withDescription("input path to labeled examples. This path must be specified");
        Option create = OptionBuilder.create("input");
        OptionBuilder.withArgName("testInput");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription("input path to test examples");
        Option create2 = OptionBuilder.create("testInput");
        OptionBuilder.withArgName("testInput");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription("fraction of data to hold out for testing. If given option testInput, this option is ignored. default: 0.2");
        Option create3 = OptionBuilder.create("fracTest");
        OptionBuilder.withArgName("maxIter");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription("maximum number of iterations for Logistic Regression. default:100");
        Option create4 = OptionBuilder.create("maxIter");
        OptionBuilder.withArgName("tol");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription("the convergence tolerance of iterations for Logistic Regression. default: 1E-6");
        Option create5 = OptionBuilder.create("tol");
        OptionBuilder.withArgName("fitIntercept");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription("fit intercept for logistic regression. default true");
        Option create6 = OptionBuilder.create("fitIntercept");
        OptionBuilder.withArgName("regParam");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription("the regularization parameter for Logistic Regression.");
        Option create7 = OptionBuilder.create("regParam");
        OptionBuilder.withArgName("elasticNetParam");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription("the ElasticNet mixing parameter for Logistic Regression.");
        return new Options().addOption(create).addOption(create2).addOption(create3).addOption(create4).addOption(create5).addOption(create6).addOption(create7).addOption(OptionBuilder.create("elasticNetParam"));
    }

    private static void printHelpAndQuit(Options options) {
        new HelpFormatter().printHelp("JavaOneVsRestExample", options);
        System.exit(-1);
    }
}
