package org.deeplearning4j.datasets.iterator;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.NoSuchElementException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/IteratorMultiDataSetIterator.class */
public class IteratorMultiDataSetIterator implements MultiDataSetIterator {
    private final Iterator<MultiDataSet> iterator;
    private final int batchSize;
    private final LinkedList<MultiDataSet> queued = new LinkedList<>();
    private MultiDataSetPreProcessor preProcessor;

    public IteratorMultiDataSetIterator(Iterator<MultiDataSet> it2, int i) {
        this.iterator = it2;
        this.batchSize = i;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return !this.queued.isEmpty() || this.iterator.hasNext();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public MultiDataSet next() {
        return next(this.batchSize);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v106, types: [org.nd4j.linalg.dataset.api.MultiDataSet] */
    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public MultiDataSet next(int i) {
        if (!hasNext()) {
            throw new NoSuchElementException();
        }
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if ((!this.queued.isEmpty() || this.iterator.hasNext()) && i3 < this.batchSize) {
                MultiDataSet removeFirst = !this.queued.isEmpty() ? this.queued.removeFirst() : this.iterator.next();
                int size = (int) removeFirst.getFeatures(0).size(0);
                if (i3 + size <= this.batchSize) {
                    arrayList.add(removeFirst);
                } else {
                    int numFeatureArrays = removeFirst.numFeatureArrays();
                    int numLabelsArrays = removeFirst.numLabelsArrays();
                    INDArray[] iNDArrayArr = new INDArray[numFeatureArrays];
                    INDArray[] iNDArrayArr2 = new INDArray[numLabelsArrays];
                    INDArray[] iNDArrayArr3 = new INDArray[numFeatureArrays];
                    INDArray[] iNDArrayArr4 = new INDArray[numLabelsArrays];
                    INDArray[] iNDArrayArr5 = removeFirst.getFeaturesMaskArrays() != null ? new INDArray[numFeatureArrays] : null;
                    INDArray[] iNDArrayArr6 = removeFirst.getLabelsMaskArrays() != null ? new INDArray[numLabelsArrays] : null;
                    INDArray[] iNDArrayArr7 = removeFirst.getFeaturesMaskArrays() != null ? new INDArray[numFeatureArrays] : null;
                    INDArray[] iNDArrayArr8 = removeFirst.getLabelsMaskArrays() != null ? new INDArray[numLabelsArrays] : null;
                    for (int i4 = 0; i4 < numFeatureArrays; i4++) {
                        INDArray features = removeFirst.getFeatures(i4);
                        iNDArrayArr[i4] = getRange(features, 0, this.batchSize - i3);
                        iNDArrayArr3[i4] = getRange(features, this.batchSize - i3, size);
                        if (iNDArrayArr5 != null) {
                            INDArray featuresMaskArray = removeFirst.getFeaturesMaskArray(i4);
                            iNDArrayArr5[i4] = getRange(featuresMaskArray, 0, this.batchSize - i3);
                            iNDArrayArr7[i4] = getRange(featuresMaskArray, this.batchSize - i3, size);
                        }
                    }
                    for (int i5 = 0; i5 < numLabelsArrays; i5++) {
                        INDArray labels = removeFirst.getLabels(i5);
                        iNDArrayArr2[i5] = getRange(labels, 0, this.batchSize - i3);
                        iNDArrayArr4[i5] = getRange(labels, this.batchSize - i3, size);
                        if (iNDArrayArr6 != null) {
                            INDArray labelsMaskArray = removeFirst.getLabelsMaskArray(i5);
                            iNDArrayArr6[i5] = getRange(labelsMaskArray, 0, this.batchSize - i3);
                            iNDArrayArr8[i5] = getRange(labelsMaskArray, this.batchSize - i3, size);
                        }
                    }
                    org.nd4j.linalg.dataset.MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet(iNDArrayArr, iNDArrayArr2, iNDArrayArr5, iNDArrayArr6);
                    org.nd4j.linalg.dataset.MultiDataSet multiDataSet2 = new org.nd4j.linalg.dataset.MultiDataSet(iNDArrayArr3, iNDArrayArr4, iNDArrayArr7, iNDArrayArr8);
                    arrayList.add(multiDataSet);
                    this.queued.add(multiDataSet2);
                }
                i2 = i3 + size;
            }
        }
        org.nd4j.linalg.dataset.MultiDataSet merge = arrayList.size() == 1 ? (MultiDataSet) arrayList.get(0) : org.nd4j.linalg.dataset.MultiDataSet.merge(arrayList);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(merge);
        }
        return merge;
    }

    private static INDArray getRange(INDArray iNDArray, int i, int i2) {
        if (iNDArray == null) {
            return null;
        }
        int rank = iNDArray.rank();
        switch (rank) {
            case 2:
                return iNDArray.get(NDArrayIndex.interval(i, i2), NDArrayIndex.all());
            case 3:
                return iNDArray.get(NDArrayIndex.interval(i, i2), NDArrayIndex.all(), NDArrayIndex.all());
            case 4:
                return iNDArray.get(NDArrayIndex.interval(i, i2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
            default:
                throw new RuntimeException("Invalid rank: " + rank);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean resetSupported() {
        return false;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean asyncSupported() {
        return false;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void reset() {
        throw new UnsupportedOperationException("Reset not supported");
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
        this.preProcessor = multiDataSetPreProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public MultiDataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    @Override // java.util.Iterator
    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }
}
