package org.platanios.tensorflow.examples;

import com.typesafe.scalalogging.Logger;
import com.typesafe.scalalogging.Logger$;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.platanios.tensorflow.api.implicits.helpers.DataTypeAuxToDataType$;
import org.platanios.tensorflow.api.implicits.helpers.DataTypeToOutput$;
import org.platanios.tensorflow.api.implicits.helpers.OutputToTensor$;
import org.platanios.tensorflow.api.learn.SupervisedTrainableModel;
import org.platanios.tensorflow.api.learn.estimators.InMemoryEstimator;
import org.platanios.tensorflow.api.learn.hooks.Hook;
import org.platanios.tensorflow.api.ops.Function$ArgType$;
import org.platanios.tensorflow.api.ops.NN$SamePadding$;
import org.platanios.tensorflow.api.ops.io.data.BatchDataset;
import org.platanios.tensorflow.api.ops.io.data.Dataset;
import org.platanios.tensorflow.api.ops.io.data.PrefetchDataset;
import org.platanios.tensorflow.api.ops.io.data.RepeatDataset;
import org.platanios.tensorflow.api.ops.io.data.ShuffleDataset;
import org.platanios.tensorflow.api.ops.io.data.TensorSlicesDataset;
import org.platanios.tensorflow.api.ops.io.data.ZipDataset;
import org.platanios.tensorflow.api.package$;
import org.platanios.tensorflow.api.package$tf$;
import org.platanios.tensorflow.api.package$tf$data$;
import org.platanios.tensorflow.api.package$tf$learn$;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.tensors.ops.Basic;
import org.platanios.tensorflow.data.image.STL10Dataset;
import org.platanios.tensorflow.data.image.STL10Loader$;
import org.slf4j.LoggerFactory;
import scala.Predef$;
import scala.Some;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.FloatRef;
import scala.runtime.RichInt$;

/* compiled from: STL10.scala */
/* loaded from: input_file:org/platanios/tensorflow/examples/STL10$.class */
public final class STL10$ {
    public static final STL10$ MODULE$ = null;
    private final Logger logger;

    static {
        new STL10$();
    }

    public void main(String[] strArr) {
        STL10Dataset load = STL10Loader$.MODULE$.load(Paths.get("datasets/STL10", new String[0]), STL10Loader$.MODULE$.load$default$2(), false);
        TensorSlicesDataset apply = package$tf$data$.MODULE$.TensorSlicesDataset().apply(load.trainImages(), package$tf$data$.MODULE$.TensorSlicesDataset().apply$default$2(), package$.MODULE$.tensorDataHelper(), OutputToTensor$.MODULE$.outputToTensor(), Function$ArgType$.MODULE$.outputArgType());
        TensorSlicesDataset apply2 = package$tf$data$.MODULE$.TensorSlicesDataset().apply(load.trainLabels(), package$tf$data$.MODULE$.TensorSlicesDataset().apply$default$2(), package$.MODULE$.tensorDataHelper(), OutputToTensor$.MODULE$.outputToTensor(), Function$ArgType$.MODULE$.outputArgType());
        package$ package_ = package$.MODULE$;
        package$ package_2 = package$.MODULE$;
        package$ package_3 = package$.MODULE$;
        package$ package_4 = package$.MODULE$;
        ZipDataset.ZipDatasetOps datasetToZipDatasetOps = package$.MODULE$.datasetToZipDatasetOps(apply);
        RepeatDataset.RepeatDatasetOps datasetToRepeatDatasetOps = package_4.datasetToRepeatDatasetOps(datasetToZipDatasetOps.zip(apply2, datasetToZipDatasetOps.zip$default$2()));
        ShuffleDataset.ShuffleDatasetOps datasetToShuffleDatasetOps = package_3.datasetToShuffleDatasetOps(datasetToRepeatDatasetOps.repeat(datasetToRepeatDatasetOps.repeat$default$1(), datasetToRepeatDatasetOps.repeat$default$2()));
        BatchDataset.BatchDatasetOps datasetToBatchDatasetOps = package_2.datasetToBatchDatasetOps(datasetToShuffleDatasetOps.shuffle(10000L, datasetToShuffleDatasetOps.shuffle$default$2(), datasetToShuffleDatasetOps.shuffle$default$3()));
        PrefetchDataset.PrefetchDatasetOps datasetToPrefetchDatasetOps = package_.datasetToPrefetchDatasetOps(datasetToBatchDatasetOps.batch(256L, datasetToBatchDatasetOps.batch$default$2()));
        Dataset prefetch = datasetToPrefetchDatasetOps.prefetch(10L, datasetToPrefetchDatasetOps.prefetch$default$2());
        if (this.logger.underlying().isInfoEnabled()) {
            this.logger.underlying().info("Building the logistic regression model.");
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        SupervisedTrainableModel apply3 = package$tf$learn$.MODULE$.Model().apply(package$tf$learn$.MODULE$.Input().apply(package$.MODULE$.UINT8(), package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{-1, load.trainImages().shape().apply(1), load.trainImages().shape().apply(2), load.trainImages().shape().apply(3)})), package$tf$learn$.MODULE$.Input().apply$default$3(), DataTypeAuxToDataType$.MODULE$.dataTypeAuxToDataType(), DataTypeToOutput$.MODULE$.dataTypeToOutput(), OutputToTensor$.MODULE$.outputToTensor(), package$.MODULE$.tensorDataHelper()), package$tf$learn$.MODULE$.Cast().apply("Input/Cast", package$.MODULE$.FLOAT32()).$greater$greater(package$tf$learn$.MODULE$.Conv2D().apply("Layer_0/Conv2D", package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{5, 5, 3, 32})), 1L, 1L, NN$SamePadding$.MODULE$, package$tf$learn$.MODULE$.Conv2D().apply$default$6(), package$tf$learn$.MODULE$.Conv2D().apply$default$7(), package$tf$learn$.MODULE$.Conv2D().apply$default$8())).$greater$greater(package$tf$learn$.MODULE$.AddBias().apply("Layer_0/Bias", package$tf$learn$.MODULE$.AddBias().apply$default$2())).$greater$greater(package$tf$learn$.MODULE$.ReLU().apply("Layer_0/ReLU", 0.1f)).$greater$greater(package$tf$learn$.MODULE$.MaxPool().apply("Layer_0/MaxPool", Seq$.MODULE$.apply(Predef$.MODULE$.wrapLongArray(new long[]{1, 2, 2, 1})), 1L, 1L, NN$SamePadding$.MODULE$, package$tf$learn$.MODULE$.MaxPool().apply$default$6())).$greater$greater(package$tf$learn$.MODULE$.Conv2D().apply("Layer_1/Conv2D", package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{5, 5, 32, 64})), 1L, 1L, NN$SamePadding$.MODULE$, package$tf$learn$.MODULE$.Conv2D().apply$default$6(), package$tf$learn$.MODULE$.Conv2D().apply$default$7(), package$tf$learn$.MODULE$.Conv2D().apply$default$8())).$greater$greater(package$tf$learn$.MODULE$.AddBias().apply("Bias_1", package$tf$learn$.MODULE$.AddBias().apply$default$2())).$greater$greater(package$tf$learn$.MODULE$.ReLU().apply("Layer_1/ReLU", 0.1f)).$greater$greater(package$tf$learn$.MODULE$.MaxPool().apply("Layer_1/MaxPool", Seq$.MODULE$.apply(Predef$.MODULE$.wrapLongArray(new long[]{1, 2, 2, 1})), 1L, 1L, NN$SamePadding$.MODULE$, package$tf$learn$.MODULE$.MaxPool().apply$default$6())).$greater$greater(package$tf$learn$.MODULE$.Flatten().apply("Layer_2/Flatten")).$greater$greater(package$tf$learn$.MODULE$.Linear().apply("Layer_2/Linear", 256, package$tf$learn$.MODULE$.Linear().apply$default$3(), package$tf$learn$.MODULE$.Linear().apply$default$4(), package$tf$learn$.MODULE$.Linear().apply$default$5())).$greater$greater(package$tf$learn$.MODULE$.ReLU().apply("Layer_2/ReLU", 0.1f)).$greater$greater(package$tf$learn$.MODULE$.Linear().apply("OutputLayer/Linear", 10, package$tf$learn$.MODULE$.Linear().apply$default$3(), package$tf$learn$.MODULE$.Linear().apply$default$4(), package$tf$learn$.MODULE$.Linear().apply$default$5())), package$tf$learn$.MODULE$.Input().apply(package$.MODULE$.UINT8(), package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{-1})), package$tf$learn$.MODULE$.Input().apply$default$3(), DataTypeAuxToDataType$.MODULE$.dataTypeAuxToDataType(), DataTypeToOutput$.MODULE$.dataTypeToOutput(), OutputToTensor$.MODULE$.outputToTensor(), package$.MODULE$.tensorDataHelper()), package$tf$learn$.MODULE$.Cast().apply("TrainInput/Cast", package$.MODULE$.INT64()), package$tf$learn$.MODULE$.SparseSoftmaxCrossEntropy().apply("Loss/CrossEntropy").$greater$greater(package$tf$learn$.MODULE$.Mean().apply("Loss/Mean")).$greater$greater(package$tf$learn$.MODULE$.ScalarSummary().apply("Loss/Summary", "Loss", package$tf$learn$.MODULE$.ScalarSummary().apply$default$3(), package$tf$learn$.MODULE$.ScalarSummary().apply$default$4())), package$tf$.MODULE$.train().AdaGrad().apply(0.1d, package$tf$.MODULE$.train().AdaGrad().apply$default$2(), package$tf$.MODULE$.train().AdaGrad().apply$default$3(), package$tf$.MODULE$.train().AdaGrad().apply$default$4(), package$tf$.MODULE$.train().AdaGrad().apply$default$5(), package$tf$.MODULE$.train().AdaGrad().apply$default$6()));
        if (this.logger.underlying().isInfoEnabled()) {
            this.logger.underlying().info("Training the linear regression model.");
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
        }
        Path path = Paths.get("temp/cnn-stl10", new String[0]);
        InMemoryEstimator apply4 = package$tf$learn$.MODULE$.InMemoryEstimator().apply(package$.MODULE$.supervisedTrainableModelToModelFunction(apply3), package$tf$learn$.MODULE$.Configuration().apply(new Some(path), package$tf$learn$.MODULE$.Configuration().apply$default$2(), package$tf$learn$.MODULE$.Configuration().apply$default$3(), package$tf$learn$.MODULE$.Configuration().apply$default$4()), package$tf$learn$.MODULE$.StopCriteria().apply(package$tf$learn$.MODULE$.StopCriteria().apply$default$1(), new Some(BoxesRunTime.boxToLong(100000L)), package$tf$learn$.MODULE$.StopCriteria().apply$default$3(), package$tf$learn$.MODULE$.StopCriteria().apply$default$4(), package$tf$learn$.MODULE$.StopCriteria().apply$default$5(), package$tf$learn$.MODULE$.StopCriteria().apply$default$6(), package$tf$learn$.MODULE$.StopCriteria().apply$default$7()), Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new Hook[]{package$tf$learn$.MODULE$.StepRateLogger().apply(false, path, package$tf$learn$.MODULE$.StepHookTrigger().apply(100), package$tf$learn$.MODULE$.StepRateLogger().apply$default$4(), package$tf$learn$.MODULE$.StepRateLogger().apply$default$5()), package$tf$learn$.MODULE$.CheckpointSaver().apply(path, package$tf$learn$.MODULE$.StepHookTrigger().apply(1000), package$tf$learn$.MODULE$.CheckpointSaver().apply$default$3(), package$tf$learn$.MODULE$.CheckpointSaver().apply$default$4())})), package$tf$learn$.MODULE$.InMemoryEstimator().apply$default$5(), package$tf$learn$.MODULE$.InMemoryEstimator().apply$default$6(), package$tf$learn$.MODULE$.InMemoryEstimator().apply$default$7(), package$tf$learn$.MODULE$.TensorBoardConfig().apply(path, package$tf$learn$.MODULE$.TensorBoardConfig().apply$default$2(), package$tf$learn$.MODULE$.TensorBoardConfig().apply$default$3(), 1), package$tf$learn$.MODULE$.InMemoryEstimator().apply$default$9());
        apply4.train(new STL10$$anonfun$main$1(prefetch), package$tf$learn$.MODULE$.StopCriteria().apply(package$tf$learn$.MODULE$.StopCriteria().apply$default$1(), new Some(BoxesRunTime.boxToLong(10000L)), package$tf$learn$.MODULE$.StopCriteria().apply$default$3(), package$tf$learn$.MODULE$.StopCriteria().apply$default$4(), package$tf$learn$.MODULE$.StopCriteria().apply$default$5(), package$tf$learn$.MODULE$.StopCriteria().apply$default$6(), package$tf$learn$.MODULE$.StopCriteria().apply$default$7()));
        if (this.logger.underlying().isInfoEnabled()) {
            this.logger.underlying().info("Train accuracy = {}", new Object[]{BoxesRunTime.boxToFloat(accuracy$1(load.trainImages(), load.trainLabels(), apply4))});
            BoxedUnit boxedUnit5 = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit6 = BoxedUnit.UNIT;
        }
        if (!this.logger.underlying().isInfoEnabled()) {
            BoxedUnit boxedUnit7 = BoxedUnit.UNIT;
        } else {
            this.logger.underlying().info("Test accuracy = {}", new Object[]{BoxesRunTime.boxToFloat(accuracy$1(load.testImages(), load.testLabels(), apply4))});
            BoxedUnit boxedUnit8 = BoxedUnit.UNIT;
        }
    }

    private final float accuracy$1(Tensor tensor, Tensor tensor2, InMemoryEstimator inMemoryEstimator) {
        Basic.BasicOps tensorToBasicOps = package$.MODULE$.tensorToBasicOps(tensor);
        Seq splitEvenly = tensorToBasicOps.splitEvenly(100, tensorToBasicOps.splitEvenly$default$2());
        Basic.BasicOps tensorToBasicOps2 = package$.MODULE$.tensorToBasicOps(tensor2);
        Seq splitEvenly2 = tensorToBasicOps2.splitEvenly(100, tensorToBasicOps2.splitEvenly$default$2());
        FloatRef create = FloatRef.create(0.0f);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), 100).foreach$mVc$sp(new STL10$$anonfun$accuracy$1$1(inMemoryEstimator, splitEvenly, splitEvenly2, create));
        return create.elem / tensor.shape().apply(0);
    }

    private STL10$() {
        MODULE$ = this;
        this.logger = Logger$.MODULE$.apply(LoggerFactory.getLogger("Examples / STL10"));
    }
}
