package org.deeplearning4j.datasets.iterator;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.class */
public class DataSetIteratorSplitter {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DataSetIteratorSplitter.class);
    protected DataSetIterator backedIterator;
    protected final long totalExamples;
    protected final double ratio;
    protected final double[] ratios;
    protected final long numTrain;
    protected final long numTest;
    protected final long numArbitrarySets;
    protected final int[] splits;
    protected AtomicLong counter = new AtomicLong(0);
    protected AtomicBoolean resetPending = new AtomicBoolean(false);
    protected DataSet firstTrain = null;
    protected int partNumber = 0;

    public DataSetIteratorSplitter(@NonNull DataSetIterator dataSetIterator, long j, double d) {
        if (dataSetIterator == null) {
            throw new NullPointerException("baseIterator is marked @NonNull but is null");
        }
        if (d <= 0.0d || d >= 1.0d) {
            throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
        }
        if (j < 0) {
            throw new ND4JIllegalStateException("totalExamples number should be positive value");
        }
        if (!dataSetIterator.resetSupported()) {
            throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
        }
        this.backedIterator = dataSetIterator;
        this.totalExamples = j;
        this.ratio = d;
        this.ratios = null;
        this.numTrain = (long) (this.totalExamples * d);
        this.numTest = this.totalExamples - this.numTrain;
        this.numArbitrarySets = 2L;
        this.splits = null;
        log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
    }

    public DataSetIteratorSplitter(@NonNull DataSetIterator dataSetIterator, long j, double[] dArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("baseIterator is marked @NonNull but is null");
        }
        for (double d : dArr) {
            if (d <= 0.0d || d >= 1.0d) {
                throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
            }
        }
        if (j < 0) {
            throw new ND4JIllegalStateException("totalExamples number should be positive value");
        }
        if (!dataSetIterator.resetSupported()) {
            throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
        }
        this.backedIterator = dataSetIterator;
        this.totalExamples = j;
        this.ratio = 0.0d;
        this.ratios = dArr;
        this.numTrain = 0L;
        this.numTest = 0L;
        this.numArbitrarySets = dArr.length;
        this.splits = new int[this.ratios.length];
        for (int i = 0; i < this.splits.length; i++) {
            this.splits[i] = (int) (this.totalExamples * dArr[i]);
        }
        log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
    }

    public DataSetIteratorSplitter(@NonNull DataSetIterator dataSetIterator, int[] iArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("baseIterator is marked @NonNull but is null");
        }
        int i = 0;
        for (int i2 : iArr) {
            i += i2;
        }
        if (i < 0) {
            throw new ND4JIllegalStateException("totalExamples number should be positive value");
        }
        if (!dataSetIterator.resetSupported()) {
            throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
        }
        this.backedIterator = dataSetIterator;
        this.totalExamples = i;
        this.ratio = 0.0d;
        this.ratios = null;
        this.numTrain = 0L;
        this.numTest = 0L;
        this.splits = iArr;
        this.numArbitrarySets = iArr.length;
        log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
    }

    public List<DataSetIterator> getIterators() {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        int i2 = 0;
        for (int i3 : this.splits) {
            int i4 = i;
            i++;
            int[] iArr = {i2, i3};
            i2 += i3;
            arrayList.add(new ScrollableDataSetIterator(i4, this.backedIterator, this.counter, this.resetPending, this.firstTrain, iArr));
        }
        return arrayList;
    }

    @Deprecated
    public DataSetIterator getTrainIterator() {
        return new DataSetIterator() { // from class: org.deeplearning4j.datasets.iterator.DataSetIteratorSplitter.1
            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public DataSet next(int i) {
                throw new UnsupportedOperationException();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public List<String> getLabels() {
                return DataSetIteratorSplitter.this.backedIterator.getLabels();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public int inputColumns() {
                return DataSetIteratorSplitter.this.backedIterator.inputColumns();
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public int totalOutcomes() {
                return DataSetIteratorSplitter.this.backedIterator.totalOutcomes();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public boolean resetSupported() {
                return DataSetIteratorSplitter.this.backedIterator.resetSupported();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public boolean asyncSupported() {
                return DataSetIteratorSplitter.this.backedIterator.asyncSupported();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public void reset() {
                DataSetIteratorSplitter.this.resetPending.set(true);
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public int batch() {
                return DataSetIteratorSplitter.this.backedIterator.batch();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
                DataSetIteratorSplitter.this.backedIterator.setPreProcessor(dataSetPreProcessor);
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public DataSetPreProcessor getPreProcessor() {
                return DataSetIteratorSplitter.this.backedIterator.getPreProcessor();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                if (DataSetIteratorSplitter.this.resetPending.get()) {
                    if (!resetSupported()) {
                        throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
                    }
                    DataSetIteratorSplitter.this.backedIterator.reset();
                    DataSetIteratorSplitter.this.counter.set(0L);
                    DataSetIteratorSplitter.this.resetPending.set(false);
                }
                return DataSetIteratorSplitter.this.backedIterator.hasNext() && DataSetIteratorSplitter.this.counter.get() < DataSetIteratorSplitter.this.numTrain;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public DataSet next() {
                DataSetIteratorSplitter.this.counter.incrementAndGet();
                DataSet next = DataSetIteratorSplitter.this.backedIterator.next();
                if (DataSetIteratorSplitter.this.counter.get() == 1 && DataSetIteratorSplitter.this.firstTrain == null) {
                    DataSetIteratorSplitter.this.firstTrain = next.copy();
                    DataSetIteratorSplitter.this.firstTrain.detach();
                } else if (DataSetIteratorSplitter.this.counter.get() == 1 && !next.getFeatures().equalsWithEps(DataSetIteratorSplitter.this.firstTrain.getFeatures(), 1.0E-5d)) {
                    throw new ND4JIllegalStateException("First examples do not match. Randomization was used?");
                }
                return next;
            }
        };
    }

    @Deprecated
    public DataSetIterator getTestIterator() {
        return new DataSetIterator() { // from class: org.deeplearning4j.datasets.iterator.DataSetIteratorSplitter.2
            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public DataSet next(int i) {
                throw new UnsupportedOperationException();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public List<String> getLabels() {
                return DataSetIteratorSplitter.this.backedIterator.getLabels();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public int inputColumns() {
                return DataSetIteratorSplitter.this.backedIterator.inputColumns();
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public int totalOutcomes() {
                return DataSetIteratorSplitter.this.backedIterator.totalOutcomes();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public boolean resetSupported() {
                return DataSetIteratorSplitter.this.backedIterator.resetSupported();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public boolean asyncSupported() {
                return DataSetIteratorSplitter.this.backedIterator.asyncSupported();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public void reset() {
                DataSetIteratorSplitter.this.resetPending.set(true);
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public int batch() {
                return DataSetIteratorSplitter.this.backedIterator.batch();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
                DataSetIteratorSplitter.this.backedIterator.setPreProcessor(dataSetPreProcessor);
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
            public DataSetPreProcessor getPreProcessor() {
                return DataSetIteratorSplitter.this.backedIterator.getPreProcessor();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return DataSetIteratorSplitter.this.backedIterator.hasNext() && DataSetIteratorSplitter.this.counter.get() < DataSetIteratorSplitter.this.numTrain + DataSetIteratorSplitter.this.numTest;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public DataSet next() {
                DataSetIteratorSplitter.this.counter.incrementAndGet();
                return DataSetIteratorSplitter.this.backedIterator.next();
            }
        };
    }
}
