package org.apache.spark.examples.ml;

import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.BucketedRandomProjectionLSH;
import org.apache.spark.ml.feature.BucketedRandomProjectionLSHModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/* loaded from: input_file:org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.class */
public class JavaBucketedRandomProjectionLSHExample {
    public static void main(String[] strArr) {
        SparkSession orCreate = SparkSession.builder().appName("JavaBucketedRandomProjectionLSHExample").getOrCreate();
        List asList = Arrays.asList(RowFactory.create(new Object[]{0, Vectors.dense(1.0d, new double[]{1.0d})}), RowFactory.create(new Object[]{1, Vectors.dense(1.0d, new double[]{-1.0d})}), RowFactory.create(new Object[]{2, Vectors.dense(-1.0d, new double[]{-1.0d})}), RowFactory.create(new Object[]{3, Vectors.dense(-1.0d, new double[]{1.0d})}));
        List asList2 = Arrays.asList(RowFactory.create(new Object[]{4, Vectors.dense(1.0d, new double[]{0.0d})}), RowFactory.create(new Object[]{5, Vectors.dense(-1.0d, new double[]{0.0d})}), RowFactory.create(new Object[]{6, Vectors.dense(0.0d, new double[]{1.0d})}), RowFactory.create(new Object[]{7, Vectors.dense(0.0d, new double[]{-1.0d})}));
        StructType structType = new StructType(new StructField[]{new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("keys", new VectorUDT(), false, Metadata.empty())});
        Dataset createDataFrame = orCreate.createDataFrame(asList, structType);
        Dataset createDataFrame2 = orCreate.createDataFrame(asList2, structType);
        Vector dense = Vectors.dense(1.0d, new double[]{0.0d});
        BucketedRandomProjectionLSHModel fit = new BucketedRandomProjectionLSH().setBucketLength(2.0d).setNumHashTables(3).setInputCol("keys").setOutputCol("values").fit(createDataFrame);
        fit.transform(createDataFrame).show();
        Dataset cache = fit.transform(createDataFrame).cache();
        Dataset cache2 = fit.transform(createDataFrame2).cache();
        fit.approxSimilarityJoin(createDataFrame, createDataFrame2, 1.5d).show();
        fit.approxSimilarityJoin(cache, cache2, 1.5d).show();
        fit.approxSimilarityJoin(createDataFrame, createDataFrame, 2.5d).filter("datasetA.id < datasetB.id").show();
        fit.approxNearestNeighbors(createDataFrame, dense, 2).show();
        fit.approxNearestNeighbors(cache, dense, 2).show();
        orCreate.stop();
    }
}
