package org.deeplearning4j.datasets.datavec;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.SelfWritableConverter;
import org.datavec.api.io.converters.WritableConverterException;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.datavec.common.data.NDArrayWritable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.class */
public class RecordReaderDataSetIterator implements DataSetIterator {
    protected RecordReader recordReader;
    protected WritableConverter converter;
    protected int batchSize;
    protected int maxNumBatches;
    protected int batchNum;
    protected int labelIndex;
    protected int labelIndexTo;
    protected int numPossibleLabels;
    protected Iterator<List<Writable>> sequenceIter;
    protected DataSet last;
    protected boolean useCurrent;
    protected boolean regression;
    protected DataSetPreProcessor preProcessor;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Deprecated
    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2) {
        this(recordReader, (WritableConverter) new SelfWritableConverter(), 10, i, i2);
    }

    @Deprecated
    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter) {
        this(recordReader, writableConverter, 10, -1, -1);
    }

    @Deprecated
    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2) {
        this(recordReader, writableConverter, 10, i, i2);
    }

    @Deprecated
    public RecordReaderDataSetIterator(RecordReader recordReader) {
        this(recordReader, (WritableConverter) new SelfWritableConverter());
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i) {
        this(recordReader, writableConverter, i, -1, recordReader.getLabels() == null ? -1 : recordReader.getLabels().size());
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i) {
        this(recordReader, (WritableConverter) new SelfWritableConverter(), i, -1, recordReader.getLabels() == null ? -1 : recordReader.getLabels().size());
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2, int i3) {
        this(recordReader, (WritableConverter) new SelfWritableConverter(), i, i2, i3);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2, int i3, boolean z) {
        this(recordReader, writableConverter, i, i2, i3, -1, z);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2, int i3) {
        this(recordReader, writableConverter, i, i2, i3, -1, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2, int i3, int i4) {
        this(recordReader, new SelfWritableConverter(), i, i2, i3, i4, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2, int i3, boolean z) {
        this(recordReader, new SelfWritableConverter(), i, i2, i3, -1, -1, z);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2, int i3, int i4, boolean z) {
        this(recordReader, writableConverter, i, i2, i2, i3, i4, z);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2, int i3, int i4, int i5, boolean z) {
        this.batchSize = 10;
        this.maxNumBatches = -1;
        this.batchNum = 0;
        this.labelIndex = -1;
        this.labelIndexTo = -1;
        this.numPossibleLabels = -1;
        this.useCurrent = false;
        this.regression = false;
        this.recordReader = recordReader;
        this.converter = writableConverter;
        this.batchSize = i;
        this.maxNumBatches = i5;
        this.labelIndex = i2;
        this.labelIndexTo = i3;
        this.numPossibleLabels = i4;
        this.regression = z;
    }

    public DataSet next(int i) {
        if (this.useCurrent) {
            this.useCurrent = false;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess(this.last);
            }
            return this.last;
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i && hasNext(); i2++) {
            if (this.recordReader instanceof SequenceRecordReader) {
                if (this.sequenceIter == null || !this.sequenceIter.hasNext()) {
                    this.sequenceIter = this.recordReader.sequenceRecord().iterator();
                }
                arrayList.add(getDataSet(this.sequenceIter.next()));
            } else {
                arrayList.add(getDataSet(this.recordReader.next()));
            }
        }
        this.batchNum++;
        if (arrayList.isEmpty()) {
            return new DataSet();
        }
        DataSet merge = DataSet.merge(arrayList);
        this.last = merge;
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(merge);
        }
        if (this.recordReader.getLabels() != null) {
            merge.setLabelNames(this.recordReader.getLabels());
        }
        return merge;
    }

    private DataSet getDataSet(List<Writable> list) {
        List<Writable> arrayList = list instanceof List ? list : new ArrayList(list);
        if (this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = list.size() - 1;
        }
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        int i = 0;
        int i2 = 0;
        if (arrayList.size() == 2 && (arrayList.get(1) instanceof NDArrayWritable) && (arrayList.get(0) instanceof NDArrayWritable) && arrayList.get(0) == arrayList.get(1)) {
            NDArrayWritable nDArrayWritable = arrayList.get(0);
            return new DataSet(nDArrayWritable.get(), nDArrayWritable.get());
        }
        if (arrayList.size() == 2 && (arrayList.get(0) instanceof NDArrayWritable)) {
            return new DataSet(arrayList.get(0).get(), !this.regression ? FeatureUtil.toOutcomeVector((int) Double.parseDouble(arrayList.get(1).toString()), this.numPossibleLabels) : Nd4j.scalar(Double.parseDouble(arrayList.get(1).toString())));
        }
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            Writable writable = arrayList.get(i3);
            if ((writable instanceof NDArrayWritable) || !writable.toString().isEmpty()) {
                if (this.regression && i3 >= this.labelIndex && i3 <= this.labelIndexTo) {
                    if (iNDArray == null) {
                        iNDArray = Nd4j.create(1, (this.labelIndexTo - this.labelIndex) + 1);
                    }
                    int i4 = i2;
                    i2++;
                    iNDArray.putScalar(i4, writable.toDouble());
                } else if (this.labelIndex < 0 || i3 != this.labelIndex) {
                    try {
                        double d = writable.toDouble();
                        if (iNDArray2 == null) {
                            iNDArray2 = (!this.regression || this.labelIndex < 0) ? Nd4j.create(this.labelIndex >= 0 ? arrayList.size() - 1 : arrayList.size()) : Nd4j.create(1, arrayList.size() - ((this.labelIndexTo - this.labelIndex) + 1));
                        }
                        int i5 = i;
                        i++;
                        iNDArray2.putScalar(i5, d);
                    } catch (UnsupportedOperationException e) {
                        if (!(writable instanceof NDArrayWritable)) {
                            throw e;
                        }
                        if (!$assertionsDisabled && iNDArray2 != null) {
                            throw new AssertionError();
                        }
                        iNDArray2 = ((NDArrayWritable) writable).get();
                    }
                } else {
                    if (this.converter != null) {
                        try {
                            writable = this.converter.convert(writable);
                        } catch (WritableConverterException e2) {
                            e2.printStackTrace();
                        }
                    }
                    if (this.numPossibleLabels < 1) {
                        throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
                    }
                    if (this.regression) {
                        iNDArray = Nd4j.scalar(writable.toDouble());
                    } else {
                        int i6 = writable.toInt();
                        if (i6 >= this.numPossibleLabels) {
                            i6--;
                        }
                        iNDArray = FeatureUtil.toOutcomeVector(i6, this.numPossibleLabels);
                    }
                }
            }
        }
        return new DataSet(iNDArray2, this.labelIndex >= 0 ? iNDArray : iNDArray2);
    }

    public int totalExamples() {
        throw new UnsupportedOperationException();
    }

    public int inputColumns() {
        if (this.last != null) {
            return this.last.numInputs();
        }
        DataSet m11next = m11next();
        this.last = m11next;
        this.useCurrent = true;
        return m11next.numInputs();
    }

    public int totalOutcomes() {
        if (this.last != null) {
            return this.last.numOutcomes();
        }
        DataSet m11next = m11next();
        this.last = m11next;
        this.useCurrent = true;
        return m11next.numOutcomes();
    }

    public boolean resetSupported() {
        return true;
    }

    public void reset() {
        this.batchNum = 0;
        this.recordReader.reset();
    }

    public int batch() {
        return this.batchSize;
    }

    public int cursor() {
        throw new UnsupportedOperationException();
    }

    public int numExamples() {
        throw new UnsupportedOperationException();
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    public boolean hasNext() {
        return this.recordReader.hasNext() && (this.maxNumBatches < 0 || this.batchNum < this.maxNumBatches);
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public DataSet m11next() {
        return next(this.batchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException();
    }

    public List<String> getLabels() {
        return this.recordReader.getLabels();
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    static {
        $assertionsDisabled = !RecordReaderDataSetIterator.class.desiredAssertionStatus();
    }
}
