package org.deeplearning4j.datasets.iterator;

import java.util.ConcurrentModificationException;
import java.util.NoSuchElementException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIterator.class */
public class AsyncMultiDataSetIterator implements MultiDataSetIterator {
    private final MultiDataSetIterator iterator;
    private final LinkedBlockingQueue<MultiDataSet> queue;
    private IteratorRunnable runnable;
    private Thread thread;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIterator$IteratorRunnable.class */
    public class IteratorRunnable implements Runnable {
        private volatile boolean isAlive;
        private volatile RuntimeException exception;
        private volatile boolean killRunnable = false;
        private Semaphore runCompletedSemaphore = new Semaphore(0);
        private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
        private AtomicLong feeder = new AtomicLong(0);

        public IteratorRunnable(boolean z) {
            this.isAlive = true;
            this.isAlive = z;
        }

        public boolean hasLatch() {
            if (this.feeder.get() > 0 || !AsyncMultiDataSetIterator.this.queue.isEmpty()) {
                return true;
            }
            try {
                this.lock.readLock().lock();
                boolean z = (!AsyncMultiDataSetIterator.this.iterator.hasNext() && this.feeder.get() == 0 && AsyncMultiDataSetIterator.this.queue.isEmpty()) ? false : true;
                if (!this.isAlive) {
                    return z;
                }
                while (this.isAlive) {
                    z = (this.feeder.get() == 0 && AsyncMultiDataSetIterator.this.queue.isEmpty() && !AsyncMultiDataSetIterator.this.iterator.hasNext()) ? false : true;
                    if (z) {
                        return true;
                    }
                }
                return z;
            } finally {
                this.lock.readLock().unlock();
            }
        }

        @Override // java.lang.Runnable
        public void run() {
            while (!this.killRunnable && AsyncMultiDataSetIterator.this.iterator.hasNext()) {
                try {
                    this.feeder.incrementAndGet();
                    this.lock.writeLock().lock();
                    MultiDataSet next = AsyncMultiDataSetIterator.this.iterator.next();
                    if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
                    }
                    this.lock.writeLock().unlock();
                    AsyncMultiDataSetIterator.this.queue.put(next);
                } catch (RuntimeException e) {
                    this.exception = e;
                    if (this.lock.writeLock().isHeldByCurrentThread()) {
                        this.lock.writeLock().unlock();
                    }
                    return;
                } catch (InterruptedException e2) {
                    if (this.killRunnable) {
                        return;
                    }
                    this.exception = new RuntimeException("Runnable interrupted unexpectedly", e2);
                    return;
                } finally {
                    this.isAlive = false;
                    this.runCompletedSemaphore.release();
                }
            }
            this.isAlive = false;
        }
    }

    public AsyncMultiDataSetIterator(MultiDataSetIterator multiDataSetIterator, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Queue size must be > 0");
        }
        i = i < 2 ? 2 : i;
        this.iterator = multiDataSetIterator;
        if (this.iterator.resetSupported()) {
            this.iterator.reset();
        }
        this.queue = new LinkedBlockingQueue<>(i);
        this.runnable = new IteratorRunnable(multiDataSetIterator.hasNext());
        this.thread = new Thread(this.runnable);
        Nd4j.getAffinityManager().attachThreadToDevice(this.thread, Nd4j.getAffinityManager().getDeviceForCurrentThread());
        this.thread.setDaemon(true);
        this.thread.start();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public MultiDataSet next(int i) {
        throw new UnsupportedOperationException("Next(int) not supported for AsyncDataSetIterator");
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
        this.iterator.setPreProcessor(multiDataSetPreProcessor);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean resetSupported() {
        return this.iterator.resetSupported();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean asyncSupported() {
        return false;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void reset() {
        if (!resetSupported()) {
            throw new UnsupportedOperationException("Cannot reset Async iterator wrapping iterator that does not support reset");
        }
        this.runnable.killRunnable = true;
        if (this.runnable.isAlive) {
            this.thread.interrupt();
        }
        try {
            this.runnable.runCompletedSemaphore.tryAcquire(5L, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
        }
        this.queue.clear();
        this.iterator.reset();
        this.runnable = new IteratorRunnable(this.iterator.hasNext());
        this.thread = new Thread(this.runnable);
        Nd4j.getAffinityManager().attachThreadToDevice(this.thread, Nd4j.getAffinityManager().getDeviceForCurrentThread());
        this.thread.setDaemon(true);
        this.thread.start();
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        if (!this.queue.isEmpty()) {
            return true;
        }
        if (this.runnable.isAlive) {
            return this.runnable.hasLatch();
        }
        if (this.runnable.killRunnable || this.runnable.exception == null) {
            return this.runnable.hasLatch();
        }
        throw this.runnable.exception;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public MultiDataSet next() {
        if (!hasNext()) {
            throw new NoSuchElementException();
        }
        if (this.runnable.exception != null) {
            throw this.runnable.exception;
        }
        if (!this.queue.isEmpty()) {
            this.runnable.feeder.decrementAndGet();
            return this.queue.poll();
        }
        while (this.runnable.exception == null) {
            try {
                MultiDataSet poll = this.queue.poll(5L, TimeUnit.SECONDS);
                if (poll != null) {
                    this.runnable.feeder.decrementAndGet();
                    return poll;
                }
                if (this.runnable.killRunnable) {
                    throw new ConcurrentModificationException("Reset while next() is waiting for element?");
                }
                if (!this.runnable.isAlive && this.queue.isEmpty()) {
                    if (this.runnable.exception != null) {
                        throw new RuntimeException("Exception thrown in base iterator", this.runnable.exception);
                    }
                    throw new IllegalStateException("Unexpected state occurred for AsyncMultiDataSetIterator: runnable died or no data available");
                }
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        throw this.runnable.exception;
    }

    @Override // java.util.Iterator
    public void remove() {
    }

    public void shutdown() {
        if (this.thread.isAlive()) {
            this.runnable.killRunnable = true;
            this.thread.interrupt();
        }
    }
}
