package keystoneml.pipelines.images.cifar;

import breeze.linalg.$times$;
import breeze.linalg.BroadcastedColumns$;
import breeze.linalg.Broadcaster$;
import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.sum$;
import breeze.numerics.package$pow$;
import breeze.numerics.package$pow$powDoubleDoubleImpl$;
import breeze.numerics.package$sqrt$;
import breeze.numerics.package$sqrt$sqrtDoubleImpl$;
import breeze.storage.Zero$DoubleZero$;
import keystoneml.evaluation.MulticlassClassifierEvaluator;
import keystoneml.evaluation.MulticlassMetrics;
import keystoneml.loaders.CifarLoader$;
import keystoneml.nodes.images.Convolver;
import keystoneml.nodes.images.Convolver$;
import keystoneml.nodes.images.ImageExtractor$;
import keystoneml.nodes.images.ImageVectorizer$;
import keystoneml.nodes.images.LabelExtractor$;
import keystoneml.nodes.images.Pooler;
import keystoneml.nodes.images.SymmetricRectifier;
import keystoneml.nodes.images.SymmetricRectifier$;
import keystoneml.nodes.images.Windower;
import keystoneml.nodes.learning.BlockLeastSquaresEstimator;
import keystoneml.nodes.learning.BlockLeastSquaresEstimator$;
import keystoneml.nodes.learning.ZCAWhitener;
import keystoneml.nodes.learning.ZCAWhitenerEstimator;
import keystoneml.nodes.stats.Sampler;
import keystoneml.nodes.stats.Sampler$;
import keystoneml.nodes.stats.StandardScaler;
import keystoneml.nodes.stats.StandardScaler$;
import keystoneml.nodes.util.Cacher;
import keystoneml.nodes.util.Cacher$;
import keystoneml.nodes.util.ClassLabelIndicatorsFromIntLabels;
import keystoneml.nodes.util.MaxClassifier$;
import keystoneml.pipelines.Logging;
import keystoneml.pipelines.images.cifar.RandomPatchCifar;
import keystoneml.utils.Image;
import keystoneml.utils.LabeledImage;
import keystoneml.utils.MatrixUtils$;
import keystoneml.utils.Stats$;
import keystoneml.workflow.Pipeline;
import keystoneml.workflow.PipelineDataset;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import scala.Function0;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.TraitSetter;
import scopt.OptionParser;
import scopt.Read$;

/* compiled from: RandomPatchCifar.scala */
/* loaded from: input_file:keystoneml/pipelines/images/cifar/RandomPatchCifar$.class */
public final class RandomPatchCifar$ implements Serializable, Logging {
    public static final RandomPatchCifar$ MODULE$ = null;
    private final String appName;
    private transient Logger keystoneml$pipelines$Logging$$log_;

    static {
        new RandomPatchCifar$();
    }

    @Override // keystoneml.pipelines.Logging
    public Logger keystoneml$pipelines$Logging$$log_() {
        return this.keystoneml$pipelines$Logging$$log_;
    }

    @Override // keystoneml.pipelines.Logging
    @TraitSetter
    public void keystoneml$pipelines$Logging$$log__$eq(Logger logger) {
        this.keystoneml$pipelines$Logging$$log_ = logger;
    }

    @Override // keystoneml.pipelines.Logging
    public Logger log() {
        return Logging.Cclass.log(this);
    }

    @Override // keystoneml.pipelines.Logging
    public void logInfo(Function0<String> function0) {
        Logging.Cclass.logInfo(this, function0);
    }

    @Override // keystoneml.pipelines.Logging
    public void logDebug(Function0<String> function0) {
        Logging.Cclass.logDebug(this, function0);
    }

    @Override // keystoneml.pipelines.Logging
    public void logTrace(Function0<String> function0) {
        Logging.Cclass.logTrace(this, function0);
    }

    @Override // keystoneml.pipelines.Logging
    public void logWarning(Function0<String> function0) {
        Logging.Cclass.logWarning(this, function0);
    }

    @Override // keystoneml.pipelines.Logging
    public void logError(Function0<String> function0) {
        Logging.Cclass.logError(this, function0);
    }

    @Override // keystoneml.pipelines.Logging
    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.Cclass.logInfo(this, function0, th);
    }

    @Override // keystoneml.pipelines.Logging
    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.Cclass.logDebug(this, function0, th);
    }

    @Override // keystoneml.pipelines.Logging
    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.Cclass.logTrace(this, function0, th);
    }

    @Override // keystoneml.pipelines.Logging
    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.Cclass.logWarning(this, function0, th);
    }

    @Override // keystoneml.pipelines.Logging
    public void logError(Function0<String> function0, Throwable th) {
        Logging.Cclass.logError(this, function0, th);
    }

    public String appName() {
        return this.appName;
    }

    public Pipeline<Image, Object> run(SparkContext sparkContext, RandomPatchCifar.RandomCifarConfig randomCifarConfig) {
        RDD cache;
        Some sampleFrac = randomCifarConfig.sampleFrac();
        if (sampleFrac instanceof Some) {
            double unboxToDouble = BoxesRunTime.unboxToDouble(sampleFrac.x());
            RDD<LabeledImage> apply = CifarLoader$.MODULE$.apply(sparkContext, randomCifarConfig.trainLocation());
            cache = apply.sample(false, unboxToDouble, apply.sample$default$3()).cache();
        } else {
            None$ none$ = None$.MODULE$;
            if (none$ != null ? !none$.equals(sampleFrac) : sampleFrac != null) {
                throw new MatchError(sampleFrac);
            }
            cache = CifarLoader$.MODULE$.apply(sparkContext, randomCifarConfig.trainLocation()).cache();
        }
        RDD rdd = cache;
        RDD<Image> apply2 = ImageExtractor$.MODULE$.apply(rdd);
        Pipeline andThen = LabelExtractor$.MODULE$.andThen(new ClassLabelIndicatorsFromIntLabels(10)).andThen(new Cacher(Cacher$.MODULE$.$lessinit$greater$default$1(), ClassTag$.MODULE$.apply(DenseVector.class)));
        PipelineDataset apply3 = andThen.apply(rdd);
        DenseMatrix<Object> normalizeRows = Stats$.MODULE$.normalizeRows(MatrixUtils$.MODULE$.rowsToMatrix((DenseVector[]) new Windower(randomCifarConfig.patchSteps(), randomCifarConfig.patchSize()).andThen(new RandomPatchCifar$$anonfun$3()).andThen(new Sampler(100000, Sampler$.MODULE$.$lessinit$greater$default$2())).apply(apply2), ClassTag$.MODULE$.Double()), 10.0d);
        ZCAWhitener fitSingle = new ZCAWhitenerEstimator(randomCifarConfig.whiteningEpsilon()).fitSingle(normalizeRows);
        DenseMatrix<Object> apply4 = fitSingle.apply(MatrixUtils$.MODULE$.sampleRows(normalizeRows, randomCifarConfig.numFilters()));
        Tuple2 tuple2 = new Tuple2(((ImmutableNumericOps) ((ImmutableNumericOps) apply4.apply(package$.MODULE$.$colon$colon(), $times$.MODULE$, Broadcaster$.MODULE$.canBroadcastColumns(DenseMatrix$.MODULE$.handholdCanMapRows()))).$div(((DenseVector) package$sqrt$.MODULE$.apply(sum$.MODULE$.apply(((DenseMatrix) package$pow$.MODULE$.apply(apply4, BoxesRunTime.boxToDouble(2.0d), package$pow$.MODULE$.canMapV1DV(DenseMatrix$.MODULE$.scalarOf(), package$pow$powDoubleDoubleImpl$.MODULE$, DenseMatrix$.MODULE$.canMapValues$mDDc$sp(ClassTag$.MODULE$.Double())))).apply($times$.MODULE$, package$.MODULE$.$colon$colon(), Broadcaster$.MODULE$.canBroadcastRows(DenseMatrix$.MODULE$.handholdCanMapCols())), sum$.MODULE$.vectorizeRows(ClassTag$.MODULE$.Double(), sum$.MODULE$.helper_Double(), DenseVector$.MODULE$.canAddIntoD())), package$sqrt$.MODULE$.fromLowOrderCanMapValues(DenseVector$.MODULE$.scalarOf(), package$sqrt$sqrtDoubleImpl$.MODULE$, DenseVector$.MODULE$.canMapValues$mDDc$sp(ClassTag$.MODULE$.Double())))).$plus(BoxesRunTime.boxToDouble(1.0E-10d), DenseVector$.MODULE$.dv_s_Op_Double_OpAdd()), BroadcastedColumns$.MODULE$.broadcastOp2(DenseMatrix$.MODULE$.handholdCanMapRows(), DenseVector$.MODULE$.dv_dv_Op_Double_OpDiv(), DenseMatrix$.MODULE$.canMapRows(ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$, DenseVector$.MODULE$.dv_dv_UpdateOp_Double_OpSet())))).$times(fitSingle.whitener().t(DenseMatrix$.MODULE$.canTranspose()), DenseMatrix$.MODULE$.implOpMulMatrix_DMD_DMD_eq_DMD()), fitSingle);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((DenseMatrix) tuple2._1(), (ZCAWhitener) tuple2._2());
        Pipeline<Image, Object> andThen2 = new Convolver((DenseMatrix) tuple22._1(), 32, 32, 3, new Some((ZCAWhitener) tuple22._2()), true, Convolver$.MODULE$.$lessinit$greater$default$7()).andThen(new SymmetricRectifier(SymmetricRectifier$.MODULE$.apply$default$1(), randomCifarConfig.alpha())).andThen(new Pooler(randomCifarConfig.poolStride(), randomCifarConfig.poolSize(), new RandomPatchCifar$$anonfun$1(), new RandomPatchCifar$$anonfun$4())).andThen(ImageVectorizer$.MODULE$).andThen(new Cacher(Cacher$.MODULE$.$lessinit$greater$default$1(), ClassTag$.MODULE$.apply(DenseVector.class))).andThen(new StandardScaler(StandardScaler$.MODULE$.$lessinit$greater$default$1(), StandardScaler$.MODULE$.$lessinit$greater$default$2()), apply2).andThen(new BlockLeastSquaresEstimator(4096, 1, BoxesRunTime.unboxToDouble(randomCifarConfig.lambda().getOrElse(new RandomPatchCifar$$anonfun$2())), BlockLeastSquaresEstimator$.MODULE$.$lessinit$greater$default$4()), apply2, apply3).andThen(MaxClassifier$.MODULE$).andThen(new Cacher(Cacher$.MODULE$.$lessinit$greater$default$1(), ClassTag$.MODULE$.Int()));
        MulticlassClassifierEvaluator multiclassClassifierEvaluator = new MulticlassClassifierEvaluator(10);
        MulticlassMetrics multiclassMetrics = (MulticlassMetrics) multiclassClassifierEvaluator.evaluate(andThen2.apply(apply2), LabelExtractor$.MODULE$.apply(rdd));
        RDD<LabeledImage> apply5 = CifarLoader$.MODULE$.apply(sparkContext, randomCifarConfig.testLocation());
        RDD<Image> apply6 = ImageExtractor$.MODULE$.apply((RDD) apply5);
        andThen.apply((RDD) apply5);
        MulticlassMetrics multiclassMetrics2 = (MulticlassMetrics) multiclassClassifierEvaluator.evaluate(andThen2.apply(apply6), LabelExtractor$.MODULE$.apply((RDD) apply5));
        logInfo(new RandomPatchCifar$$anonfun$run$1(multiclassMetrics));
        logInfo(new RandomPatchCifar$$anonfun$run$2(multiclassMetrics2));
        return andThen2;
    }

    /* JADX WARN: Type inference failed for: r0v0, types: [keystoneml.pipelines.images.cifar.RandomPatchCifar$$anon$1] */
    public RandomPatchCifar.RandomCifarConfig parse(String[] strArr) {
        return (RandomPatchCifar.RandomCifarConfig) new OptionParser<RandomPatchCifar.RandomCifarConfig>() { // from class: keystoneml.pipelines.images.cifar.RandomPatchCifar$$anon$1
            {
                RandomPatchCifar$.MODULE$.appName();
                head(Predef$.MODULE$.wrapRefArray(new String[]{RandomPatchCifar$.MODULE$.appName(), "0.1"}));
                help("help").text("prints this usage text");
                opt("trainLocation", Read$.MODULE$.stringRead()).required().action(new RandomPatchCifar$$anon$1$$anonfun$5(this));
                opt("testLocation", Read$.MODULE$.stringRead()).required().action(new RandomPatchCifar$$anon$1$$anonfun$6(this));
                opt("whiteningEpsilon", Read$.MODULE$.doubleRead()).required().action(new RandomPatchCifar$$anon$1$$anonfun$7(this));
                opt("numFilters", Read$.MODULE$.intRead()).action(new RandomPatchCifar$$anon$1$$anonfun$8(this));
                opt("patchSize", Read$.MODULE$.intRead()).action(new RandomPatchCifar$$anon$1$$anonfun$9(this));
                opt("patchSteps", Read$.MODULE$.intRead()).action(new RandomPatchCifar$$anon$1$$anonfun$10(this));
                opt("poolSize", Read$.MODULE$.intRead()).action(new RandomPatchCifar$$anon$1$$anonfun$11(this));
                opt("alpha", Read$.MODULE$.doubleRead()).action(new RandomPatchCifar$$anon$1$$anonfun$12(this));
                opt("lambda", Read$.MODULE$.doubleRead()).action(new RandomPatchCifar$$anon$1$$anonfun$13(this));
                opt("sampleFrac", Read$.MODULE$.doubleRead()).action(new RandomPatchCifar$$anon$1$$anonfun$14(this));
            }
        }.parse(Predef$.MODULE$.wrapRefArray(strArr), new RandomPatchCifar.RandomCifarConfig(RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$1(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$2(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$3(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$4(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$5(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$6(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$7(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$8(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$9(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$10(), RandomPatchCifar$RandomCifarConfig$.MODULE$.apply$default$11())).get();
    }

    public void main(String[] strArr) {
        RandomPatchCifar.RandomCifarConfig parse = parse(strArr);
        SparkConf appName = new SparkConf().setAppName(appName());
        appName.setIfMissing("spark.master", "local[2]");
        SparkContext sparkContext = new SparkContext(appName);
        run(sparkContext, parse);
        sparkContext.stop();
    }

    private Object readResolve() {
        return MODULE$;
    }

    private RandomPatchCifar$() {
        MODULE$ = this;
        keystoneml$pipelines$Logging$$log__$eq(null);
        this.appName = "RandomPatchCifar";
    }
}
