package org.deeplearning4j.models.sequencevectors.transformers.impl.iterables;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.AsyncLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.class */
public class ParallelTransformerIterator extends BasicTransformerIterator {
    protected static final int capacity = 1024;
    protected BlockingQueue<Future<Sequence<VocabWord>>> buffer;
    protected AtomicBoolean underlyingHas;
    protected AtomicInteger processing;
    private ExecutorService executorService;
    private static final int PREFETCH_SIZE = 100;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ParallelTransformerIterator.class);
    protected static final AtomicInteger count = new AtomicInteger(0);

    /* loaded from: input_file:org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator$CallableTransformer.class */
    private static class CallableTransformer implements Callable<Sequence<VocabWord>> {
        private LabelledDocument document;
        private SentenceTransformer transformer;

        public CallableTransformer(LabelledDocument labelledDocument, SentenceTransformer sentenceTransformer) {
            this.transformer = sentenceTransformer;
            this.document = labelledDocument;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Sequence<VocabWord> call() {
            Sequence<VocabWord> sequence = new Sequence<>();
            if (this.document != null && this.document.getContent() != null) {
                sequence = this.transformer.transformToSequence(this.document.getContent());
                if (this.document.getLabels() != null) {
                    for (String str : this.document.getLabels()) {
                        if (str != null && !str.isEmpty()) {
                            sequence.addSequenceLabel(new VocabWord(1.0d, str));
                        }
                    }
                }
            }
            return sequence;
        }
    }

    public ParallelTransformerIterator(@NonNull LabelAwareIterator labelAwareIterator, @NonNull SentenceTransformer sentenceTransformer) {
        this(labelAwareIterator, sentenceTransformer, true);
        if (labelAwareIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (sentenceTransformer == null) {
            throw new NullPointerException("transformer is marked @NonNull but is null");
        }
    }

    private void prefetchIterator() {
    }

    public ParallelTransformerIterator(@NonNull LabelAwareIterator labelAwareIterator, @NonNull SentenceTransformer sentenceTransformer, boolean z) {
        super(new AsyncLabelAwareIterator(labelAwareIterator, 512), sentenceTransformer);
        this.buffer = new LinkedBlockingQueue(1024);
        this.underlyingHas = new AtomicBoolean(true);
        this.processing = new AtomicInteger(0);
        if (labelAwareIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (sentenceTransformer == null) {
            throw new NullPointerException("transformer is marked @NonNull but is null");
        }
        this.allowMultithreading = z;
        this.executorService = Executors.newFixedThreadPool(z ? Math.max(Runtime.getRuntime().availableProcessors(), 2) : 1);
        prefetchIterator();
    }

    @Override // org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator
    public void reset() {
        this.executorService.shutdownNow();
        this.iterator.reset();
        this.underlyingHas.set(true);
        prefetchIterator();
        this.buffer.clear();
        this.executorService = Executors.newFixedThreadPool(this.allowMultithreading ? Math.max(Runtime.getRuntime().availableProcessors(), 2) : 1);
    }

    public void shutdown() {
        this.executorService.shutdown();
    }

    @Override // org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator, java.util.Iterator
    public boolean hasNext() {
        if (this.buffer.size() < 1024 && this.iterator.hasNextDocument()) {
            try {
                this.buffer.put(this.executorService.submit(new CallableTransformer(this.iterator.nextDocument(), this.sentenceTransformer)));
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        return !this.buffer.isEmpty() || this.processing.get() > 0;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator, java.util.Iterator
    public Sequence<VocabWord> next() {
        try {
            this.processing.incrementAndGet();
            Sequence<VocabWord> sequence = this.buffer.take().get();
            this.processing.decrementAndGet();
            return sequence;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
