package org.deeplearning4j.datasets.datavec;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.SelfWritableConverter;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.ConcatenatingRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.class */
public class RecordReaderDataSetIterator implements DataSetIterator {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) RecordReaderDataSetIterator.class);
    private static final String READER_KEY = "reader";
    protected RecordReader recordReader;
    protected WritableConverter converter;
    protected int batchSize;
    protected int maxNumBatches;
    protected int batchNum;
    protected int labelIndex;
    protected int labelIndexTo;
    protected int numPossibleLabels;
    protected Iterator<List<Writable>> sequenceIter;
    protected DataSet last;
    protected boolean useCurrent;
    protected boolean regression;
    protected DataSetPreProcessor preProcessor;
    private boolean collectMetaData;
    private RecordReaderMultiDataSetIterator underlying;
    private boolean underlyingIsDisjoint;

    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator$Builder.class */
    public static class Builder {
        protected RecordReader recordReader;
        protected WritableConverter converter;
        protected int batchSize;
        protected DataSetPreProcessor preProcessor;
        protected int maxNumBatches = -1;
        protected int labelIndex = -1;
        protected int labelIndexTo = -1;
        protected int numPossibleLabels = -1;
        protected boolean regression = false;
        private boolean collectMetaData = false;
        private boolean clOrRegCalled = false;

        public Builder(@NonNull RecordReader recordReader, int i) {
            if (recordReader == null) {
                throw new NullPointerException("rr is marked @NonNull but is null");
            }
            this.recordReader = recordReader;
            this.batchSize = i;
        }

        public Builder writableConverter(WritableConverter writableConverter) {
            this.converter = writableConverter;
            return this;
        }

        public Builder maxNumBatches(int i) {
            this.maxNumBatches = i;
            return this;
        }

        public Builder regression(int i) {
            return regression(i, i);
        }

        public Builder regression(int i, int i2) {
            this.labelIndex = i;
            this.labelIndexTo = i2;
            this.regression = true;
            this.clOrRegCalled = true;
            return this;
        }

        public Builder classification(int i, int i2) {
            this.labelIndex = i;
            this.labelIndexTo = i;
            this.numPossibleLabels = i2;
            this.regression = false;
            this.clOrRegCalled = true;
            return this;
        }

        public Builder preProcessor(DataSetPreProcessor dataSetPreProcessor) {
            this.preProcessor = dataSetPreProcessor;
            return this;
        }

        public Builder collectMetaData(boolean z) {
            this.collectMetaData = z;
            return this;
        }

        public RecordReaderDataSetIterator build() {
            return new RecordReaderDataSetIterator(this);
        }
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i) {
        this(recordReader, new SelfWritableConverter(), i, -1, -1, recordReader.getLabels() == null ? -1 : recordReader.getLabels().size(), -1, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2, int i3) {
        this(recordReader, new SelfWritableConverter(), i, i2, i2, i3, -1, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2, int i3, int i4) {
        this(recordReader, new SelfWritableConverter(), i, i2, i2, i3, i4, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2, int i3, boolean z) {
        this(recordReader, new SelfWritableConverter(), i, i2, i3, -1, -1, z);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2, int i3, int i4, int i5, boolean z) {
        this.batchSize = 10;
        this.maxNumBatches = -1;
        this.batchNum = 0;
        this.labelIndex = -1;
        this.labelIndexTo = -1;
        this.numPossibleLabels = -1;
        this.useCurrent = false;
        this.regression = false;
        this.collectMetaData = false;
        this.recordReader = recordReader;
        this.converter = writableConverter;
        this.batchSize = i;
        this.maxNumBatches = i5;
        this.labelIndex = i2;
        this.labelIndexTo = i3;
        this.numPossibleLabels = i4;
        this.regression = z;
    }

    protected RecordReaderDataSetIterator(Builder builder) {
        this.batchSize = 10;
        this.maxNumBatches = -1;
        this.batchNum = 0;
        this.labelIndex = -1;
        this.labelIndexTo = -1;
        this.numPossibleLabels = -1;
        this.useCurrent = false;
        this.regression = false;
        this.collectMetaData = false;
        this.recordReader = builder.recordReader;
        this.converter = builder.converter;
        this.batchSize = builder.batchSize;
        this.maxNumBatches = builder.maxNumBatches;
        this.labelIndex = builder.labelIndex;
        this.labelIndexTo = builder.labelIndexTo;
        this.numPossibleLabels = builder.numPossibleLabels;
        this.regression = builder.regression;
        this.preProcessor = builder.preProcessor;
    }

    public void setCollectMetaData(boolean z) {
        if (this.underlying != null) {
            this.underlying.setCollectMetaData(z);
        }
        this.collectMetaData = z;
    }

    private void initializeUnderlying() {
        if (this.underlying == null) {
            initializeUnderlying(this.recordReader.nextRecord());
        }
    }

    private void initializeUnderlying(Record record) {
        int i;
        int i2;
        int size = record.getRecord().size();
        if (this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = size - 1;
            this.labelIndexTo = this.labelIndex;
        }
        if (this.recordReader.resetSupported()) {
            this.recordReader.reset();
        } else {
            this.recordReader = new ConcatenatingRecordReader(new CollectionRecordReader(Collections.singletonList(record.getRecord())), this.recordReader);
        }
        RecordReaderMultiDataSetIterator.Builder builder = new RecordReaderMultiDataSetIterator.Builder(this.batchSize);
        if (this.recordReader instanceof SequenceRecordReader) {
            builder.addSequenceReader(READER_KEY, (SequenceRecordReader) this.recordReader);
        } else {
            builder.addReader(READER_KEY, this.recordReader);
        }
        if (this.regression) {
            builder.addOutput(READER_KEY, this.labelIndex, this.labelIndexTo);
        } else if (this.numPossibleLabels >= 1) {
            builder.addOutputOneHot(READER_KEY, this.labelIndex, this.numPossibleLabels);
        }
        if (this.labelIndex >= 0 && (this.labelIndex == 0 || this.labelIndexTo == size - 1)) {
            if (this.labelIndex < 0) {
                i = 0;
                i2 = size - 1;
            } else if (this.labelIndex == 0) {
                i = this.labelIndexTo + 1;
                i2 = size - 1;
            } else {
                i = 0;
                i2 = this.labelIndex - 1;
            }
            builder.addInput(READER_KEY, i, i2);
            this.underlyingIsDisjoint = false;
        } else if (this.labelIndex >= 0) {
            Preconditions.checkState(this.labelIndex < record.getRecord().size(), "Invalid label (from) index: index must be in range 0 to first record size of (0 to %s inclusive), got %s", record.getRecord().size() - 1, this.labelIndex);
            Preconditions.checkState(this.labelIndexTo < record.getRecord().size(), "Invalid label (to) index: index must be in range 0 to first record size of (0 to %s inclusive), got %s", record.getRecord().size() - 1, this.labelIndexTo);
            builder.addInput(READER_KEY, 0, this.labelIndex - 1);
            builder.addInput(READER_KEY, this.labelIndexTo + 1, size - 1);
            this.underlyingIsDisjoint = true;
        } else {
            builder.addInput(READER_KEY);
            this.underlyingIsDisjoint = false;
        }
        this.underlying = builder.build();
        if (this.collectMetaData) {
            this.underlying.setCollectMetaData(true);
        }
    }

    private DataSet mdsToDataSet(MultiDataSet multiDataSet) {
        INDArray orNull;
        INDArray orNull2;
        if (this.underlyingIsDisjoint) {
            INDArray orNull3 = getOrNull(multiDataSet.getFeatures(), 0);
            INDArray orNull4 = getOrNull(multiDataSet.getFeatures(), 1);
            orNull2 = getOrNull(multiDataSet.getFeaturesMaskArrays(), 0);
            orNull = Nd4j.hstack(orNull3, orNull4);
        } else {
            orNull = getOrNull(multiDataSet.getFeatures(), 0);
            orNull2 = getOrNull(multiDataSet.getFeaturesMaskArrays(), 0);
        }
        DataSet dataSet = new DataSet(orNull, getOrNull(multiDataSet.getLabels(), 0), orNull2, getOrNull(multiDataSet.getLabelsMaskArrays(), 0));
        if (this.collectMetaData) {
            List<Serializable> exampleMetaData = multiDataSet.getExampleMetaData();
            ArrayList arrayList = new ArrayList(exampleMetaData.size());
            Iterator<Serializable> it2 = exampleMetaData.iterator();
            while (it2.hasNext()) {
                arrayList.add(((RecordMetaDataComposableMap) it2.next()).getMeta().get(READER_KEY));
            }
            dataSet.setExampleMetaData(arrayList);
        }
        if (this.labelIndex == -1 && this.numPossibleLabels == -1 && dataSet.getLabels() == null) {
            dataSet.setLabels(dataSet.getFeatures());
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet);
        }
        return dataSet;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSet next(int i) {
        if (this.useCurrent) {
            this.useCurrent = false;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess(this.last);
            }
            return this.last;
        }
        if (this.underlying == null) {
            initializeUnderlying();
        }
        this.batchNum++;
        return mdsToDataSet(this.underlying.next(i));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static INDArray getOrNull(INDArray[] iNDArrayArr, int i) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            return null;
        }
        return iNDArrayArr[i];
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int inputColumns() {
        if (this.last != null) {
            return this.last.numInputs();
        }
        DataSet next = next();
        this.last = next;
        this.useCurrent = true;
        return next.numInputs();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int totalOutcomes() {
        if (this.last != null) {
            return this.last.numOutcomes();
        }
        DataSet next = next();
        this.last = next;
        this.useCurrent = true;
        return next.numOutcomes();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean resetSupported() {
        if (this.underlying == null) {
            initializeUnderlying();
        }
        return this.underlying.resetSupported();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean asyncSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void reset() {
        this.batchNum = 0;
        if (this.underlying != null) {
            this.underlying.reset();
        }
        this.last = null;
        this.useCurrent = false;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int batch() {
        return this.batchSize;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return ((this.sequenceIter != null && this.sequenceIter.hasNext()) || this.recordReader.hasNext()) && (this.maxNumBatches < 0 || this.batchNum < this.maxNumBatches);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public DataSet next() {
        return next(this.batchSize);
    }

    @Override // java.util.Iterator
    public void remove() {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public List<String> getLabels() {
        return this.recordReader.getLabels();
    }

    public DataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        if (this.underlying == null) {
            initializeUnderlying(this.recordReader.loadFromMetaData(list.get(0)));
        }
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<RecordMetaData> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList.add(new RecordMetaDataComposableMap(Collections.singletonMap(READER_KEY, it2.next())));
        }
        return mdsToDataSet(this.underlying.loadFromMetaData(arrayList));
    }

    public RecordReader getRecordReader() {
        return this.recordReader;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public boolean isCollectMetaData() {
        return this.collectMetaData;
    }
}
