package pipelines.images.cifar;

import breeze.linalg.$times$;
import breeze.linalg.BroadcastedColumns$;
import breeze.linalg.BroadcastedRows$;
import breeze.linalg.Broadcaster$;
import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.sum$;
import breeze.numerics.package$pow$;
import breeze.numerics.package$pow$powDoubleDoubleImpl$;
import breeze.numerics.package$sqrt$;
import breeze.numerics.package$sqrt$sqrtDoubleImpl$;
import breeze.storage.Zero$DoubleZero$;
import evaluation.AugmentedExamplesEvaluator$;
import loaders.CifarLoader$;
import nodes.images.CenterCornerPatcher;
import nodes.images.Convolver;
import nodes.images.Convolver$;
import nodes.images.ImageExtractor$;
import nodes.images.ImageVectorizer$;
import nodes.images.LabelExtractor$;
import nodes.images.Pooler;
import nodes.images.RandomImageTransformer;
import nodes.images.RandomImageTransformer$;
import nodes.images.RandomPatcher;
import nodes.images.RandomPatcher$;
import nodes.images.SymmetricRectifier;
import nodes.images.SymmetricRectifier$;
import nodes.images.Windower;
import nodes.learning.BlockLeastSquaresEstimator;
import nodes.learning.BlockLeastSquaresEstimator$;
import nodes.learning.ZCAWhitener;
import nodes.learning.ZCAWhitenerEstimator;
import nodes.stats.Sampler;
import nodes.stats.Sampler$;
import nodes.stats.StandardScaler;
import nodes.stats.StandardScaler$;
import nodes.util.Cacher;
import nodes.util.Cacher$;
import nodes.util.ClassLabelIndicatorsFromIntLabels;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import pipelines.Logging;
import pipelines.images.cifar.RandomPatchCifarAugmented;
import scala.Function0;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.TraitSetter;
import scopt.OptionParser;
import scopt.Read$;
import utils.Image;
import utils.LabeledImage;
import utils.MatrixUtils$;
import utils.Stats$;
import workflow.Pipeline;

/* compiled from: RandomPatchCifarAugmented.scala */
/* loaded from: input_file:pipelines/images/cifar/RandomPatchCifarAugmented$.class */
public final class RandomPatchCifarAugmented$ implements Serializable, Logging {
    public static final RandomPatchCifarAugmented$ MODULE$ = null;
    private final String appName;
    private transient Logger pipelines$Logging$$log_;

    static {
        new RandomPatchCifarAugmented$();
    }

    @Override // pipelines.Logging
    public Logger pipelines$Logging$$log_() {
        return this.pipelines$Logging$$log_;
    }

    @Override // pipelines.Logging
    @TraitSetter
    public void pipelines$Logging$$log__$eq(Logger logger) {
        this.pipelines$Logging$$log_ = logger;
    }

    @Override // pipelines.Logging
    public Logger log() {
        return Logging.Cclass.log(this);
    }

    @Override // pipelines.Logging
    public void logInfo(Function0<String> function0) {
        Logging.Cclass.logInfo(this, function0);
    }

    @Override // pipelines.Logging
    public void logDebug(Function0<String> function0) {
        Logging.Cclass.logDebug(this, function0);
    }

    @Override // pipelines.Logging
    public void logTrace(Function0<String> function0) {
        Logging.Cclass.logTrace(this, function0);
    }

    @Override // pipelines.Logging
    public void logWarning(Function0<String> function0) {
        Logging.Cclass.logWarning(this, function0);
    }

    @Override // pipelines.Logging
    public void logError(Function0<String> function0) {
        Logging.Cclass.logError(this, function0);
    }

    @Override // pipelines.Logging
    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.Cclass.logInfo(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.Cclass.logDebug(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.Cclass.logTrace(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.Cclass.logWarning(this, function0, th);
    }

    @Override // pipelines.Logging
    public void logError(Function0<String> function0, Throwable th) {
        Logging.Cclass.logError(this, function0, th);
    }

    public String appName() {
        return this.appName;
    }

    public Pipeline<Image, DenseVector<Object>> run(SparkContext sparkContext, RandomPatchCifarAugmented.RandomCifarFeaturizerConfig randomCifarFeaturizerConfig) {
        RDD<LabeledImage> cache;
        Some sampleFrac = randomCifarFeaturizerConfig.sampleFrac();
        if (sampleFrac instanceof Some) {
            double unboxToDouble = BoxesRunTime.unboxToDouble(sampleFrac.x());
            RDD<LabeledImage> apply = CifarLoader$.MODULE$.apply(sparkContext, randomCifarFeaturizerConfig.trainLocation());
            cache = apply.sample(false, unboxToDouble, apply.sample$default$3()).cache();
        } else {
            None$ none$ = None$.MODULE$;
            if (none$ != null ? !none$.equals(sampleFrac) : sampleFrac != null) {
                throw new MatchError(sampleFrac);
            }
            cache = CifarLoader$.MODULE$.apply(sparkContext, randomCifarFeaturizerConfig.trainLocation()).cache();
        }
        RDD<LabeledImage> rdd = cache;
        RDD<Image> apply2 = ImageExtractor$.MODULE$.apply((RDD) rdd);
        DenseMatrix<Object> normalizeRows = Stats$.MODULE$.normalizeRows(MatrixUtils$.MODULE$.rowsToMatrix((DenseVector[]) new Windower(randomCifarFeaturizerConfig.patchSteps(), randomCifarFeaturizerConfig.patchSize()).andThen(new RandomPatchCifarAugmented$$anonfun$3()).andThen(new Sampler(100000, Sampler$.MODULE$.$lessinit$greater$default$2())).apply(apply2), ClassTag$.MODULE$.Double()), 10.0d);
        ZCAWhitener fitSingle = new ZCAWhitenerEstimator(randomCifarFeaturizerConfig.whiteningEpsilon()).fitSingle(normalizeRows);
        DenseMatrix<Object> apply3 = fitSingle.apply(MatrixUtils$.MODULE$.sampleRows(normalizeRows, randomCifarFeaturizerConfig.numFilters()));
        Tuple2 tuple2 = new Tuple2(((ImmutableNumericOps) ((ImmutableNumericOps) apply3.apply(package$.MODULE$.$colon$colon(), $times$.MODULE$, Broadcaster$.MODULE$.canBroadcastColumns(DenseMatrix$.MODULE$.handholdCanMapRows()))).$div(((DenseVector) package$sqrt$.MODULE$.apply(sum$.MODULE$.apply(((DenseMatrix) package$pow$.MODULE$.apply(apply3, BoxesRunTime.boxToDouble(2.0d), package$pow$.MODULE$.canMapV1DV(DenseMatrix$.MODULE$.handholdCMV(), package$pow$powDoubleDoubleImpl$.MODULE$, DenseMatrix$.MODULE$.canMapValues(ClassTag$.MODULE$.Double())))).apply($times$.MODULE$, package$.MODULE$.$colon$colon(), Broadcaster$.MODULE$.canBroadcastRows(DenseMatrix$.MODULE$.handholdCanMapCols())), BroadcastedRows$.MODULE$.broadcastOp(DenseMatrix$.MODULE$.handholdCanMapCols(), sum$.MODULE$.reduce_Double(DenseVector$.MODULE$.canIterateValues()), DenseMatrix$.MODULE$.canCollapseCols(ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$))), package$sqrt$.MODULE$.fromLowOrderCanMapValues(DenseVector$.MODULE$.handholdCMV(), package$sqrt$sqrtDoubleImpl$.MODULE$, DenseVector$.MODULE$.canMapValues(ClassTag$.MODULE$.Double())))).$plus(BoxesRunTime.boxToDouble(1.0E-10d), DenseVector$.MODULE$.dv_s_Op_Double_OpAdd()), BroadcastedColumns$.MODULE$.broadcastOp2(DenseMatrix$.MODULE$.handholdCanMapRows(), DenseVector$.MODULE$.dv_dv_Op_Double_OpDiv(), DenseMatrix$.MODULE$.canMapRows(ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$)))).$times(fitSingle.whitener().t(DenseMatrix$.MODULE$.canTranspose()), DenseMatrix$.MODULE$.implOpMulMatrix_DMD_DMD_eq_DMD()), fitSingle);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((DenseMatrix) tuple2._1(), (ZCAWhitener) tuple2._2());
        DenseMatrix denseMatrix = (DenseMatrix) tuple22._1();
        ZCAWhitener zCAWhitener = (ZCAWhitener) tuple22._2();
        RDD<Image> apply4 = new RandomImageTransformer(0.5d, new RandomPatchCifarAugmented$$anonfun$4(), RandomImageTransformer$.MODULE$.apply$default$3()).apply((RDD) new RandomPatcher(randomCifarFeaturizerConfig.numRandomImagesAugment(), 24, 24, RandomPatcher$.MODULE$.apply$default$4()).apply(apply2));
        Pipeline<Image, DenseVector<Object>> andThen = new Convolver(denseMatrix, 24, 24, 3, new Some(zCAWhitener), true, Convolver$.MODULE$.$lessinit$greater$default$7()).andThen(new SymmetricRectifier(SymmetricRectifier$.MODULE$.apply$default$1(), randomCifarFeaturizerConfig.alpha())).andThen(new Pooler(randomCifarFeaturizerConfig.poolStride(), randomCifarFeaturizerConfig.poolSize(), new RandomPatchCifarAugmented$$anonfun$1(), new RandomPatchCifarAugmented$$anonfun$5())).andThen(ImageVectorizer$.MODULE$).andThen(new Cacher(new Some("features"), ClassTag$.MODULE$.apply(DenseVector.class))).andThen(new StandardScaler(StandardScaler$.MODULE$.$lessinit$greater$default$1(), StandardScaler$.MODULE$.$lessinit$greater$default$2()), apply4).andThen(new BlockLeastSquaresEstimator(4096, 1, BoxesRunTime.unboxToDouble(randomCifarFeaturizerConfig.lambda().getOrElse(new RandomPatchCifarAugmented$$anonfun$2())), BlockLeastSquaresEstimator$.MODULE$.$lessinit$greater$default$4()), apply4, new RandomPatchCifarAugmented.LabelAugmenter(randomCifarFeaturizerConfig.numRandomImagesAugment(), ClassTag$.MODULE$.apply(DenseVector.class)).apply(LabelExtractor$.MODULE$.andThen(new ClassLabelIndicatorsFromIntLabels(10)).apply(rdd))).andThen(new Cacher(Cacher$.MODULE$.$lessinit$greater$default$1(), ClassTag$.MODULE$.apply(DenseVector.class)));
        RDD<LabeledImage> apply5 = CifarLoader$.MODULE$.apply(sparkContext, randomCifarFeaturizerConfig.testLocation());
        RDD<Image> apply6 = ImageExtractor$.MODULE$.apply((RDD) apply5);
        RDD<Image> apply7 = new CenterCornerPatcher(24, 24, true).apply(apply6);
        RDD apply8 = new RandomPatchCifarAugmented.LabelAugmenter(10, ClassTag$.MODULE$.Long()).apply(apply6.zipWithUniqueId().map(new RandomPatchCifarAugmented$$anonfun$6(), ClassTag$.MODULE$.Long()));
        RDD<Object> apply9 = new RandomPatchCifarAugmented.LabelAugmenter(10, ClassTag$.MODULE$.Int()).apply((RDD) LabelExtractor$.MODULE$.apply((RDD) apply5));
        logInfo(new RandomPatchCifarAugmented$$anonfun$run$1(AugmentedExamplesEvaluator$.MODULE$.apply(apply8, andThen.apply(apply7), apply9, 10, AugmentedExamplesEvaluator$.MODULE$.apply$default$5(), ClassTag$.MODULE$.Long())));
        return andThen;
    }

    /* JADX WARN: Type inference failed for: r0v0, types: [pipelines.images.cifar.RandomPatchCifarAugmented$$anon$1] */
    public RandomPatchCifarAugmented.RandomCifarFeaturizerConfig parse(String[] strArr) {
        return (RandomPatchCifarAugmented.RandomCifarFeaturizerConfig) new OptionParser<RandomPatchCifarAugmented.RandomCifarFeaturizerConfig>() { // from class: pipelines.images.cifar.RandomPatchCifarAugmented$$anon$1
            {
                RandomPatchCifarAugmented$.MODULE$.appName();
                head(Predef$.MODULE$.wrapRefArray(new String[]{RandomPatchCifarAugmented$.MODULE$.appName(), "0.1"}));
                help("help").text("prints this usage text");
                opt("trainLocation", Read$.MODULE$.stringRead()).required().action(new RandomPatchCifarAugmented$$anon$1$$anonfun$7(this));
                opt("testLocation", Read$.MODULE$.stringRead()).required().action(new RandomPatchCifarAugmented$$anon$1$$anonfun$8(this));
                opt("numFilters", Read$.MODULE$.intRead()).action(new RandomPatchCifarAugmented$$anon$1$$anonfun$9(this));
                opt("whiteningEpsilon", Read$.MODULE$.doubleRead()).required().action(new RandomPatchCifarAugmented$$anon$1$$anonfun$10(this));
                opt("patchSize", Read$.MODULE$.intRead()).action(new RandomPatchCifarAugmented$$anon$1$$anonfun$11(this));
                opt("patchSteps", Read$.MODULE$.intRead()).action(new RandomPatchCifarAugmented$$anon$1$$anonfun$12(this));
                opt("poolSize", Read$.MODULE$.intRead()).action(new RandomPatchCifarAugmented$$anon$1$$anonfun$13(this));
                opt("numRandomImagesAugment", Read$.MODULE$.intRead()).action(new RandomPatchCifarAugmented$$anon$1$$anonfun$14(this));
                opt("alpha", Read$.MODULE$.doubleRead()).action(new RandomPatchCifarAugmented$$anon$1$$anonfun$15(this));
                opt("lambda", Read$.MODULE$.doubleRead()).action(new RandomPatchCifarAugmented$$anon$1$$anonfun$16(this));
                opt("sampleFrac", Read$.MODULE$.doubleRead()).action(new RandomPatchCifarAugmented$$anon$1$$anonfun$17(this));
            }
        }.parse(Predef$.MODULE$.wrapRefArray(strArr), new RandomPatchCifarAugmented.RandomCifarFeaturizerConfig(RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$1(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$2(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$3(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$4(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$5(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$6(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$7(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$8(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$9(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$10(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$11(), RandomPatchCifarAugmented$RandomCifarFeaturizerConfig$.MODULE$.apply$default$12())).get();
    }

    public void main(String[] strArr) {
        RandomPatchCifarAugmented.RandomCifarFeaturizerConfig parse = parse(strArr);
        SparkConf appName = new SparkConf().setAppName(appName());
        appName.setIfMissing("spark.master", "local[2]");
        appName.remove("spark.jars");
        SparkContext sparkContext = new SparkContext(appName);
        run(sparkContext, parse);
        sparkContext.stop();
    }

    private Object readResolve() {
        return MODULE$;
    }

    private RandomPatchCifarAugmented$() {
        MODULE$ = this;
        pipelines$Logging$$log__$eq(null);
        this.appName = "RandomPatchCifarAugmented";
    }
}
