package org.apache.spark.ml.classification;

import org.apache.spark.SparkException;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.classification.LogisticRegressionParams;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasElasticNetParam;
import org.apache.spark.ml.param.shared.HasFitIntercept;
import org.apache.spark.ml.param.shared.HasMaxIter;
import org.apache.spark.ml.param.shared.HasRegParam;
import org.apache.spark.ml.param.shared.HasStandardization;
import org.apache.spark.ml.param.shared.HasThreshold;
import org.apache.spark.ml.param.shared.HasTol;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.sql.DataFrame;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.Seq;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: LogisticRegression.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\re\u0001B\u0001\u0003\u00015\u0011q\u0003T8hSN$\u0018n\u0019*fOJ,7o]5p]6{G-\u001a7\u000b\u0005\r!\u0011AD2mCN\u001c\u0018NZ5dCRLwN\u001c\u0006\u0003\u000b\u0019\t!!\u001c7\u000b\u0005\u001dA\u0011!B:qCJ\\'BA\u0005\u000b\u0003\u0019\t\u0007/Y2iK*\t1\"A\u0002pe\u001e\u001c\u0001aE\u0002\u0001\u001dm\u0001Ba\u0004\t\u001355\t!!\u0003\u0002\u0012\u0005\t\u0001\u0003K]8cC\nLG.[:uS\u000e\u001cE.Y:tS\u001aL7-\u0019;j_:lu\u000eZ3m!\t\u0019\u0002$D\u0001\u0015\u0015\t)b#\u0001\u0004mS:\fGn\u001a\u0006\u0003/\u0019\tQ!\u001c7mS\nL!!\u0007\u000b\u0003\rY+7\r^8s!\ty\u0001\u0001\u0005\u0002\u00109%\u0011QD\u0001\u0002\u0019\u0019><\u0017n\u001d;jGJ+wM]3tg&|g\u000eU1sC6\u001c\b\u0002C\u0010\u0001\u0005\u000b\u0007I\u0011\t\u0011\u0002\u0007ULG-F\u0001\"!\t\u0011\u0003F\u0004\u0002$M5\tAEC\u0001&\u0003\u0015\u00198-\u00197b\u0013\t9C%\u0001\u0004Qe\u0016$WMZ\u0005\u0003S)\u0012aa\u0015;sS:<'BA\u0014%\u0011!a\u0003A!A!\u0002\u0013\t\u0013\u0001B;jI\u0002B\u0001B\f\u0001\u0003\u0006\u0004%\taL\u0001\bo\u0016Lw\r\u001b;t+\u0005\u0011\u0002\u0002C\u0019\u0001\u0005\u0003\u0005\u000b\u0011\u0002\n\u0002\u0011],\u0017n\u001a5ug\u0002B\u0001b\r\u0001\u0003\u0006\u0004%\t\u0001N\u0001\nS:$XM]2faR,\u0012!\u000e\t\u0003GYJ!a\u000e\u0013\u0003\r\u0011{WO\u00197f\u0011!I\u0004A!A!\u0002\u0013)\u0014AC5oi\u0016\u00148-\u001a9uA!11\b\u0001C\u0001\tq\na\u0001P5oSRtD\u0003\u0002\u000e>}}BQa\b\u001eA\u0002\u0005BQA\f\u001eA\u0002IAQa\r\u001eA\u0002UBQ!\u0011\u0001\u0005B\t\u000bAb]3u)\"\u0014Xm\u001d5pY\u0012$\"a\u0011#\u000e\u0003\u0001AQ!\u0012!A\u0002U\nQA^1mk\u0016DQa\u0012\u0001\u0005BQ\nAbZ3u)\"\u0014Xm\u001d5pY\u0012DQ!\u0013\u0001\u0005B)\u000bQb]3u)\"\u0014Xm\u001d5pY\u0012\u001cHCA\"L\u0011\u0015)\u0005\n1\u0001M!\r\u0019S*N\u0005\u0003\u001d\u0012\u0012Q!\u0011:sCfDQ\u0001\u0015\u0001\u0005BE\u000bQbZ3u)\"\u0014Xm\u001d5pY\u0012\u001cX#\u0001'\t\u000fM\u0003!\u0019!C\u0005)\u00061Q.\u0019:hS:,\u0012!\u0016\t\u0005GY\u0013R'\u0003\u0002XI\tIa)\u001e8di&|g.\r\u0005\u00073\u0002\u0001\u000b\u0011B+\u0002\u000f5\f'oZ5oA!91\f\u0001b\u0001\n\u0013!\u0016!B:d_J,\u0007BB/\u0001A\u0003%Q+\u0001\u0004tG>\u0014X\r\t\u0005\b?\u0002\u0011\r\u0011\"\u0011a\u0003-qW/\u001c$fCR,(/Z:\u0016\u0003\u0005\u0004\"a\t2\n\u0005\r$#aA%oi\"1Q\r\u0001Q\u0001\n\u0005\fAB\\;n\r\u0016\fG/\u001e:fg\u0002Bqa\u001a\u0001C\u0002\u0013\u0005\u0003-\u0001\u0006ok6\u001cE.Y:tKNDa!\u001b\u0001!\u0002\u0013\t\u0017a\u00038v[\u000ec\u0017m]:fg\u0002Bqa\u001b\u0001A\u0002\u0013%A.A\bue\u0006Lg.\u001b8h'VlW.\u0019:z+\u0005i\u0007cA\u0012oa&\u0011q\u000e\n\u0002\u0007\u001fB$\u0018n\u001c8\u0011\u0005=\t\u0018B\u0001:\u0003\u0005\u0005bunZ5ti&\u001c'+Z4sKN\u001c\u0018n\u001c8Ue\u0006Lg.\u001b8h'VlW.\u0019:z\u0011\u001d!\b\u00011A\u0005\nU\f1\u0003\u001e:bS:LgnZ*v[6\f'/_0%KF$\"A^=\u0011\u0005\r:\u0018B\u0001=%\u0005\u0011)f.\u001b;\t\u000fi\u001c\u0018\u0011!a\u0001[\u0006\u0019\u0001\u0010J\u0019\t\rq\u0004\u0001\u0015)\u0003n\u0003A!(/Y5oS:<7+^7nCJL\b\u0005C\u0003\u007f\u0001\u0011\u0005q0A\u0004tk6l\u0017M]=\u0016\u0003AD\u0001\"a\u0001\u0001\t\u0003\u0011\u0011QA\u0001\u000bg\u0016$8+^7nCJLHcA\"\u0002\b!1a0!\u0001A\u0002ADq!a\u0003\u0001\t\u0003\ti!\u0001\u0006iCN\u001cV/\\7bef,\"!a\u0004\u0011\u0007\r\n\t\"C\u0002\u0002\u0014\u0011\u0012qAQ8pY\u0016\fg\u000e\u0003\u0005\u0002\u0018\u0001!\tAAA\r\u0003!)g/\u00197vCR,G\u0003BA\u000e\u0003C\u00012aDA\u000f\u0013\r\tyB\u0001\u0002\u001a\u0019><\u0017n\u001d;jGJ+wM]3tg&|gnU;n[\u0006\u0014\u0018\u0010\u0003\u0005\u0002$\u0005U\u0001\u0019AA\u0013\u0003\u001d!\u0017\r^1tKR\u0004B!a\n\u0002.5\u0011\u0011\u0011\u0006\u0006\u0004\u0003W1\u0011aA:rY&!\u0011qFA\u0015\u0005%!\u0015\r^1Ge\u0006lW\rC\u0004\u00024\u0001!\t&!\u000e\u0002\u000fA\u0014X\rZ5diR\u0019Q'a\u000e\t\u000f\u0005e\u0012\u0011\u0007a\u0001%\u0005Aa-Z1ukJ,7\u000fC\u0004\u0002>\u0001!\t&a\u0010\u0002-I\fwO\r9s_\n\f'-\u001b7jifLe\u000e\u00157bG\u0016$2AEA!\u0011\u001d\t\u0019%a\u000fA\u0002I\tQB]1x!J,G-[2uS>t\u0007bBA$\u0001\u0011E\u0013\u0011J\u0001\u000baJ,G-[2u%\u0006<Hc\u0001\n\u0002L!9\u0011\u0011HA#\u0001\u0004\u0011\u0002bBA(\u0001\u0011\u0005\u0013\u0011K\u0001\u0005G>\u0004\u0018\u0010F\u0002\u001b\u0003'B\u0001\"!\u0016\u0002N\u0001\u0007\u0011qK\u0001\u0006Kb$(/\u0019\t\u0005\u00033\ny&\u0004\u0002\u0002\\)\u0019\u0011Q\f\u0003\u0002\u000bA\f'/Y7\n\t\u0005\u0005\u00141\f\u0002\t!\u0006\u0014\u0018-\\'ba\"9\u0011Q\r\u0001\u0005R\u0005\u001d\u0014A\u0004:boJ\u0002(/\u001a3jGRLwN\u001c\u000b\u0004k\u0005%\u0004bBA\"\u0003G\u0002\rA\u0005\u0005\b\u0003[\u0002A\u0011KA8\u0003Y\u0001(o\u001c2bE&d\u0017\u000e^=3aJ,G-[2uS>tGcA\u001b\u0002r!9\u00111OA6\u0001\u0004\u0011\u0012a\u00039s_\n\f'-\u001b7jifD3\u0001AA<!\u0011\tI(a \u000e\u0005\u0005m$bAA?\r\u0005Q\u0011M\u001c8pi\u0006$\u0018n\u001c8\n\t\u0005\u0005\u00151\u0010\u0002\r\u000bb\u0004XM]5nK:$\u0018\r\u001c")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/classification/LogisticRegressionModel.class */
public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector, LogisticRegressionModel> implements LogisticRegressionParams {
    private final String uid;
    private final Vector weights;
    private final double intercept;
    private final Function1<Vector, Object> org$apache$spark$ml$classification$LogisticRegressionModel$$margin;
    private final Function1<Vector, Object> score;
    private final int numFeatures;
    private final int numClasses;
    private Option<LogisticRegressionTrainingSummary> trainingSummary;
    private final DoubleParam threshold;
    private final Param<String> weightCol;
    private final BooleanParam standardization;
    private final DoubleParam tol;
    private final BooleanParam fitIntercept;
    private final IntParam maxIter;
    private final DoubleParam elasticNetParam;
    private final DoubleParam regParam;

    @Override // org.apache.spark.ml.classification.LogisticRegressionParams
    public void checkThresholdConsistency() {
        LogisticRegressionParams.Cclass.checkThresholdConsistency(this);
    }

    @Override // org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public void validateParams() {
        LogisticRegressionParams.Cclass.validateParams(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasThreshold
    public final DoubleParam threshold() {
        return this.threshold;
    }

    @Override // org.apache.spark.ml.param.shared.HasThreshold
    public final void org$apache$spark$ml$param$shared$HasThreshold$_setter_$threshold_$eq(DoubleParam doubleParam) {
        this.threshold = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasWeightCol
    public final Param<String> weightCol() {
        return this.weightCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasWeightCol
    public final void org$apache$spark$ml$param$shared$HasWeightCol$_setter_$weightCol_$eq(Param param) {
        this.weightCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasWeightCol
    public final String getWeightCol() {
        return HasWeightCol.Cclass.getWeightCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasStandardization
    public final BooleanParam standardization() {
        return this.standardization;
    }

    @Override // org.apache.spark.ml.param.shared.HasStandardization
    public final void org$apache$spark$ml$param$shared$HasStandardization$_setter_$standardization_$eq(BooleanParam booleanParam) {
        this.standardization = booleanParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasStandardization
    public final boolean getStandardization() {
        return HasStandardization.Cclass.getStandardization(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam doubleParam) {
        this.tol = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final double getTol() {
        return HasTol.Cclass.getTol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasFitIntercept
    public final BooleanParam fitIntercept() {
        return this.fitIntercept;
    }

    @Override // org.apache.spark.ml.param.shared.HasFitIntercept
    public final void org$apache$spark$ml$param$shared$HasFitIntercept$_setter_$fitIntercept_$eq(BooleanParam booleanParam) {
        this.fitIntercept = booleanParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasFitIntercept
    public final boolean getFitIntercept() {
        return HasFitIntercept.Cclass.getFitIntercept(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam intParam) {
        this.maxIter = intParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final int getMaxIter() {
        return HasMaxIter.Cclass.getMaxIter(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasElasticNetParam
    public final DoubleParam elasticNetParam() {
        return this.elasticNetParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasElasticNetParam
    public final void org$apache$spark$ml$param$shared$HasElasticNetParam$_setter_$elasticNetParam_$eq(DoubleParam doubleParam) {
        this.elasticNetParam = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasElasticNetParam
    public final double getElasticNetParam() {
        return HasElasticNetParam.Cclass.getElasticNetParam(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasRegParam
    public final DoubleParam regParam() {
        return this.regParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasRegParam
    public final void org$apache$spark$ml$param$shared$HasRegParam$_setter_$regParam_$eq(DoubleParam doubleParam) {
        this.regParam = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasRegParam
    public final double getRegParam() {
        return HasRegParam.Cclass.getRegParam(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public Vector weights() {
        return this.weights;
    }

    public double intercept() {
        return this.intercept;
    }

    @Override // org.apache.spark.ml.classification.LogisticRegressionParams
    public LogisticRegressionModel setThreshold(double d) {
        return (LogisticRegressionModel) LogisticRegressionParams.Cclass.setThreshold(this, d);
    }

    @Override // org.apache.spark.ml.classification.LogisticRegressionParams, org.apache.spark.ml.param.shared.HasThreshold
    public double getThreshold() {
        return LogisticRegressionParams.Cclass.getThreshold(this);
    }

    @Override // org.apache.spark.ml.classification.LogisticRegressionParams
    public LogisticRegressionModel setThresholds(double[] dArr) {
        return (LogisticRegressionModel) LogisticRegressionParams.Cclass.setThresholds(this, dArr);
    }

    @Override // org.apache.spark.ml.classification.ProbabilisticClassificationModel, org.apache.spark.ml.param.shared.HasThresholds, org.apache.spark.ml.classification.LogisticRegressionParams
    public double[] getThresholds() {
        return LogisticRegressionParams.Cclass.getThresholds(this);
    }

    public Function1<Vector, Object> org$apache$spark$ml$classification$LogisticRegressionModel$$margin() {
        return this.org$apache$spark$ml$classification$LogisticRegressionModel$$margin;
    }

    private Function1<Vector, Object> score() {
        return this.score;
    }

    @Override // org.apache.spark.ml.PredictionModel
    public int numFeatures() {
        return this.numFeatures;
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel
    public int numClasses() {
        return this.numClasses;
    }

    private Option<LogisticRegressionTrainingSummary> trainingSummary() {
        return this.trainingSummary;
    }

    private void trainingSummary_$eq(Option<LogisticRegressionTrainingSummary> option) {
        this.trainingSummary = option;
    }

    public LogisticRegressionTrainingSummary summary() {
        Some trainingSummary = trainingSummary();
        if (trainingSummary instanceof Some) {
            return (LogisticRegressionTrainingSummary) trainingSummary.x();
        }
        None$ none$ = None$.MODULE$;
        if (none$ != null ? !none$.equals(trainingSummary) : trainingSummary != null) {
            throw new MatchError(trainingSummary);
        }
        throw new SparkException("No training summary available for this LogisticRegressionModel", new NullPointerException());
    }

    public LogisticRegressionModel setSummary(LogisticRegressionTrainingSummary logisticRegressionTrainingSummary) {
        trainingSummary_$eq(new Some(logisticRegressionTrainingSummary));
        return this;
    }

    public boolean hasSummary() {
        return trainingSummary().isDefined();
    }

    public LogisticRegressionSummary evaluate(DataFrame dataFrame) {
        return new BinaryLogisticRegressionSummary(transform(dataFrame), (String) $(probabilityCol()), (String) $(labelCol()));
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel, org.apache.spark.ml.PredictionModel
    public double predict(Vector vector) {
        return BoxesRunTime.unboxToDouble(score().apply(vector)) > getThreshold() ? 1.0d : 0.0d;
    }

    @Override // org.apache.spark.ml.classification.ProbabilisticClassificationModel
    public Vector raw2probabilityInPlace(Vector vector) {
        if (!(vector instanceof DenseVector)) {
            if (vector instanceof SparseVector) {
                throw new RuntimeException("Unexpected error in LogisticRegressionModel: raw2probabilitiesInPlace encountered SparseVector");
            }
            throw new MatchError(vector);
        }
        DenseVector denseVector = (DenseVector) vector;
        int size = denseVector.size();
        for (int i = 0; i < size; i++) {
            denseVector.values()[i] = 1.0d / (1.0d + package$.MODULE$.exp(-denseVector.values()[i]));
        }
        return denseVector;
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel
    public Vector predictRaw(Vector vector) {
        double unboxToDouble = BoxesRunTime.unboxToDouble(org$apache$spark$ml$classification$LogisticRegressionModel$$margin().apply(vector));
        return Vectors$.MODULE$.dense(-unboxToDouble, (Seq<Object>) Predef$.MODULE$.wrapDoubleArray(new double[]{unboxToDouble}));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.spark.ml.Model, org.apache.spark.ml.Transformer, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public LogisticRegressionModel copy(ParamMap paramMap) {
        LogisticRegressionModel logisticRegressionModel = (LogisticRegressionModel) copyValues(new LogisticRegressionModel(uid(), weights(), intercept()), paramMap);
        if (trainingSummary().isDefined()) {
            logisticRegressionModel.setSummary((LogisticRegressionTrainingSummary) trainingSummary().get());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return (LogisticRegressionModel) logisticRegressionModel.setParent(parent());
    }

    @Override // org.apache.spark.ml.classification.ProbabilisticClassificationModel, org.apache.spark.ml.classification.ClassificationModel
    public double raw2prediction(Vector vector) {
        double threshold = getThreshold();
        return vector.apply(1) > ((threshold > 0.0d ? 1 : (threshold == 0.0d ? 0 : -1)) == 0 ? Double.NEGATIVE_INFINITY : (threshold > 1.0d ? 1 : (threshold == 1.0d ? 0 : -1)) == 0 ? Double.POSITIVE_INFINITY : package$.MODULE$.log(threshold / (1.0d - threshold))) ? 1.0d : 0.0d;
    }

    @Override // org.apache.spark.ml.classification.ProbabilisticClassificationModel
    public double probability2prediction(Vector vector) {
        return vector.apply(1) > getThreshold() ? 1.0d : 0.0d;
    }

    public LogisticRegressionModel(String str, Vector vector, double d) {
        this.uid = str;
        this.weights = vector;
        this.intercept = d;
        org$apache$spark$ml$param$shared$HasRegParam$_setter_$regParam_$eq(new DoubleParam(this, "regParam", "regularization parameter (>= 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gtEq(0.0d)));
        org$apache$spark$ml$param$shared$HasElasticNetParam$_setter_$elasticNetParam_$eq(new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", (Function1<Object, Object>) ParamValidators$.MODULE$.inRange(0.0d, 1.0d)));
        org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gtEq(0.0d)));
        HasFitIntercept.Cclass.$init$(this);
        org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms"));
        HasStandardization.Cclass.$init$(this);
        org$apache$spark$ml$param$shared$HasWeightCol$_setter_$weightCol_$eq(new Param(this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0."));
        HasThreshold.Cclass.$init$(this);
        LogisticRegressionParams.Cclass.$init$(this);
        this.org$apache$spark$ml$classification$LogisticRegressionModel$$margin = new LogisticRegressionModel$$anonfun$8(this);
        this.score = new LogisticRegressionModel$$anonfun$9(this);
        this.numFeatures = vector.size();
        this.numClasses = 2;
        this.trainingSummary = None$.MODULE$;
    }
}
