package org.apache.spark.examples.mllib;

import java.util.HashMap;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;

/* loaded from: input_file:org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.class */
public class JavaRandomForestRegressionExample {
    public static void main(String[] strArr) {
        JavaSparkContext javaSparkContext = new JavaSparkContext(new SparkConf().setAppName("JavaRandomForestRegressionExample"));
        JavaRDD[] randomSplit = MLUtils.loadLibSVMFile(javaSparkContext.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD().randomSplit(new double[]{0.7d, 0.3d});
        JavaRDD javaRDD = randomSplit[0];
        JavaRDD javaRDD2 = randomSplit[1];
        Integer num = 3;
        Integer num2 = 4;
        Integer num3 = 32;
        Integer num4 = 12345;
        final RandomForestModel trainRegressor = RandomForest.trainRegressor(javaRDD, new HashMap(), num.intValue(), "auto", "variance", num2.intValue(), num3.intValue(), num4.intValue());
        System.out.println("Test Mean Squared Error: " + Double.valueOf(((Double) javaRDD2.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { // from class: org.apache.spark.examples.mllib.JavaRandomForestRegressionExample.1
            public Tuple2<Double, Double> call(LabeledPoint labeledPoint) {
                return new Tuple2<>(Double.valueOf(trainRegressor.predict(labeledPoint.features())), Double.valueOf(labeledPoint.label()));
            }
        }).map(new Function<Tuple2<Double, Double>, Double>() { // from class: org.apache.spark.examples.mllib.JavaRandomForestRegressionExample.3
            public Double call(Tuple2<Double, Double> tuple2) {
                Double valueOf = Double.valueOf(((Double) tuple2._1()).doubleValue() - ((Double) tuple2._2()).doubleValue());
                return Double.valueOf(valueOf.doubleValue() * valueOf.doubleValue());
            }
        }).reduce(new Function2<Double, Double, Double>() { // from class: org.apache.spark.examples.mllib.JavaRandomForestRegressionExample.2
            public Double call(Double d, Double d2) {
                return Double.valueOf(d.doubleValue() + d2.doubleValue());
            }
        })).doubleValue() / javaRDD2.count()));
        System.out.println("Learned regression forest model:\n" + trainRegressor.toDebugString());
        trainRegressor.save(javaSparkContext.sc(), "target/tmp/myRandomForestRegressionModel");
        RandomForestModel.load(javaSparkContext.sc(), "target/tmp/myRandomForestRegressionModel");
        javaSparkContext.stop();
    }
}
