package org.platanios.tensorflow.data.image;

import com.typesafe.scalalogging.Logger;
import com.typesafe.scalalogging.Logger$;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.zip.GZIPInputStream;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.IndexerConstructionWithTwoNumbers$;
import org.platanios.tensorflow.api.package$;
import org.platanios.tensorflow.api.package$tfi$;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.SupportedType$;
import org.platanios.tensorflow.api.types.package$UINT8$;
import org.platanios.tensorflow.data.Loader;
import org.platanios.tensorflow.data.image.CIFARLoader;
import org.slf4j.LoggerFactory;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq$;
import scala.collection.immutable.StringOps;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.util.matching.Regex;

/* compiled from: CIFARLoader.scala */
/* loaded from: input_file:org/platanios/tensorflow/data/image/CIFARLoader$.class */
public final class CIFARLoader$ implements Loader {
    public static CIFARLoader$ MODULE$;
    private final Logger logger;
    private final Regex googleDriveConfirmTokenRegex;

    static {
        new CIFARLoader$();
    }

    @Override // org.platanios.tensorflow.data.Loader
    public boolean maybeDownload(Path path, String str, int i) {
        boolean maybeDownload;
        maybeDownload = maybeDownload(path, str, i);
        return maybeDownload;
    }

    @Override // org.platanios.tensorflow.data.Loader
    public void download(Path path, String str, int i) {
        download(path, str, i);
    }

    @Override // org.platanios.tensorflow.data.Loader
    public int maybeDownload$default$3() {
        int maybeDownload$default$3;
        maybeDownload$default$3 = maybeDownload$default$3();
        return maybeDownload$default$3;
    }

    @Override // org.platanios.tensorflow.data.Loader
    public int download$default$3() {
        int download$default$3;
        download$default$3 = download$default$3();
        return download$default$3;
    }

    @Override // org.platanios.tensorflow.data.Loader
    public Regex googleDriveConfirmTokenRegex() {
        return this.googleDriveConfirmTokenRegex;
    }

    @Override // org.platanios.tensorflow.data.Loader
    public void org$platanios$tensorflow$data$Loader$_setter_$googleDriveConfirmTokenRegex_$eq(Regex regex) {
        this.googleDriveConfirmTokenRegex = regex;
    }

    @Override // org.platanios.tensorflow.data.Loader
    public Logger logger() {
        return this.logger;
    }

    public CIFARDataset load(Path path, CIFARLoader.DatasetType datasetType, int i) {
        String url = datasetType.url();
        String compressedFilename = datasetType.compressedFilename();
        maybeDownload(path.resolve(compressedFilename), new StringBuilder(0).append(url).append(compressedFilename).toString(), i);
        CIFARDataset extractFiles = extractFiles(path.resolve(compressedFilename), datasetType, i);
        if (logger().underlying().isInfoEnabled()) {
            logger().underlying().info("Finished loading the {} dataset.", new Object[]{datasetType.name()});
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        return extractFiles;
    }

    public CIFARLoader.DatasetType load$default$2() {
        return CIFARLoader$CIFAR_10$.MODULE$;
    }

    public int load$default$3() {
        return 8192;
    }

    private CIFARDataset extractFiles(Path path, CIFARLoader.DatasetType datasetType, int i) {
        if (logger().underlying().isInfoEnabled()) {
            logger().underlying().info("Extracting data from file '{}'.", new Object[]{path});
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        TarArchiveInputStream tarArchiveInputStream = new TarArchiveInputStream(new GZIPInputStream(Files.newInputStream(path, new OpenOption[0])));
        CIFARDataset cIFARDataset = new CIFARDataset(datasetType, null, null, null, null);
        ObjectRef create = ObjectRef.create(tarArchiveInputStream.getNextTarEntry());
        while (((TarArchiveEntry) create.elem) != null) {
            if (datasetType.trainFilenames().exists(str -> {
                return BoxesRunTime.boxToBoolean($anonfun$extractFiles$1(create, str));
            })) {
                Tuple2<Tensor<package$UINT8$>, Tensor<package$UINT8$>> readImagesAndLabels = readImagesAndLabels(tarArchiveInputStream, (TarArchiveEntry) create.elem, datasetType, i);
                if (readImagesAndLabels == null) {
                    throw new MatchError(readImagesAndLabels);
                }
                Tuple2 tuple2 = new Tuple2((Tensor) readImagesAndLabels._1(), (Tensor) readImagesAndLabels._2());
                Tensor<package$UINT8$> tensor = (Tensor) tuple2._1();
                Tensor<package$UINT8$> tensor2 = (Tensor) tuple2._2();
                Tensor<package$UINT8$> concatenate = cIFARDataset.trainImages() == null ? tensor : package$tfi$.MODULE$.concatenate(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tensor[]{cIFARDataset.trainImages(), tensor})), package$.MODULE$.tensorFromTensorConvertible(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())));
                Tensor<package$UINT8$> concatenate2 = cIFARDataset.trainLabels() == null ? tensor2 : package$tfi$.MODULE$.concatenate(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tensor[]{cIFARDataset.trainLabels(), tensor2})), package$.MODULE$.tensorFromTensorConvertible(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())));
                CIFARDataset cIFARDataset2 = cIFARDataset;
                cIFARDataset = cIFARDataset2.copy(cIFARDataset2.copy$default$1(), concatenate, concatenate2, cIFARDataset2.copy$default$4(), cIFARDataset2.copy$default$5());
            } else if (((TarArchiveEntry) create.elem).getName().endsWith(datasetType.testFilename())) {
                Tuple2<Tensor<package$UINT8$>, Tensor<package$UINT8$>> readImagesAndLabels2 = readImagesAndLabels(tarArchiveInputStream, (TarArchiveEntry) create.elem, datasetType, i);
                if (readImagesAndLabels2 == null) {
                    throw new MatchError(readImagesAndLabels2);
                }
                Tuple2 tuple22 = new Tuple2((Tensor) readImagesAndLabels2._1(), (Tensor) readImagesAndLabels2._2());
                Tensor<package$UINT8$> tensor3 = (Tensor) tuple22._1();
                Tensor<package$UINT8$> tensor4 = (Tensor) tuple22._2();
                CIFARDataset cIFARDataset3 = cIFARDataset;
                cIFARDataset = cIFARDataset3.copy(cIFARDataset3.copy$default$1(), cIFARDataset3.copy$default$2(), cIFARDataset3.copy$default$3(), tensor3, tensor4);
            } else {
                continue;
            }
            create.elem = tarArchiveInputStream.getNextTarEntry();
        }
        tarArchiveInputStream.close();
        return cIFARDataset;
    }

    private CIFARLoader.DatasetType extractFiles$default$2() {
        return CIFARLoader$CIFAR_10$.MODULE$;
    }

    private int extractFiles$default$3() {
        return 8192;
    }

    private Tuple2<Tensor<package$UINT8$>, Tensor<package$UINT8$>> readImagesAndLabels(TarArchiveInputStream tarArchiveInputStream, TarArchiveEntry tarArchiveEntry, CIFARLoader.DatasetType datasetType, int i) {
        Tuple2<Tensor<package$UINT8$>, Tensor<package$UINT8$>> tuple2;
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] bArr = new byte[i];
        scala.package$.MODULE$.Stream().continually(() -> {
            return tarArchiveInputStream.read(bArr);
        }).takeWhile(i2 -> {
            return i2 != -1;
        }).foreach(i3 -> {
            byteArrayOutputStream.write(bArr, 0, i3);
        });
        ByteBuffer order = ByteBuffer.wrap(byteArrayOutputStream.toByteArray()).order(ByteOrder.BIG_ENDIAN);
        byteArrayOutputStream.close();
        Tensor fromBuffer = package$.MODULE$.Tensor().fromBuffer(package$.MODULE$.UINT8(), package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{((int) tarArchiveEntry.getSize()) / datasetType.entryByteSize(), datasetType.entryByteSize()})), (int) tarArchiveEntry.getSize(), order);
        if (CIFARLoader$CIFAR_10$.MODULE$.equals(datasetType)) {
            tuple2 = new Tuple2<>(package$.MODULE$.BasicOps(fromBuffer.apply(package$.MODULE$.$colon$colon(), Predef$.MODULE$.wrapRefArray(new Indexer[]{package$.MODULE$.intToIndexerConstruction(1).$colon$colon()}))).reshape(package$.MODULE$.tensorFromTensorConvertible(package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{-1, 32, 32, 3})), TensorConvertible$.MODULE$.fromShape())), fromBuffer.apply(package$.MODULE$.$colon$colon(), Predef$.MODULE$.wrapRefArray(new Indexer[]{package$.MODULE$.intToIndex(0)})));
        } else {
            if (!CIFARLoader$CIFAR_100$.MODULE$.equals(datasetType)) {
                throw new MatchError(datasetType);
            }
            tuple2 = new Tuple2<>(package$.MODULE$.BasicOps(fromBuffer.apply(package$.MODULE$.$colon$colon(), Predef$.MODULE$.wrapRefArray(new Indexer[]{package$.MODULE$.intToIndexerConstruction(2).$colon$colon()}))).reshape(package$.MODULE$.tensorFromTensorConvertible(package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{-1, 32, 32, 3})), TensorConvertible$.MODULE$.fromShape())), fromBuffer.apply(package$.MODULE$.$colon$colon(), Predef$.MODULE$.wrapRefArray(new Indexer[]{IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(package$.MODULE$.intToIndexerConstruction(2).$colon$colon(package$.MODULE$.intToIndexerConstruction(0)))})));
        }
        return tuple2;
    }

    private CIFARLoader.DatasetType readImagesAndLabels$default$3() {
        return CIFARLoader$CIFAR_10$.MODULE$;
    }

    private int readImagesAndLabels$default$4() {
        return 8192;
    }

    public static final /* synthetic */ boolean $anonfun$extractFiles$1(ObjectRef objectRef, String str) {
        return ((TarArchiveEntry) objectRef.elem).getName().endsWith(str);
    }

    private CIFARLoader$() {
        MODULE$ = this;
        org$platanios$tensorflow$data$Loader$_setter_$googleDriveConfirmTokenRegex_$eq(new StringOps(Predef$.MODULE$.augmentString("<a id=\"uc-download-link\".*href=\"/uc\\?export=download&amp;(confirm=.*)&amp;id=.*\">Download anyway</a>")).r());
        this.logger = Logger$.MODULE$.apply(LoggerFactory.getLogger("CIFAR Data Loader"));
    }
}
