package ai.catboost.spark;

import ai.catboost.CatBoostError;
import ai.catboost.spark.params.Helpers$;
import ai.catboost.spark.params.PoolLoadParams;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.UUID;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.attribute.NominalAttribute;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.QuantizedFeaturesInfoPtr;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TFeaturesLayoutPtr;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Map;
import scala.collection.mutable.Map$;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

/* compiled from: Pool.scala */
/* loaded from: input_file:ai/catboost/spark/Pool$.class */
public final class Pool$ implements Serializable {
    public static Pool$ MODULE$;

    static {
        new Pool$();
    }

    public Dataset<Row> $lessinit$greater$default$2() {
        return null;
    }

    public TFeaturesLayoutPtr $lessinit$greater$default$3() {
        return null;
    }

    public QuantizedFeaturesInfoPtr $lessinit$greater$default$4() {
        return null;
    }

    public Dataset<Row> $lessinit$greater$default$5() {
        return null;
    }

    public boolean $lessinit$greater$default$6() {
        return false;
    }

    private Dataset<Row> updateSparseFeaturesSize(Dataset<Row> dataset) {
        SparkSession sparkSession = dataset.sparkSession();
        Dataset mapPartitions = dataset.mapPartitions(iterator -> {
            IntRef create = IntRef.create(0);
            iterator.foreach(row -> {
                $anonfun$updateSparseFeaturesSize$2(create, row);
                return BoxedUnit.UNIT;
            });
            return scala.package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapIntArray(new int[]{create.elem}));
        }, sparkSession.implicits().newIntEncoder());
        IntRef create = IntRef.create(0);
        new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) mapPartitions.collect())).foreach(i -> {
            if (i > create.elem) {
                create.elem = i;
            }
        });
        String[] featureNames = getFeatureNames(dataset, "features");
        String[] strArr = (String[]) Arrays.copyOf(featureNames, create.elem);
        Arrays.fill(strArr, featureNames.length, create.elem, "");
        return dataset.withColumn("features", functions$.MODULE$.udf(vector -> {
            SparseVector sparseVector = (SparseVector) vector;
            return Vectors$.MODULE$.sparse(create.elem, sparseVector.indices(), sparseVector.values());
        }, scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: ai.catboost.spark.Pool$$typecreator1$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        }), scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: ai.catboost.spark.Pool$$typecreator2$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{sparkSession.implicits().StringToColumn(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"features"}))).$(Nil$.MODULE$)})).as("_", DataHelpers$.MODULE$.makeFeaturesMetadata(strArr)));
    }

    public Pool load(SparkSession sparkSession, String str, Path path, PoolLoadParams poolLoadParams, String str2) {
        Dataset<Row> dataset;
        String[] split = str.split("://", 2);
        Tuple2 tuple2 = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(split)).size() == 1 ? new Tuple2("dsv", split[0]) : new Tuple2(split[0], split[1]);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((String) tuple2._1(), (String) tuple2._2());
        String str3 = (String) tuple22._1();
        String str4 = (String) tuple22._2();
        if (!("dsv".equals(str3) ? true : "libsvm".equals(str3))) {
            throw new CatBoostError(new StringBuilder(42).append("Loading pool from scheme ").append(str3).append(" is not supported").toString());
        }
        Map apply = Map$.MODULE$.apply(Nil$.MODULE$);
        if (str2 != null) {
            apply.update("addSampleId", "true");
            dataset = CatBoostPairsDataLoader$.MODULE$.load(sparkSession, str2);
        } else {
            dataset = null;
        }
        Dataset<Row> dataset2 = dataset;
        apply.update("dataScheme", str3);
        poolLoadParams.extractParamMap().toSeq().foreach(paramPair -> {
            $anonfun$load$1(apply, paramPair);
            return BoxedUnit.UNIT;
        });
        if (path != null) {
            apply.update("columnDescription", path.toString());
        }
        apply.update("catboostJsonParams", Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParamsString(poolLoadParams));
        apply.update("uuid", UUID.randomUUID().toString());
        Dataset<Row> load = sparkSession.read().format("ai.catboost.spark.CatBoostTextFileFormat").options(apply).load(str4);
        if (dataset2 != null) {
            load = DataHelpers$.MODULE$.mapSampleIdxToPerGroupSampleIdx(load);
        }
        Pool pool = new Pool((str3 != null ? !str3.equals("libsvm") : "libsvm" != 0) ? load : updateSparseFeaturesSize(load), dataset2);
        setColumnParamsFromLoadedData(pool);
        return pool;
    }

    public Path load$default$3() {
        return null;
    }

    public PoolLoadParams load$default$4() {
        return new PoolLoadParams();
    }

    public String load$default$5() {
        return null;
    }

    public PoolReader read(SparkSession sparkSession) {
        return new PoolReader(sparkSession);
    }

    public void setColumnParamsFromLoadedData(Pool pool) {
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(pool.data().columns())).foreach(str -> {
            return pool.set(new StringBuilder(3).append(str).append("Col").toString(), str);
        });
    }

    public int getFeatureCount(Dataset<Row> dataset, String str) {
        AttributeGroup fromStructField = AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(str));
        Option numAttributes = fromStructField.numAttributes();
        if (numAttributes.isDefined()) {
            return BoxesRunTime.unboxToInt(numAttributes.get());
        }
        Option attributes = fromStructField.attributes();
        if (attributes.isDefined()) {
            return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) attributes.get())).size();
        }
        if (dataset.count() == 0) {
            throw new CatBoostError("Cannot get feature count from empty DataFrame without attributes");
        }
        return ((Vector) ((Row) dataset.first()).getAs(str)).size();
    }

    public String[] getFeatureNames(Dataset<Row> dataset, String str) {
        int featureCount = getFeatureCount(dataset, str);
        Option attributes = AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(str)).attributes();
        if (attributes.isEmpty()) {
            String[] strArr = new String[featureCount];
            Arrays.fill(strArr, 0, featureCount, "");
            return strArr;
        }
        if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) attributes.get())).size() != featureCount) {
            throw new CatBoostError(new StringBuilder(55).append("number of attributes (").append(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) attributes.get())).size()).append(") is not equal to featureCount (").append(featureCount).append(")").toString());
        }
        return (String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) attributes.get())).map(attribute -> {
            return (String) attribute.name().getOrElse(() -> {
                return "";
            });
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).toArray(ClassTag$.MODULE$.apply(String.class));
    }

    public int[] getCatFeaturesUniqValueCounts(Dataset<Row> dataset, String str) {
        int featureCount = getFeatureCount(dataset, str);
        Option attributes = AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(str)).attributes();
        if (attributes.isEmpty()) {
            int[] iArr = new int[featureCount];
            Arrays.fill(iArr, 0);
            return iArr;
        }
        if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) attributes.get())).size() != featureCount) {
            throw new CatBoostError(new StringBuilder(55).append("number of attributes (").append(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) attributes.get())).size()).append(") is not equal to featureCount (").append(featureCount).append(")").toString());
        }
        return (int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) attributes.get())).map(attribute -> {
            return BoxesRunTime.boxToInteger($anonfun$getCatFeaturesUniqValueCounts$1(attribute));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())))).toArray(ClassTag$.MODULE$.Int());
    }

    public int getCatFeaturesMaxUniqValueCount(Dataset<Row> dataset, String str) {
        return BoxesRunTime.unboxToInt(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(getCatFeaturesUniqValueCounts(dataset, str))).max(Ordering$Int$.MODULE$));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ void $anonfun$updateSparseFeaturesSize$2(IntRef intRef, Row row) {
        int size = ((SparseVector) row.getAs(0)).size();
        if (size > intRef.elem) {
            intRef.elem = size;
        }
    }

    public static final /* synthetic */ void $anonfun$load$1(Map map, ParamPair paramPair) {
        if (paramPair == null) {
            throw new MatchError(paramPair);
        }
        map.update(paramPair.param().name(), paramPair.value().toString());
        BoxedUnit boxedUnit = BoxedUnit.UNIT;
    }

    public static final /* synthetic */ int $anonfun$getCatFeaturesUniqValueCounts$1(Attribute attribute) {
        int i;
        int length;
        if (attribute instanceof NominalAttribute) {
            NominalAttribute nominalAttribute = (NominalAttribute) attribute;
            if (nominalAttribute.numValues().isDefined()) {
                length = BoxesRunTime.unboxToInt(nominalAttribute.numValues().get());
            } else {
                if (nominalAttribute.values().isEmpty()) {
                    throw new CatBoostError("Neither numValues nor values is defined for categorical feature attribute");
                }
                length = ((String[]) nominalAttribute.values().get()).length;
            }
            i = length;
        } else {
            i = 0;
        }
        return i;
    }

    private Pool$() {
        MODULE$ = this;
    }
}
