package org.deeplearning4j.util;

import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/util/TestDataSetConsumer.class */
public class TestDataSetConsumer {
    private DataSetIterator iterator;
    private long delay;
    private AtomicLong count = new AtomicLong(0);
    protected static final Logger logger = LoggerFactory.getLogger(TestDataSetConsumer.class);

    public TestDataSetConsumer(long j) {
        this.delay = j;
    }

    public TestDataSetConsumer(@NonNull DataSetIterator dataSetIterator, long j) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        this.iterator = dataSetIterator;
        this.delay = j;
    }

    public long consumeWhileHasNext(boolean z) {
        if (this.iterator == null) {
            throw new RuntimeException("Can't use consumeWhileHasNext() if iterator isn't set");
        }
        while (this.iterator.hasNext()) {
            consumeOnce((DataSet) this.iterator.next(), z);
        }
        return this.count.get();
    }

    public long consumeOnce(@NonNull DataSet dataSet, boolean z) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        long currentTimeMillis = System.currentTimeMillis() + this.delay;
        while (System.currentTimeMillis() < currentTimeMillis) {
            if (z) {
                try {
                    Thread.sleep(this.delay);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }
        this.count.incrementAndGet();
        if (this.count.get() % 100 == 0) {
            logger.info("Passed {} datasets...", Long.valueOf(this.count.get()));
        }
        return this.count.get();
    }

    public long getCount() {
        return this.count.get();
    }
}
