package org.deeplearning4j.datasets.iterator;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.class */
public class MultiDataSetIteratorSplitter {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) MultiDataSetIteratorSplitter.class);
    protected MultiDataSetIterator backedIterator;
    protected final long totalExamples;
    protected final double ratio;
    protected final long numTrain;
    protected final long numTest;
    protected AtomicLong counter = new AtomicLong(0);
    protected AtomicBoolean resetPending = new AtomicBoolean(false);
    protected MultiDataSet firstTrain = null;

    public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator multiDataSetIterator, long j, double d) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("baseIterator is marked @NonNull but is null");
        }
        if (d <= 0.0d || d >= 1.0d) {
            throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
        }
        if (j < 0) {
            throw new ND4JIllegalStateException("totalExamples number should be positive value");
        }
        if (!multiDataSetIterator.resetSupported()) {
            throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
        }
        this.backedIterator = multiDataSetIterator;
        this.totalExamples = j;
        this.ratio = d;
        this.numTrain = (long) (this.totalExamples * d);
        this.numTest = this.totalExamples - this.numTrain;
        log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
    }

    public MultiDataSetIterator getTrainIterator() {
        return new MultiDataSetIterator() { // from class: org.deeplearning4j.datasets.iterator.MultiDataSetIteratorSplitter.1
            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public org.nd4j.linalg.dataset.api.MultiDataSet next(int i) {
                throw new UnsupportedOperationException("To be implemented yet");
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
                MultiDataSetIteratorSplitter.this.backedIterator.setPreProcessor(multiDataSetPreProcessor);
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public MultiDataSetPreProcessor getPreProcessor() {
                return MultiDataSetIteratorSplitter.this.backedIterator.getPreProcessor();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public boolean resetSupported() {
                return MultiDataSetIteratorSplitter.this.backedIterator.resetSupported();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public boolean asyncSupported() {
                return MultiDataSetIteratorSplitter.this.backedIterator.asyncSupported();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public void reset() {
                MultiDataSetIteratorSplitter.this.resetPending.set(true);
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                if (MultiDataSetIteratorSplitter.this.resetPending.get()) {
                    if (!resetSupported()) {
                        throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
                    }
                    MultiDataSetIteratorSplitter.this.backedIterator.reset();
                    MultiDataSetIteratorSplitter.this.counter.set(0L);
                    MultiDataSetIteratorSplitter.this.resetPending.set(false);
                }
                return MultiDataSetIteratorSplitter.this.backedIterator.hasNext() && MultiDataSetIteratorSplitter.this.counter.get() < MultiDataSetIteratorSplitter.this.numTrain;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public org.nd4j.linalg.dataset.api.MultiDataSet next() {
                MultiDataSetIteratorSplitter.this.counter.incrementAndGet();
                org.nd4j.linalg.dataset.api.MultiDataSet next = MultiDataSetIteratorSplitter.this.backedIterator.next();
                if (MultiDataSetIteratorSplitter.this.counter.get() == 1 && MultiDataSetIteratorSplitter.this.firstTrain == null) {
                    MultiDataSetIteratorSplitter.this.firstTrain = (MultiDataSet) next.copy();
                    MultiDataSetIteratorSplitter.this.firstTrain.detach();
                } else if (MultiDataSetIteratorSplitter.this.counter.get() == 1) {
                    int i = 0;
                    for (INDArray iNDArray : next.getFeatures()) {
                        int i2 = i;
                        i++;
                        if (!iNDArray.equalsWithEps(MultiDataSetIteratorSplitter.this.firstTrain.getFeatures()[i2], 1.0E-5d)) {
                            throw new ND4JIllegalStateException("First examples do not match. Randomization was used?");
                        }
                    }
                }
                return next;
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    public MultiDataSetIterator getTestIterator() {
        return new MultiDataSetIterator() { // from class: org.deeplearning4j.datasets.iterator.MultiDataSetIteratorSplitter.2
            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public org.nd4j.linalg.dataset.api.MultiDataSet next(int i) {
                throw new UnsupportedOperationException("To be implemented yet");
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
                MultiDataSetIteratorSplitter.this.backedIterator.setPreProcessor(multiDataSetPreProcessor);
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public MultiDataSetPreProcessor getPreProcessor() {
                return MultiDataSetIteratorSplitter.this.backedIterator.getPreProcessor();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public boolean resetSupported() {
                return MultiDataSetIteratorSplitter.this.backedIterator.resetSupported();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public boolean asyncSupported() {
                return MultiDataSetIteratorSplitter.this.backedIterator.asyncSupported();
            }

            @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
            public void reset() {
                MultiDataSetIteratorSplitter.this.resetPending.set(true);
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return MultiDataSetIteratorSplitter.this.backedIterator.hasNext() && MultiDataSetIteratorSplitter.this.counter.get() < MultiDataSetIteratorSplitter.this.numTrain + MultiDataSetIteratorSplitter.this.numTest;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public org.nd4j.linalg.dataset.api.MultiDataSet next() {
                MultiDataSetIteratorSplitter.this.counter.incrementAndGet();
                return MultiDataSetIteratorSplitter.this.backedIterator.next();
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }
}
