package nodes.learning;

import breeze.generic.UFunc$;
import breeze.linalg.Axis$_0$;
import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.NumericOps;
import breeze.linalg.sum$;
import breeze.storage.Zero$DoubleZero$;
import nodes.learning.internal.ReWeightedLeastSquaresSolver$;
import nodes.util.VectorSplitter;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.OrderedRDDFunctions;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.slf4j.Logger;
import pipelines.Logging;
import scala.Array$;
import scala.Function0;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import utils.MatrixUtils$;

/* compiled from: PerClassWeightedLeastSquares.scala */
/* loaded from: input_file:nodes/learning/PerClassWeightedLeastSquaresEstimator$.class */
public final class PerClassWeightedLeastSquaresEstimator$ implements Logging, Serializable {
    public static final PerClassWeightedLeastSquaresEstimator$ MODULE$ = null;
    private transient Logger pipelines$Logging$$log_;

    static {
        new PerClassWeightedLeastSquaresEstimator$();
    }

    @Override // pipelines.Logging
    public Logger pipelines$Logging$$log_() {
        return this.pipelines$Logging$$log_;
    }

    @Override // pipelines.Logging
    public void pipelines$Logging$$log__$eq(Logger logger) {
        this.pipelines$Logging$$log_ = logger;
    }

    @Override // pipelines.Logging
    public Logger log() {
        return Logging.Cclass.log(this);
    }

    @Override // pipelines.Logging
    public void logInfo(Function0<String> function0) {
        Logging.Cclass.logInfo(this, function0);
    }

    @Override // pipelines.Logging
    public void logDebug(Function0<String> function0) {
        Logging.Cclass.logDebug(this, function0);
    }

    @Override // pipelines.Logging
    public void logTrace(Function0<String> function0) {
        Logging.Cclass.logTrace(this, function0);
    }

    @Override // pipelines.Logging
    public void logWarning(Function0<String> function0) {
        Logging.Cclass.logWarning(this, function0);
    }

    @Override // pipelines.Logging
    public void logError(Function0<String> function0) {
        Logging.Cclass.logError(this, function0);
    }

    @Override // pipelines.Logging
    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.Cclass.logInfo(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.Cclass.logDebug(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.Cclass.logTrace(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.Cclass.logWarning(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logError(Function0<String> function0, Throwable th) {
        Logging.Cclass.logError(this, function0, th);
    }

    public BlockLinearMapper trainWithL2(RDD<DenseVector<Object>> rdd, RDD<DenseVector<Object>> rdd2, int i, int i2, double d, double d2, Option<Object> option) {
        int length = ((DenseVector) rdd2.first()).length();
        int unboxToInt = BoxesRunTime.unboxToInt(option.getOrElse(new PerClassWeightedLeastSquaresEstimator$$anonfun$1(rdd)));
        long count = rdd2.count();
        Tuple2<RDD<Tuple2<Object, DenseVector<Object>>>, int[]> computeJointFeatureMean = computeJointFeatureMean(rdd, rdd2, (DenseVector) ((NumericOps) rdd.fold(DenseVector$.MODULE$.zeros$mDc$sp(unboxToInt, ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$), new PerClassWeightedLeastSquaresEstimator$$anonfun$5())).$div$eq(BoxesRunTime.boxToDouble(count), DenseVector$.MODULE$.dv_s_UpdateOp_Double_OpDiv()), d2, count, unboxToInt, length);
        if (computeJointFeatureMean == null) {
            throw new MatchError(computeJointFeatureMean);
        }
        Tuple2 tuple2 = new Tuple2((RDD) computeJointFeatureMean._1(), (int[]) computeJointFeatureMean._2());
        RDD rdd3 = (RDD) tuple2._1();
        int[] iArr = (int[]) tuple2._2();
        RDD<DenseVector<Object>> computeWeights = computeWeights(rdd2, count, d2, iArr);
        MatrixUtils$ matrixUtils$ = MatrixUtils$.MODULE$;
        OrderedRDDFunctions rddToOrderedRDDFunctions = RDD$.MODULE$.rddToOrderedRDDFunctions(rdd3, Ordering$Int$.MODULE$, ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(DenseVector.class));
        DenseMatrix rowsToMatrix = matrixUtils$.rowsToMatrix((DenseVector[]) rddToOrderedRDDFunctions.sortByKey(rddToOrderedRDDFunctions.sortByKey$default$1(), rddToOrderedRDDFunctions.sortByKey$default$2()).map(new PerClassWeightedLeastSquaresEstimator$$anonfun$6(), ClassTag$.MODULE$.apply(DenseVector.class)).collect(), ClassTag$.MODULE$.Double());
        DenseVector<Object> computeJointLabelMean = computeJointLabelMean(iArr, d2, count);
        IndexedSeq indexedSeq = (IndexedSeq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).map(new PerClassWeightedLeastSquaresEstimator$$anonfun$8(i, i2, d, unboxToInt, computeWeights, rowsToMatrix, rdd2.map(new PerClassWeightedLeastSquaresEstimator$$anonfun$7(computeJointLabelMean), ClassTag$.MODULE$.apply(DenseVector.class)), new VectorSplitter(i, new Some(BoxesRunTime.boxToInteger(unboxToInt))).apply(rdd)), IndexedSeq$.MODULE$.canBuildFrom());
        DenseMatrix[] denseMatrixArr = new DenseMatrix[((SeqLike) indexedSeq.apply(0)).size()];
        ((IterableLike) indexedSeq.zipWithIndex(IndexedSeq$.MODULE$.canBuildFrom())).foreach(new PerClassWeightedLeastSquaresEstimator$$anonfun$trainWithL2$1(length, denseMatrixArr));
        return new BlockLinearMapper(Predef$.MODULE$.wrapRefArray(denseMatrixArr), i, new Some((DenseVector) computeJointLabelMean.$minus(((DenseMatrix) sum$.MODULE$.apply(((ImmutableNumericOps) rowsToMatrix.t(DenseMatrix$.MODULE$.canTranspose())).$colon$times(DenseMatrix$.MODULE$.vertcat(Predef$.MODULE$.wrapRefArray(denseMatrixArr), DenseMatrix$.MODULE$.dm_dm_UpdateOp_Double_OpSet(), ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$), DenseMatrix$.MODULE$.op_DM_DM_Double_OpMulScalar()), Axis$_0$.MODULE$, UFunc$.MODULE$.collapseUred(DenseMatrix$.MODULE$.handholdCanMapRows(), sum$.MODULE$.reduce_Double(DenseVector$.MODULE$.canIterateValues()), DenseMatrix$.MODULE$.canCollapseRows(ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$)))).toDenseVector$mcD$sp(), DenseVector$.MODULE$.canSubD())), BlockLinearMapper$.MODULE$.$lessinit$greater$default$4());
    }

    public Tuple2<RDD<Tuple2<Object, DenseVector<Object>>>, int[]> computeJointFeatureMean(RDD<DenseVector<Object>> rdd, RDD<DenseVector<Object>> rdd2, DenseVector<Object> denseVector, double d, long j, int i, int i2) {
        Broadcast broadcast = rdd.context().broadcast(denseVector, ClassTag$.MODULE$.apply(DenseVector.class));
        Map collectAsMap = RDD$.MODULE$.rddToPairRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(rdd2.map(new PerClassWeightedLeastSquaresEstimator$$anonfun$11(), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.Int(), Ordering$Int$.MODULE$).reduceByKey(new PerClassWeightedLeastSquaresEstimator$$anonfun$2()), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.Int(), Ordering$Int$.MODULE$).collectAsMap();
        int[] iArr = new int[i2];
        collectAsMap.foreach(new PerClassWeightedLeastSquaresEstimator$$anonfun$computeJointFeatureMean$1(iArr));
        return new Tuple2<>(RDD$.MODULE$.rddToPairRDDFunctions(rdd.zip(rdd2, ClassTag$.MODULE$.apply(DenseVector.class)).map(new PerClassWeightedLeastSquaresEstimator$$anonfun$12(), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(DenseVector.class), Ordering$Int$.MODULE$).foldByKey(DenseVector$.MODULE$.zeros$mDc$sp(i, ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$), new PerClassWeightedLeastSquaresEstimator$$anonfun$13()).map(new PerClassWeightedLeastSquaresEstimator$$anonfun$14(collectAsMap), ClassTag$.MODULE$.apply(Tuple2.class)).map(new PerClassWeightedLeastSquaresEstimator$$anonfun$15(d, broadcast), ClassTag$.MODULE$.apply(Tuple2.class)), iArr);
    }

    public RDD<DenseVector<Object>> computeWeights(RDD<DenseVector<Object>> rdd, long j, double d, int[] iArr) {
        return rdd.map(new PerClassWeightedLeastSquaresEstimator$$anonfun$16(j, d, iArr), ClassTag$.MODULE$.apply(DenseVector.class));
    }

    public DenseVector<Object> computeJointLabelMean(int[] iArr, double d, long j) {
        return (DenseVector) ((DenseVector) ((ImmutableNumericOps) new DenseVector.mcD.sp((double[]) Predef$.MODULE$.intArrayOps(iArr).map(new PerClassWeightedLeastSquaresEstimator$$anonfun$4(j), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()))).$colon$times(BoxesRunTime.boxToDouble(2.0d * (1.0d - d)), DenseVector$.MODULE$.canScaleD())).$colon$minus(BoxesRunTime.boxToDouble(1.0d), DenseVector$.MODULE$.dv_s_Op_Double_OpSub())).$colon$plus(BoxesRunTime.boxToDouble(2.0d * d), DenseVector$.MODULE$.dv_s_Op_Double_OpAdd());
    }

    public Seq<DenseVector<Object>> trainSingleClassWeightedLS(int i, int i2, int i3, double d, int i4, Seq<RDD<DenseVector<Object>>> seq, RDD<Object> rdd, RDD<Object> rdd2, DenseVector<Object> denseVector) {
        Tuple2<Seq<DenseMatrix<Object>>, RDD<DenseMatrix<Object>>> trainWithL2 = ReWeightedLeastSquaresSolver$.MODULE$.trainWithL2(i, i2, i3, d, i4, 1, seq, rdd.map(new PerClassWeightedLeastSquaresEstimator$$anonfun$17(), ClassTag$.MODULE$.apply(DenseVector.class)), rdd2, denseVector);
        if (trainWithL2 == null) {
            throw new MatchError(trainWithL2);
        }
        Tuple2 tuple2 = new Tuple2((Seq) trainWithL2._1(), (RDD) trainWithL2._2());
        Seq seq2 = (Seq) tuple2._1();
        return (Seq) seq2.map(new PerClassWeightedLeastSquaresEstimator$$anonfun$trainSingleClassWeightedLS$1(), Seq$.MODULE$.canBuildFrom());
    }

    public Option<Object> $lessinit$greater$default$5() {
        return None$.MODULE$;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private PerClassWeightedLeastSquaresEstimator$() {
        MODULE$ = this;
        pipelines$Logging$$log__$eq(null);
    }
}
