package org.deeplearning4j.datasets;

import com.google.common.collect.Lists;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.Persistable;
import org.deeplearning4j.util.MathUtils;
import org.jblas.DoubleMatrix;
import org.jblas.SimpleBlas;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/DataSet.class */
public class DataSet extends Pair<DoubleMatrix, DoubleMatrix> implements Persistable, Iterable<DataSet> {
    private static final long serialVersionUID = 1935520764586513365L;
    private static Logger log = LoggerFactory.getLogger(DataSet.class);

    public DataSet() {
        this(DoubleMatrix.zeros(1), DoubleMatrix.zeros(1));
    }

    public DataSet(Pair<DoubleMatrix, DoubleMatrix> pair) {
        this(pair.getFirst(), pair.getSecond());
    }

    public DataSet(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        super(doubleMatrix, doubleMatrix2);
        if (doubleMatrix.rows != doubleMatrix2.rows) {
            throw new IllegalStateException("Invalid data set; first and second do not have equal rows. First was " + doubleMatrix.rows + " second was " + doubleMatrix2.rows);
        }
    }

    public DataSetIterator iterator(int i) {
        return new ListDataSetIterator(dataSetBatches(i));
    }

    public DataSet copy() {
        return new DataSet(getFirst(), getSecond());
    }

    public static DataSet empty() {
        return new DataSet(DoubleMatrix.zeros(1), DoubleMatrix.zeros(1));
    }

    public static DataSet merge(List<DataSet> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        DataSet dataSet = list.get(0);
        int i = totalExamples(list);
        DoubleMatrix doubleMatrix = new DoubleMatrix(i, dataSet.getFirst().columns);
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(i, dataSet.getSecond().columns);
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            DataSet dataSet2 = list.get(i3);
            for (int i4 = 0; i4 < dataSet2.numExamples(); i4++) {
                DataSet dataSet3 = dataSet2.get(i4);
                doubleMatrix.putRow(i2, dataSet3.getFirst());
                doubleMatrix2.putRow(i2, dataSet3.getSecond());
                i2++;
            }
        }
        return new DataSet(doubleMatrix, doubleMatrix2);
    }

    private static int totalExamples(Collection<DataSet> collection) {
        int i = 0;
        Iterator<DataSet> it = collection.iterator();
        while (it.hasNext()) {
            i += it.next().numExamples();
        }
        return i;
    }

    public int numInputs() {
        return getFirst().columns;
    }

    public void validate() {
        if (getFirst().rows != getSecond().rows) {
            throw new IllegalStateException("Invalid dataset");
        }
    }

    public int outcome() {
        if (numExamples() > 1) {
            throw new IllegalStateException("Unable to derive outcome for dataset greater than one row");
        }
        return SimpleBlas.iamax(getSecond());
    }

    public DataSet get(int i) {
        return new DataSet(getFirst().getRow(i), getSecond().getRow(i));
    }

    public List<List<DataSet>> batchBy(int i) {
        return Lists.partition(asList(), i);
    }

    public Counter<Integer> outcomeCounts() {
        List<DataSet> asList = asList();
        Counter<Integer> counter = new Counter<>();
        for (int i = 0; i < asList.size(); i++) {
            counter.incrementCount(Integer.valueOf(asList.get(i).outcome()), 1.0d);
        }
        return counter;
    }

    public List<DataSet> dataSetBatches(int i) {
        List partition = Lists.partition(asList(), i);
        ArrayList arrayList = new ArrayList();
        Iterator it = partition.iterator();
        while (it.hasNext()) {
            arrayList.add(merge((List) it.next()));
        }
        return arrayList;
    }

    public List<List<DataSet>> sortAndBatchByNumLabels() {
        sortByLabel();
        return Lists.partition(asList(), numOutcomes());
    }

    public List<List<DataSet>> batchByNumLabels() {
        return Lists.partition(asList(), numOutcomes());
    }

    public List<DataSet> asList() {
        ArrayList arrayList = new ArrayList(numExamples());
        for (int i = 0; i < numExamples(); i++) {
            arrayList.add(new DataSet(getFirst().getRow(i), getSecond().getRow(i)));
        }
        return arrayList;
    }

    public Pair<DataSet, DataSet> splitTestAndTrain(int i) {
        if (i >= numExamples()) {
            throw new IllegalArgumentException("Unable to split on size larger than the number of rows");
        }
        List<DataSet> asList = asList();
        Collections.rotate(asList, 3);
        Collections.shuffle(asList);
        ArrayList arrayList = new ArrayList();
        arrayList.add(asList.subList(0, i));
        arrayList.add(asList.subList(i, asList.size()));
        return new Pair<>(merge((List) arrayList.get(0)), merge((List) arrayList.get(1)));
    }

    public void sortByLabel() {
        HashMap hashMap = new HashMap();
        List<DataSet> asList = asList();
        int numOutcomes = numOutcomes();
        int numExamples = numExamples();
        for (DataSet dataSet : asList) {
            int label = getLabel(dataSet);
            Queue queue = (Queue) hashMap.get(Integer.valueOf(label));
            if (queue == null) {
                queue = new ArrayDeque();
                hashMap.put(Integer.valueOf(label), queue);
            }
            queue.add(dataSet);
        }
        for (Integer num : hashMap.keySet()) {
            log.info("Label " + num + " has " + ((Queue) hashMap.get(num)).size() + " elements");
        }
        boolean z = true;
        int i = 0;
        while (i < numExamples) {
            if (z) {
                int i2 = 0;
                while (true) {
                    if (i2 >= numOutcomes) {
                        break;
                    }
                    DataSet dataSet2 = (DataSet) ((Queue) hashMap.get(Integer.valueOf(i2))).poll();
                    if (dataSet2 == null) {
                        z = false;
                        break;
                    } else {
                        addRow(dataSet2, i);
                        i++;
                        i2++;
                    }
                }
            } else {
                DataSet dataSet3 = null;
                Iterator it = hashMap.values().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Queue queue2 = (Queue) it.next();
                    if (!queue2.isEmpty()) {
                        dataSet3 = (DataSet) queue2.poll();
                        break;
                    }
                }
                addRow(dataSet3, i);
            }
            i++;
        }
    }

    public void addRow(DataSet dataSet, int i) {
        if (i > numExamples() || dataSet == null) {
            throw new IllegalArgumentException("Invalid index for adding a row");
        }
        getFirst().putRow(i, dataSet.getFirst());
        getSecond().putRow(i, dataSet.getSecond());
    }

    private int getLabel(DataSet dataSet) {
        return SimpleBlas.iamax(dataSet.getSecond());
    }

    public DoubleMatrix exampleSums() {
        return getFirst().columnSums();
    }

    public DoubleMatrix exampleMaxs() {
        return getFirst().columnMaxs();
    }

    public DoubleMatrix exampleMeans() {
        return getFirst().columnMeans();
    }

    public void saveTo(File file, boolean z) throws IOException {
        if (file.exists()) {
            file.delete();
        }
        file.createNewFile();
        if (z) {
            DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
            getFirst().out(dataOutputStream);
            getSecond().out(dataOutputStream);
            dataOutputStream.flush();
            dataOutputStream.close();
            return;
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
        for (int i = 0; i < numExamples(); i++) {
            bufferedOutputStream.write(getFirst().getRow(i).toString("%.3f", "[", "]", ", ", ";").getBytes());
            bufferedOutputStream.write("\t".getBytes());
            bufferedOutputStream.write(getSecond().getRow(i).toString("%.3f", "[", "]", ", ", ";").getBytes());
            bufferedOutputStream.write("\n".getBytes());
        }
        bufferedOutputStream.flush();
        bufferedOutputStream.close();
    }

    public static DataSet load(File file) throws IOException {
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(file)));
        DoubleMatrix doubleMatrix = new DoubleMatrix(1, 1);
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(1, 1);
        doubleMatrix.in(dataInputStream);
        doubleMatrix2.in(dataInputStream);
        dataInputStream.close();
        return new DataSet(doubleMatrix, doubleMatrix2);
    }

    public DataSet sample(int i) {
        return sample(i, (RandomGenerator) new MersenneTwister(System.currentTimeMillis()));
    }

    public DataSet sample(int i, RandomGenerator randomGenerator) {
        return sample(i, randomGenerator, false);
    }

    public DataSet sample(int i, boolean z) {
        return sample(i, new MersenneTwister(System.currentTimeMillis()), z);
    }

    public DataSet sample(int i, RandomGenerator randomGenerator, boolean z) {
        if (i >= numExamples()) {
            return this;
        }
        DoubleMatrix doubleMatrix = new DoubleMatrix(i, getFirst().columns);
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(i, numOutcomes());
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < i; i2++) {
            int nextInt = randomGenerator.nextInt(numExamples());
            while (hashSet.contains(Integer.valueOf(nextInt))) {
                nextInt = randomGenerator.nextInt(numExamples());
            }
            doubleMatrix.putRow(i2, getFirst().getRow(i2));
            doubleMatrix2.putRow(i2, getSecond().getRow(i2));
        }
        return new DataSet(doubleMatrix, doubleMatrix2);
    }

    public void roundToTheNearest(int i) {
        for (int i2 = 0; i2 < getFirst().length; i2++) {
            getFirst().put(i2, MathUtils.roundDouble(getFirst().get(i2), i));
        }
    }

    public int numOutcomes() {
        return getSecond().columns;
    }

    public int numExamples() {
        return getFirst().rows;
    }

    @Override // org.deeplearning4j.berkeley.Pair
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("===========INPUT===================\n").append(getFirst().toString().replaceAll(";", "\n")).append("\n=================OUTPUT==================\n").append(getSecond().toString().replaceAll(";", "\n"));
        return sb.toString();
    }

    public static void main(String[] strArr) throws IOException {
        MnistDataFetcher mnistDataFetcher = new MnistDataFetcher();
        mnistDataFetcher.fetch(100);
        new DataSet(mnistDataFetcher.next()).saveTo(new File(strArr[0]), false);
    }

    @Override // org.deeplearning4j.nn.Persistable
    public void write(OutputStream outputStream) {
        DataOutputStream dataOutputStream = new DataOutputStream(outputStream);
        try {
            getFirst().out(dataOutputStream);
            getSecond().out(dataOutputStream);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.Persistable
    public void load(InputStream inputStream) {
        DataInputStream dataInputStream = new DataInputStream(inputStream);
        try {
            getFirst().in(dataInputStream);
            getSecond().in(dataInputStream);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // java.lang.Iterable
    public Iterator<DataSet> iterator() {
        return asList().iterator();
    }
}
