package org.deeplearning4j.base;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.deeplearning4j.util.ArrayUtil;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/base/DeepLearningTest.class */
public abstract class DeepLearningTest {
    private static Logger log = LoggerFactory.getLogger(DeepLearningTest.class);

    public static Pair<DoubleMatrix, DoubleMatrix> getIris() throws IOException {
        return IrisUtils.loadIris();
    }

    public static Pair<DoubleMatrix, DoubleMatrix> getIris(int i) throws IOException {
        return IrisUtils.loadIris(i);
    }

    public static Pair<DoubleMatrix, DoubleMatrix> getFaces(int i) throws Exception {
        LFWLoader lFWLoader = new LFWLoader();
        lFWLoader.getIfNotExists();
        return lFWLoader.getAllImagesAsMatrix(i);
    }

    public static Pair<DoubleMatrix, DoubleMatrix> getFacesMatrix() throws Exception {
        LFWLoader lFWLoader = new LFWLoader();
        lFWLoader.getIfNotExists();
        return lFWLoader.getAllImagesAsMatrix();
    }

    public static List<Pair<DoubleMatrix, DoubleMatrix>> getFirstFaces(int i) throws Exception {
        LFWLoader lFWLoader = new LFWLoader();
        lFWLoader.getIfNotExists();
        return lFWLoader.getFirst(i);
    }

    public List<Pair<DoubleMatrix, DoubleMatrix>> getFaces() throws Exception {
        LFWLoader lFWLoader = new LFWLoader();
        lFWLoader.getIfNotExists();
        return lFWLoader.getImagesAsList();
    }

    public static Pair<DoubleMatrix, DoubleMatrix> getMnistExample(int i) throws IOException {
        if (!new File("/tmp/MNIST").exists()) {
            new MnistFetcher().downloadAndUntar();
        }
        MnistManager mnistManager = new MnistManager("/tmp/MNIST/train-images-idx1-ubyte", "/tmp/MNIST/train-labels-idx1-ubyte");
        mnistManager.setCurrent(i);
        return new Pair<>(MatrixUtil.toMatrix(ArrayUtil.flatten(mnistManager.readImage())).transpose(), MatrixUtil.toOutcomeVector(mnistManager.readLabel(), 10));
    }

    public List<Pair<DoubleMatrix, DoubleMatrix>> getMnistExampleBatches(int i, int i2) throws IOException {
        File file = new File("/tmp/MNIST");
        ArrayList arrayList = new ArrayList();
        if (!file.exists()) {
            new MnistFetcher().downloadAndUntar();
        }
        MnistManager mnistManager = new MnistManager("/tmp/MNIST/train-images-idx1-ubyte", "/tmp/MNIST/train-labels-idx1-ubyte");
        int[] flatten = ArrayUtil.flatten(mnistManager.readImage());
        for (int i3 = 0; i3 < i2; i3++) {
            double[][] dArr = new double[i][flatten.length];
            int[][] iArr = new int[i][10];
            for (int i4 = 1 + i3; i4 < i + 1 + i3; i4++) {
                mnistManager.setCurrent(i4);
                dArr[(i4 - 1) - i3] = ArrayUtil.flatten(ArrayUtil.toDouble(mnistManager.readImage()));
                iArr[(i4 - 1) - i3] = ArrayUtil.toOutcomeArray(mnistManager.readLabel(), 10);
            }
            arrayList.add(new Pair(new DoubleMatrix(dArr), MatrixUtil.toMatrix(iArr)));
        }
        return arrayList;
    }

    public static Pair<DoubleMatrix, DoubleMatrix> getMnistExampleBatch(int i) throws IOException {
        if (!new File("/tmp/MNIST").exists() || !new File("/tmp/MNIST/train-images-idx1-ubyte").exists() || !new File("/tmp/MNIST/train-labels-idx1-ubyte").exists()) {
            new MnistFetcher().downloadAndUntar();
        }
        MnistManager mnistManager = new MnistManager("/tmp/MNIST/train-images-idx1-ubyte", "/tmp/MNIST/train-labels-idx1-ubyte");
        double[][] dArr = new double[i][ArrayUtil.flatten(mnistManager.readImage()).length];
        int[][] iArr = new int[i][10];
        for (int i2 = 1; i2 < i + 1; i2++) {
            mnistManager.setCurrent(i2);
            double[] flatten = ArrayUtil.flatten(ArrayUtil.toDouble(mnistManager.readImage()));
            for (int i3 = 0; i3 < flatten.length; i3++) {
                flatten[i3] = MathUtils.normalize(flatten[i3], 0.0d, 255.0d);
            }
            dArr[i2 - 1] = flatten;
            iArr[i2 - 1] = ArrayUtil.toOutcomeArray(mnistManager.readLabel(), 10);
        }
        return new Pair<>(new DoubleMatrix(dArr), MatrixUtil.toMatrix(iArr));
    }
}
