package ml.dmlc.xgboost4j.scala.spark;

import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.IRabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.EvalTrait;
import ml.dmlc.xgboost4j.scala.ObjectiveTrait;
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker;
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker$;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuilder;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;

/* compiled from: WowXGBoost.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/WowXGBoost$.class */
public final class WowXGBoost$ implements Serializable {
    public static final WowXGBoost$ MODULE$ = null;
    private final Log ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$logger;

    static {
        new WowXGBoost$();
    }

    public Log ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$logger() {
        return this.ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$logger;
    }

    public Iterator<LabeledPoint> removeMissingValues(Iterator<LabeledPoint> iterator, float f) {
        return Predef$.MODULE$.float2Float(f).isNaN() ? iterator : iterator.map(new WowXGBoost$$anonfun$removeMissingValues$1(f));
    }

    public Option<float[]> ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$fromBaseMarginsToArray(Iterator<Object> iterator) {
        ArrayBuilder.ofFloat offloat = new ArrayBuilder.ofFloat();
        int i = 0;
        int i2 = 0;
        while (iterator.hasNext()) {
            i++;
            float unboxToFloat = BoxesRunTime.unboxToFloat(iterator.next());
            if (Predef$.MODULE$.float2Float(unboxToFloat).isNaN()) {
                i2++;
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                offloat.$plus$eq(unboxToFloat);
            }
        }
        if (i2 == i) {
            return None$.MODULE$;
        }
        if (i2 == 0) {
            return new Some(offloat.result());
        }
        throw new IllegalArgumentException(new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Encountered a partition with ", " NaN base margin values. "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i2)}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"If you want to specify base margin, ensure all values are non-NaN."})).s(Nil$.MODULE$)).toString());
    }

    public RDD<Tuple2<Booster, Map<String, float[]>>> buildDistributedBoosters(RDD<LabeledPoint> rdd, Map<String, Object> map, java.util.Map<String, String> map2, int i, ObjectiveTrait objectiveTrait, EvalTrait evalTrait, boolean z, float f, Booster booster) {
        return rdd.zipPartitions(rdd.map(new WowXGBoost$$anonfun$3(), ClassTag$.MODULE$.Float()), new WowXGBoost$$anonfun$buildDistributedBoosters$1(map, map2, i, objectiveTrait, evalTrait, z, f, booster), ClassTag$.MODULE$.Float(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    public Map<String, Object> ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$overrideParamsAccordingToTaskCPUs(Map<String, Object> map, SparkContext sparkContext) {
        int i = sparkContext.getConf().getInt("spark.task.cpus", 1);
        Map<String, Object> map2 = map;
        if (map2.contains("nthread")) {
            int i2 = new StringOps(Predef$.MODULE$.augmentString(map2.apply("nthread").toString())).toInt();
            Predef$.MODULE$.require(i2 <= i, new WowXGBoost$$anonfun$ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$overrideParamsAccordingToTaskCPUs$1(i, i2));
        } else {
            map2 = map.$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("nthread"), BoxesRunTime.boxToInteger(i)));
        }
        return map2;
    }

    public IRabitTracker ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$startTracker(int i, TrackerConf trackerConf) {
        String trackerImpl = trackerConf.trackerImpl();
        RabitTracker rabitTracker = "scala".equals(trackerImpl) ? new RabitTracker(i, RabitTracker$.MODULE$.$lessinit$greater$default$2(), RabitTracker$.MODULE$.$lessinit$greater$default$3()) : "python".equals(trackerImpl) ? new ml.dmlc.xgboost4j.java.RabitTracker(i) : new ml.dmlc.xgboost4j.java.RabitTracker(i);
        Predef$.MODULE$.require(rabitTracker.start(trackerConf.workerConnectionTimeout()), new WowXGBoost$$anonfun$ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$startTracker$1());
        return rabitTracker;
    }

    public Tuple2<Booster, Map<String, float[]>> trainDistributed(RDD<LabeledPoint> rdd, Map<String, Object> map, int i, int i2, ObjectiveTrait objectiveTrait, EvalTrait evalTrait, boolean z, float f) throws XGBoostError {
        TrackerConf trackerConf;
        long unboxToLong;
        if (map.contains("tree_method")) {
            Predef$ predef$ = Predef$.MODULE$;
            Object apply = map.apply("tree_method");
            predef$.require(apply != null ? !apply.equals("hist") : "hist" != 0, new WowXGBoost$$anonfun$trainDistributed$1());
        }
        Predef$.MODULE$.require(i2 > 0, new WowXGBoost$$anonfun$trainDistributed$2());
        if (objectiveTrait != null) {
            Predef$.MODULE$.require(map.get("obj_type").isDefined(), new WowXGBoost$$anonfun$trainDistributed$3());
        }
        Some some = map.get("tracker_conf");
        if (!None$.MODULE$.equals(some)) {
            if (some instanceof Some) {
                Object x = some.x();
                if (x instanceof TrackerConf) {
                    trackerConf = (TrackerConf) x;
                }
            }
            throw new IllegalArgumentException("parameter \"tracker_conf\" must be an instance of TrackerConf.");
        }
        trackerConf = TrackerConf$.MODULE$.apply();
        TrackerConf trackerConf2 = trackerConf;
        Some some2 = map.get("timeout_request_workers");
        if (!None$.MODULE$.equals(some2)) {
            if (some2 instanceof Some) {
                Object x2 = some2.x();
                if (x2 instanceof Long) {
                    unboxToLong = BoxesRunTime.unboxToLong(x2);
                }
            }
            throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be an instance of Long.");
        }
        unboxToLong = 0;
        long j = unboxToLong;
        Tuple2 extractParams = CheckpointManager$.MODULE$.extractParams(map);
        if (extractParams == null) {
            throw new MatchError(extractParams);
        }
        Tuple2 tuple2 = new Tuple2((String) extractParams._1(), BoxesRunTime.boxToInteger(extractParams._2$mcI$sp()));
        String str = (String) tuple2._1();
        int _2$mcI$sp = tuple2._2$mcI$sp();
        RDD<LabeledPoint> repartitionForTraining = repartitionForTraining(rdd, i2);
        SparkContext sparkContext = rdd.sparkContext();
        CheckpointManager checkpointManager = new CheckpointManager(sparkContext, str);
        checkpointManager.cleanUpHigherVersions(i);
        return (Tuple2) ((TraversableLike) checkpointManager.getCheckpointRounds(_2$mcI$sp, i).map(new WowXGBoost$$anonfun$trainDistributed$4(map, i, i2, objectiveTrait, evalTrait, z, f, trackerConf2, j, repartitionForTraining, sparkContext, checkpointManager, ObjectRef.create(checkpointManager.loadCheckpointAsBooster())), Seq$.MODULE$.canBuildFrom())).last();
    }

    public ObjectiveTrait trainDistributed$default$5() {
        return null;
    }

    public EvalTrait trainDistributed$default$6() {
        return null;
    }

    public boolean trainDistributed$default$7() {
        return false;
    }

    public float trainDistributed$default$8() {
        return Float.NaN;
    }

    public RDD<LabeledPoint> repartitionForTraining(RDD<LabeledPoint> rdd, int i) {
        if (rdd.getNumPartitions() == i) {
            return rdd;
        }
        ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"repartitioning training set to ", " partitions"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i)})));
        return rdd.repartition(i, rdd.repartition$default$2(i));
    }

    public Tuple2<Booster, Map<String, float[]>> ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$postTrackerReturnProcessing(int i, RDD<Tuple2<Booster, Map<String, float[]>>> rdd, Thread thread) {
        if (i != 0) {
            try {
                if (thread.isAlive()) {
                    thread.interrupt();
                }
            } catch (InterruptedException unused) {
                ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$logger().info("spark job thread is interrupted");
            }
            throw new XGBoostError("XGBoostModel training failed");
        }
        Tuple2 tuple2 = (Tuple2) rdd.first();
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Booster) tuple2._1(), (Map) tuple2._2());
        Booster booster = (Booster) tuple22._1();
        Map map = (Map) tuple22._2();
        rdd.unpersist(false);
        return new Tuple2<>(booster, map);
    }

    private Object readResolve() {
        return MODULE$;
    }

    private WowXGBoost$() {
        MODULE$ = this;
        this.ml$dmlc$xgboost4j$scala$spark$WowXGBoost$$logger = LogFactory.getLog("XGBoostSpark");
    }
}
