package nodes.learning;

import breeze.generic.UFunc$;
import breeze.linalg.Axis$_0$;
import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.NumericOps;
import breeze.linalg.sum$;
import breeze.storage.Zero$DoubleZero$;
import edu.berkeley.cs.amplab.mlmatrix.util.Utils$;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.slf4j.Logger;
import pipelines.Logging;
import scala.Array$;
import scala.Function0;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: BlockWeightedLeastSquares.scala */
/* loaded from: input_file:nodes/learning/BlockWeightedLeastSquaresEstimator$.class */
public final class BlockWeightedLeastSquaresEstimator$ implements Logging, Serializable {
    public static final BlockWeightedLeastSquaresEstimator$ MODULE$ = null;
    private transient Logger pipelines$Logging$$log_;

    static {
        new BlockWeightedLeastSquaresEstimator$();
    }

    @Override // pipelines.Logging
    public Logger pipelines$Logging$$log_() {
        return this.pipelines$Logging$$log_;
    }

    @Override // pipelines.Logging
    public void pipelines$Logging$$log__$eq(Logger logger) {
        this.pipelines$Logging$$log_ = logger;
    }

    @Override // pipelines.Logging
    public Logger log() {
        return Logging.Cclass.log(this);
    }

    @Override // pipelines.Logging
    public void logInfo(Function0<String> function0) {
        Logging.Cclass.logInfo(this, function0);
    }

    @Override // pipelines.Logging
    public void logDebug(Function0<String> function0) {
        Logging.Cclass.logDebug(this, function0);
    }

    @Override // pipelines.Logging
    public void logTrace(Function0<String> function0) {
        Logging.Cclass.logTrace(this, function0);
    }

    @Override // pipelines.Logging
    public void logWarning(Function0<String> function0) {
        Logging.Cclass.logWarning(this, function0);
    }

    @Override // pipelines.Logging
    public void logError(Function0<String> function0) {
        Logging.Cclass.logError(this, function0);
    }

    @Override // pipelines.Logging
    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.Cclass.logInfo(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.Cclass.logDebug(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.Cclass.logTrace(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.Cclass.logWarning(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logError(Function0<String> function0, Throwable th) {
        Logging.Cclass.logError(this, function0, th);
    }

    public BlockLinearMapper trainWithL2(Seq<RDD<DenseVector<Object>>> seq, RDD<DenseVector<Object>> rdd, int i, int i2, double d, double d2) {
        boolean z;
        Tuple2<Seq<RDD<DenseVector<Object>>>, RDD<DenseVector<Object>>> tuple2;
        SparkContext context = ((RDD) seq.head()).context();
        if (Predef$.MODULE$.booleanArrayOps((boolean[]) rdd.mapPartitions(new BlockWeightedLeastSquaresEstimator$$anonfun$1(), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.Boolean()).collect()).forall(new BlockWeightedLeastSquaresEstimator$$anonfun$2())) {
            int[] iArr = (int[]) rdd.mapPartitions(new BlockWeightedLeastSquaresEstimator$$anonfun$3(), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.Int()).collect();
            z = Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.intArrayOps(iArr).distinct()).size() != Predef$.MODULE$.intArrayOps(iArr).size();
        } else {
            z = true;
        }
        if (z) {
            logWarning(new BlockWeightedLeastSquaresEstimator$$anonfun$4());
            tuple2 = groupByClasses(seq, rdd);
        } else {
            tuple2 = new Tuple2<>(seq, rdd);
        }
        Tuple2<Seq<RDD<DenseVector<Object>>>, RDD<DenseVector<Object>>> tuple22 = tuple2;
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        Tuple2 tuple23 = new Tuple2((Seq) tuple22._1(), (RDD) tuple22._2());
        Seq seq2 = (Seq) tuple23._1();
        RDD rdd2 = (RDD) tuple23._2();
        RDD name = rdd2.mapPartitions(new BlockWeightedLeastSquaresEstimator$$anonfun$5(), rdd2.mapPartitions$default$2(), ClassTag$.MODULE$.Int()).cache().setName("classIdxs");
        long count = rdd2.count();
        int length = ((DenseVector) rdd2.first()).length();
        RDD mapPartitions = rdd2.mapPartitions(new BlockWeightedLeastSquaresEstimator$$anonfun$6(), rdd2.mapPartitions$default$2(), ClassTag$.MODULE$.apply(DenseMatrix.class));
        Tuple2[] tuple2Arr = (Tuple2[]) mapPartitions.zip(name, ClassTag$.MODULE$.Int()).map(new BlockWeightedLeastSquaresEstimator$$anonfun$7(d2, count), ClassTag$.MODULE$.apply(Tuple2.class)).collect();
        DenseVector zeros$mDc$sp = DenseVector$.MODULE$.zeros$mDc$sp(length, ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$);
        Predef$.MODULE$.refArrayOps(tuple2Arr).foreach(new BlockWeightedLeastSquaresEstimator$$anonfun$trainWithL2$2(zeros$mDc$sp));
        DenseMatrix[] denseMatrixArr = (DenseMatrix[]) ((TraversableOnce) seq.map(new BlockWeightedLeastSquaresEstimator$$anonfun$8(i, length), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(DenseMatrix.class));
        int length2 = denseMatrixArr.length;
        ObjectRef objectRef = new ObjectRef(mapPartitions.map(new BlockWeightedLeastSquaresEstimator$$anonfun$9(zeros$mDc$sp), ClassTag$.MODULE$.apply(DenseMatrix.class)).cache().setName("residual"));
        ObjectRef objectRef2 = new ObjectRef((DenseVector) ((NumericOps) Utils$.MODULE$.treeReduce(((RDD) objectRef.elem).map(new BlockWeightedLeastSquaresEstimator$$anonfun$10(), ClassTag$.MODULE$.apply(DenseVector.class)), new BlockWeightedLeastSquaresEstimator$$anonfun$11(), Utils$.MODULE$.treeReduce$default$3(), ClassTag$.MODULE$.apply(DenseVector.class))).$div$eq(BoxesRunTime.boxToDouble(length), DenseVector$.MODULE$.dv_s_UpdateOp_Double_OpDiv()));
        Option[] optionArr = (Option[]) ((TraversableOnce) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length2).map(new BlockWeightedLeastSquaresEstimator$$anonfun$12(), IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Option.class));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i2).foreach$mVc$sp(new BlockWeightedLeastSquaresEstimator$$anonfun$trainWithL2$1(i, d, d2, context, seq2, name, count, length, denseMatrixArr, length2, objectRef, objectRef2, optionArr));
        return new BlockLinearMapper(Predef$.MODULE$.wrapRefArray(denseMatrixArr), i, new Some((DenseVector) zeros$mDc$sp.$minus(((DenseMatrix) sum$.MODULE$.apply(((ImmutableNumericOps) DenseMatrix$.MODULE$.horzcat(Predef$.MODULE$.wrapRefArray((Object[]) Predef$.MODULE$.refArrayOps(optionArr).map(new BlockWeightedLeastSquaresEstimator$$anonfun$22(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(DenseMatrix.class)))), Predef$.MODULE$.conforms(), DenseMatrix$.MODULE$.dm_dm_UpdateOp_Double_OpSet(), ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$).t(DenseMatrix$.MODULE$.canTranspose())).$colon$times(DenseMatrix$.MODULE$.vertcat(Predef$.MODULE$.wrapRefArray(denseMatrixArr), DenseMatrix$.MODULE$.dm_dm_UpdateOp_Double_OpSet(), ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$), DenseMatrix$.MODULE$.op_DM_DM_Double_OpMulScalar()), Axis$_0$.MODULE$, UFunc$.MODULE$.collapseUred(DenseMatrix$.MODULE$.handholdCanMapRows(), sum$.MODULE$.reduce_Double(DenseVector$.MODULE$.canIterateValues()), DenseMatrix$.MODULE$.canCollapseRows(ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$)))).toDenseVector$mcD$sp(), DenseVector$.MODULE$.canSubD())), BlockLinearMapper$.MODULE$.$lessinit$greater$default$4());
    }

    public Tuple2<DenseMatrix<Object>, DenseMatrix<Object>> addPairMatrices(Tuple2<DenseMatrix<Object>, DenseMatrix<Object>> tuple2, Tuple2<DenseMatrix<Object>, DenseMatrix<Object>> tuple22) {
        ((NumericOps) tuple2._1()).$plus$eq(tuple22._1(), DenseMatrix$.MODULE$.dm_dm_UpdateOp_Double_OpAdd());
        ((NumericOps) tuple2._2()).$plus$eq(tuple22._2(), DenseMatrix$.MODULE$.dm_dm_UpdateOp_Double_OpAdd());
        return tuple2;
    }

    public Tuple2<Seq<RDD<DenseVector<Object>>>, RDD<DenseVector<Object>>> groupByClasses(Seq<RDD<DenseVector<Object>>> seq, RDD<DenseVector<Object>> rdd) {
        HashPartitioner hashPartitioner = new HashPartitioner(((DenseVector) rdd.first()).length());
        long length = rdd.partitions().length;
        RDD values = RDD$.MODULE$.rddToPairRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(rdd.mapPartitionsWithIndex(new BlockWeightedLeastSquaresEstimator$$anonfun$23(length), rdd.mapPartitionsWithIndex$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$Int$.MODULE$).partitionBy(hashPartitioner), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$Int$.MODULE$).values();
        return new Tuple2<>((Seq) seq.map(new BlockWeightedLeastSquaresEstimator$$anonfun$25(rdd, hashPartitioner, length), Seq$.MODULE$.canBuildFrom()), values.mapPartitions(new BlockWeightedLeastSquaresEstimator$$anonfun$24(), values.mapPartitions$default$2(), ClassTag$.MODULE$.apply(DenseVector.class)));
    }

    public Option<Object> $lessinit$greater$default$5() {
        return None$.MODULE$;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private BlockWeightedLeastSquaresEstimator$() {
        MODULE$ = this;
        pipelines$Logging$$log__$eq(null);
    }
}
