package org.apache.spark.examples.ml;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.regression.RandomForestRegressor;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;

/* loaded from: input_file:org/apache/spark/examples/ml/JavaRandomForestRegressorExample.class */
public class JavaRandomForestRegressorExample {
    public static void main(String[] strArr) {
        JavaSparkContext javaSparkContext = new JavaSparkContext(new SparkConf().setAppName("JavaRandomForestRegressorExample"));
        DataFrame load = new SQLContext(javaSparkContext).read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
        PipelineStage fit = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(load);
        DataFrame[] randomSplit = load.randomSplit(new double[]{0.7d, 0.3d});
        DataFrame dataFrame = randomSplit[0];
        DataFrame dataFrame2 = randomSplit[1];
        PipelineModel fit2 = new Pipeline().setStages(new PipelineStage[]{fit, (RandomForestRegressor) new RandomForestRegressor().setLabelCol("label").setFeaturesCol("indexedFeatures")}).fit(dataFrame);
        DataFrame transform = fit2.transform(dataFrame2);
        transform.select("prediction", new String[]{"label", "features"}).show(5);
        System.out.println("Root Mean Squared Error (RMSE) on test data = " + new RegressionEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("rmse").evaluate(transform));
        System.out.println("Learned regression forest model:\n" + fit2.stages()[1].toDebugString());
        javaSparkContext.stop();
    }
}
