package org.apache.spark.ml.classification;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.PredictorParams;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.Attribute$;
import org.apache.spark.ml.attribute.NominalAttribute$;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.attribute.UnresolvedAttribute$;
import org.apache.spark.ml.classification.OneVsRestParams;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MetadataUtils$;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.MatchError;
import scala.Predef$;
import scala.collection.parallel.ParIterableLike;
import scala.collection.parallel.immutable.ParSeq$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: OneVsRest.scala */
@ScalaSignature(bytes = "\u0006\u0001y4A!\u0001\u0002\u0003\u001b\tIqJ\\3WgJ+7\u000f\u001e\u0006\u0003\u0007\u0011\tab\u00197bgNLg-[2bi&|gN\u0003\u0002\u0006\r\u0005\u0011Q\u000e\u001c\u0006\u0003\u000f!\tQa\u001d9be.T!!\u0003\u0006\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005Y\u0011aA8sO\u000e\u00011c\u0001\u0001\u000f-A\u0019q\u0002\u0005\n\u000e\u0003\u0011I!!\u0005\u0003\u0003\u0013\u0015\u001bH/[7bi>\u0014\bCA\n\u0015\u001b\u0005\u0011\u0011BA\u000b\u0003\u00059ye.\u001a,t%\u0016\u001cH/T8eK2\u0004\"aE\f\n\u0005a\u0011!aD(oKZ\u001b(+Z:u!\u0006\u0014\u0018-\\:\t\u0011i\u0001!Q1A\u0005Bm\t1!^5e+\u0005a\u0002CA\u000f$\u001d\tq\u0012%D\u0001 \u0015\u0005\u0001\u0013!B:dC2\f\u0017B\u0001\u0012 \u0003\u0019\u0001&/\u001a3fM&\u0011A%\n\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005\tz\u0002\u0002C\u0014\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u000f\u0002\tULG\r\t\u0005\u0006S\u0001!\tAK\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0005-b\u0003CA\n\u0001\u0011\u0015Q\u0002\u00061\u0001\u001d\u0011\u0015I\u0003\u0001\"\u0001/)\u0005Y\u0003\"\u0002\u0019\u0001\t\u0003\t\u0014!D:fi\u000ec\u0017m]:jM&,'\u000f\u0006\u00023g5\t\u0001\u0001C\u00035_\u0001\u0007Q'A\u0003wC2,X\r\r\u00037w\u0015C\u0005#B\n8s\u0011;\u0015B\u0001\u001d\u0003\u0005)\u0019E.Y:tS\u001aLWM\u001d\t\u0003umb\u0001\u0001B\u0005=g\u0005\u0005\t\u0011!B\u0001{\t\u0019q\f\n\u001c\u0012\u0005y\n\u0005C\u0001\u0010@\u0013\t\u0001uDA\u0004O_RD\u0017N\\4\u0011\u0005y\u0011\u0015BA\" \u0005\r\te.\u001f\t\u0003u\u0015#\u0011BR\u001a\u0002\u0002\u0003\u0005)\u0011A\u001f\u0003\u0007}#s\u0007\u0005\u0002;\u0011\u0012I\u0011jMA\u0001\u0002\u0003\u0015\t!\u0010\u0002\u0004?\u0012B\u0004\"B&\u0001\t\u0003a\u0015aC:fi2\u000b'-\u001a7D_2$\"AM'\t\u000bQR\u0005\u0019\u0001\u000f\t\u000b=\u0003A\u0011\u0001)\u0002\u001dM,GOR3biV\u0014Xm]\"pYR\u0011!'\u0015\u0005\u0006i9\u0003\r\u0001\b\u0005\u0006'\u0002!\t\u0001V\u0001\u0011g\u0016$\bK]3eS\u000e$\u0018n\u001c8D_2$\"AM+\t\u000bQ\u0012\u0006\u0019\u0001\u000f\t\u000b]\u0003A\u0011\t-\u0002\u001fQ\u0014\u0018M\\:g_Jl7k\u00195f[\u0006$\"!W1\u0011\u0005i{V\"A.\u000b\u0005qk\u0016!\u0002;za\u0016\u001c(B\u00010\u0007\u0003\r\u0019\u0018\u000f\\\u0005\u0003An\u0013!b\u0015;sk\u000e$H+\u001f9f\u0011\u0015\u0011g\u000b1\u0001Z\u0003\u0019\u00198\r[3nC\")A\r\u0001C!K\u0006\u0019a-\u001b;\u0015\u0005I1\u0007\"B4d\u0001\u0004A\u0017a\u00023bi\u0006\u001cX\r\u001e\t\u0003S*l\u0011!X\u0005\u0003Wv\u0013\u0011\u0002R1uC\u001a\u0013\u0018-\\3\t\u000b5\u0004A\u0011\t8\u0002\t\r|\u0007/\u001f\u000b\u0003W=DQ\u0001\u001d7A\u0002E\fQ!\u001a=ue\u0006\u0004\"A];\u000e\u0003MT!\u0001\u001e\u0003\u0002\u000bA\f'/Y7\n\u0005Y\u001c(\u0001\u0003)be\u0006lW*\u00199)\u0005\u0001A\bCA=}\u001b\u0005Q(BA>\u0007\u0003)\tgN\\8uCRLwN\\\u0005\u0003{j\u0014A\"\u0012=qKJLW.\u001a8uC2\u0004")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/classification/OneVsRest.class */
public final class OneVsRest extends Estimator<OneVsRestModel> implements OneVsRestParams {
    private final String uid;
    private final Param<Classifier<?, ? extends Classifier<Object, Classifier, ClassificationModel>, ? extends ClassificationModel<Object, ClassificationModel>>> classifier;
    private final Param<String> predictionCol;
    private final Param<String> featuresCol;
    private final Param<String> labelCol;

    @Override // org.apache.spark.ml.classification.OneVsRestParams
    public Param<Classifier<?, ? extends Classifier<Object, Classifier, ClassificationModel>, ? extends ClassificationModel<Object, ClassificationModel>>> classifier() {
        return this.classifier;
    }

    @Override // org.apache.spark.ml.classification.OneVsRestParams
    public void org$apache$spark$ml$classification$OneVsRestParams$_setter_$classifier_$eq(Param param) {
        this.classifier = param;
    }

    @Override // org.apache.spark.ml.classification.OneVsRestParams
    public Classifier<?, ? extends Classifier<Object, Classifier, ClassificationModel>, ? extends ClassificationModel<Object, ClassificationModel>> getClassifier() {
        return OneVsRestParams.Cclass.getClassifier(this);
    }

    @Override // org.apache.spark.ml.PredictorParams
    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return PredictorParams.Cclass.validateAndTransformSchema(this, structType, z, dataType);
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param param) {
        this.predictionCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final String getPredictionCol() {
        return HasPredictionCol.Cclass.getPredictionCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final void org$apache$spark$ml$param$shared$HasFeaturesCol$_setter_$featuresCol_$eq(Param param) {
        this.featuresCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final String getFeaturesCol() {
        return HasFeaturesCol.Cclass.getFeaturesCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public OneVsRest setClassifier(Classifier<?, ?, ?> classifier) {
        return (OneVsRest) set((Param<Param<Classifier<?, ? extends Classifier<Object, Classifier, ClassificationModel>, ? extends ClassificationModel<Object, ClassificationModel>>>>) classifier(), (Param<Classifier<?, ? extends Classifier<Object, Classifier, ClassificationModel>, ? extends ClassificationModel<Object, ClassificationModel>>>) classifier);
    }

    public OneVsRest setLabelCol(String str) {
        return (OneVsRest) set((Param<Param<String>>) labelCol(), (Param<String>) str);
    }

    public OneVsRest setFeaturesCol(String str) {
        return (OneVsRest) set((Param<Param<String>>) featuresCol(), (Param<String>) str);
    }

    public OneVsRest setPredictionCol(String str) {
        return (OneVsRest) set((Param<Param<String>>) predictionCol(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        return validateAndTransformSchema(structType, true, getClassifier().featuresDataType());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.ml.Estimator
    public OneVsRestModel fit(DataFrame dataFrame) {
        boolean z;
        Attribute attribute;
        StructField apply = dataFrame.schema().apply((String) $(labelCol()));
        int unboxToInt = BoxesRunTime.unboxToInt(MetadataUtils$.MODULE$.getNumClasses(apply).fold(new OneVsRest$$anonfun$1(this, dataFrame), new OneVsRest$$anonfun$2(this)));
        DataFrame select = dataFrame.select((String) $(labelCol()), Predef$.MODULE$.wrapRefArray(new String[]{(String) $(featuresCol())}));
        StorageLevel storageLevel = dataFrame.rdd().getStorageLevel();
        StorageLevel NONE = StorageLevel$.MODULE$.NONE();
        boolean z2 = storageLevel != null ? storageLevel.equals(NONE) : NONE == null;
        if (z2) {
            select.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        ClassificationModel[] classificationModelArr = (ClassificationModel[]) ((ParIterableLike) package$.MODULE$.Range().apply(0, unboxToInt).par().map(new OneVsRest$$anonfun$9(this, select), ParSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(ClassificationModel.class));
        if (z2) {
            select.unpersist();
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        Attribute fromStructField = Attribute$.MODULE$.fromStructField(apply);
        if (fromStructField instanceof NumericAttribute) {
            z = true;
        } else {
            UnresolvedAttribute$ unresolvedAttribute$ = UnresolvedAttribute$.MODULE$;
            z = unresolvedAttribute$ != null ? unresolvedAttribute$.equals(fromStructField) : fromStructField == null;
        }
        if (z) {
            attribute = NominalAttribute$.MODULE$.defaultAttr().withName("label").withNumValues(unboxToInt);
        } else {
            if (fromStructField == null) {
                throw new MatchError(fromStructField);
            }
            attribute = fromStructField;
        }
        return (OneVsRestModel) copyValues(new OneVsRestModel(uid(), attribute.toMetadata(), classificationModelArr).setParent(this), copyValues$default$2());
    }

    @Override // org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public OneVsRest copy(ParamMap paramMap) {
        OneVsRest oneVsRest = (OneVsRest) defaultCopy(paramMap);
        if (isDefined(classifier())) {
            oneVsRest.setClassifier((Classifier) ((Predictor) $(classifier())).copy(paramMap));
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return oneVsRest;
    }

    public OneVsRest(String str) {
        this.uid = str;
        HasLabelCol.Cclass.$init$(this);
        HasFeaturesCol.Cclass.$init$(this);
        HasPredictionCol.Cclass.$init$(this);
        PredictorParams.Cclass.$init$(this);
        org$apache$spark$ml$classification$OneVsRestParams$_setter_$classifier_$eq(new Param(this, "classifier", "base binary classifier"));
    }

    public OneVsRest() {
        this(Identifiable$.MODULE$.randomUID("oneVsRest"));
    }
}
