package org.deeplearning4j.datasets.datavec;

import java.beans.ConstructorProperties;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.datavec.api.records.Record;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.datavec.common.data.NDArrayWritable;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.class */
public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator {
    private int batchSize;
    private AlignmentMode alignmentMode;
    private Map<String, RecordReader> recordReaders;
    private Map<String, SequenceRecordReader> sequenceRecordReaders;
    private List<SubsetDetails> inputs;
    private List<SubsetDetails> outputs;
    private boolean collectMetaData;
    private MultiDataSetPreProcessor preProcessor;

    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator$AlignmentMode.class */
    public enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END
    }

    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator$Builder.class */
    public static class Builder {
        private int batchSize;
        private AlignmentMode alignmentMode = AlignmentMode.EQUAL_LENGTH;
        private Map<String, RecordReader> recordReaders = new HashMap();
        private Map<String, SequenceRecordReader> sequenceRecordReaders = new HashMap();
        private List<SubsetDetails> inputs = new ArrayList();
        private List<SubsetDetails> outputs = new ArrayList();

        public Builder(int i) {
            this.batchSize = i;
        }

        public Builder addReader(String str, RecordReader recordReader) {
            this.recordReaders.put(str, recordReader);
            return this;
        }

        public Builder addSequenceReader(String str, SequenceRecordReader sequenceRecordReader) {
            this.sequenceRecordReaders.put(str, sequenceRecordReader);
            return this;
        }

        public Builder sequenceAlignmentMode(AlignmentMode alignmentMode) {
            this.alignmentMode = alignmentMode;
            return this;
        }

        public Builder addInput(String str) {
            this.inputs.add(new SubsetDetails(str, true, false, -1, -1, -1));
            return this;
        }

        public Builder addInput(String str, int i, int i2) {
            this.inputs.add(new SubsetDetails(str, false, false, -1, i, i2));
            return this;
        }

        public Builder addInputOneHot(String str, int i, int i2) {
            this.inputs.add(new SubsetDetails(str, false, true, i2, i, -1));
            return this;
        }

        public Builder addOutput(String str) {
            this.outputs.add(new SubsetDetails(str, true, false, -1, -1, -1));
            return this;
        }

        public Builder addOutput(String str, int i, int i2) {
            this.outputs.add(new SubsetDetails(str, false, false, -1, i, i2));
            return this;
        }

        public Builder addOutputOneHot(String str, int i, int i2) {
            this.outputs.add(new SubsetDetails(str, false, true, i2, i, -1));
            return this;
        }

        public RecordReaderMultiDataSetIterator build() {
            if (this.recordReaders.isEmpty() && this.sequenceRecordReaders.isEmpty()) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no readers");
            }
            if (this.batchSize <= 0) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with batch size <= 0");
            }
            if (this.inputs.isEmpty() && this.outputs.isEmpty()) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no inputs/outputs");
            }
            for (SubsetDetails subsetDetails : this.inputs) {
                if (!this.recordReaders.containsKey(subsetDetails.readerName) && !this.sequenceRecordReaders.containsKey(subsetDetails.readerName)) {
                    throw new IllegalStateException("Invalid input name: \"" + subsetDetails.readerName + "\" - no reader found with this name");
                }
            }
            for (SubsetDetails subsetDetails2 : this.outputs) {
                if (!this.recordReaders.containsKey(subsetDetails2.readerName) && !this.sequenceRecordReaders.containsKey(subsetDetails2.readerName)) {
                    throw new IllegalStateException("Invalid output name: \"" + subsetDetails2.readerName + "\" - no reader found with this name");
                }
            }
            return new RecordReaderMultiDataSetIterator(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator$SubsetDetails.class */
    public static class SubsetDetails {
        private final String readerName;
        private final boolean entireReader;
        private final boolean oneHot;
        private final int oneHotNumClasses;
        private final int subsetStart;
        private final int subsetEndInclusive;

        @ConstructorProperties({"readerName", "entireReader", "oneHot", "oneHotNumClasses", "subsetStart", "subsetEndInclusive"})
        public SubsetDetails(String str, boolean z, boolean z2, int i, int i2, int i3) {
            this.readerName = str;
            this.entireReader = z;
            this.oneHot = z2;
            this.oneHotNumClasses = i;
            this.subsetStart = i2;
            this.subsetEndInclusive = i3;
        }
    }

    private RecordReaderMultiDataSetIterator(Builder builder) {
        this.recordReaders = new HashMap();
        this.sequenceRecordReaders = new HashMap();
        this.inputs = new ArrayList();
        this.outputs = new ArrayList();
        this.collectMetaData = false;
        this.batchSize = builder.batchSize;
        this.alignmentMode = builder.alignmentMode;
        this.recordReaders = builder.recordReaders;
        this.sequenceRecordReaders = builder.sequenceRecordReaders;
        this.inputs.addAll(builder.inputs);
        this.outputs.addAll(builder.outputs);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public MultiDataSet next() {
        return next(this.batchSize);
    }

    @Override // java.util.Iterator
    public void remove() {
        throw new UnsupportedOperationException("Remove not supported");
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public MultiDataSet next(int i) {
        List<List<Writable>> sequenceRecord;
        List<Writable> next;
        if (!hasNext()) {
            throw new NoSuchElementException("No next elements");
        }
        Map<String, List<List<Writable>>> hashMap = new HashMap<>();
        Map<String, List<List<List<Writable>>>> hashMap2 = new HashMap<>();
        List<RecordMetaDataComposableMap> arrayList = this.collectMetaData ? new ArrayList<>() : null;
        for (Map.Entry<String, RecordReader> entry : this.recordReaders.entrySet()) {
            RecordReader value = entry.getValue();
            List<List<Writable>> arrayList2 = new ArrayList<>(i);
            for (int i2 = 0; i2 < i && value.hasNext(); i2++) {
                if (this.collectMetaData) {
                    Record nextRecord = value.nextRecord();
                    next = nextRecord.getRecord();
                    if (arrayList.size() <= i2) {
                        arrayList.add(new RecordMetaDataComposableMap(new HashMap()));
                    }
                    arrayList.get(i2).getMeta().put(entry.getKey(), nextRecord.getMetaData());
                } else {
                    next = value.next();
                }
                arrayList2.add(next);
            }
            hashMap.put(entry.getKey(), arrayList2);
        }
        for (Map.Entry<String, SequenceRecordReader> entry2 : this.sequenceRecordReaders.entrySet()) {
            SequenceRecordReader value2 = entry2.getValue();
            List<List<List<Writable>>> arrayList3 = new ArrayList<>(i);
            for (int i3 = 0; i3 < i && value2.hasNext(); i3++) {
                if (this.collectMetaData) {
                    SequenceRecord nextSequence = value2.nextSequence();
                    sequenceRecord = nextSequence.getSequenceRecord();
                    if (arrayList.size() <= i3) {
                        arrayList.add(new RecordMetaDataComposableMap(new HashMap()));
                    }
                    arrayList.get(i3).getMeta().put(entry2.getKey(), nextSequence.getMetaData());
                } else {
                    sequenceRecord = value2.sequenceRecord();
                }
                arrayList3.add(sequenceRecord);
            }
            hashMap2.put(entry2.getKey(), arrayList3);
        }
        return nextMultiDataSet(hashMap, hashMap2, arrayList);
    }

    private MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> map, Map<String, List<List<List<Writable>>>> map2, List<RecordMetaDataComposableMap> list) {
        int i = Integer.MAX_VALUE;
        Iterator<List<List<Writable>>> it2 = map.values().iterator();
        while (it2.hasNext()) {
            i = Math.min(i, it2.next().size());
        }
        Iterator<List<List<List<Writable>>>> it3 = map2.values().iterator();
        while (it3.hasNext()) {
            i = Math.min(i, it3.next().size());
        }
        if (i == Integer.MAX_VALUE) {
            throw new RuntimeException("Error occurred during data set generation: no readers?");
        }
        int[] iArr = null;
        if (this.alignmentMode == AlignmentMode.ALIGN_END) {
            iArr = new int[i];
            Iterator<Map.Entry<String, List<List<List<Writable>>>>> it4 = map2.entrySet().iterator();
            while (it4.hasNext()) {
                List<List<List<Writable>>> value = it4.next().getValue();
                for (int i2 = 0; i2 < value.size() && i2 < i; i2++) {
                    iArr[i2] = Math.max(iArr[i2], value.get(i2).size());
                }
            }
        }
        int i3 = -1;
        if (this.alignmentMode != AlignmentMode.EQUAL_LENGTH) {
            Iterator<Map.Entry<String, List<List<List<Writable>>>>> it5 = map2.entrySet().iterator();
            while (it5.hasNext()) {
                Iterator<List<List<Writable>>> it6 = it5.next().getValue().iterator();
                while (it6.hasNext()) {
                    i3 = Math.max(i3, it6.next().size());
                }
            }
        }
        INDArray[] iNDArrayArr = new INDArray[this.inputs.size()];
        INDArray[] iNDArrayArr2 = new INDArray[this.inputs.size()];
        boolean z = false;
        int i4 = 0;
        for (SubsetDetails subsetDetails : this.inputs) {
            if (map.containsKey(subsetDetails.readerName)) {
                iNDArrayArr[i4] = convertWritables(map.get(subsetDetails.readerName), i, subsetDetails);
            } else {
                Pair<INDArray, INDArray> convertWritablesSequence = convertWritablesSequence(map2.get(subsetDetails.readerName), i, i3, subsetDetails, iArr);
                iNDArrayArr[i4] = convertWritablesSequence.getFirst();
                iNDArrayArr2[i4] = convertWritablesSequence.getSecond();
                if (iNDArrayArr2[i4] != null) {
                    z = true;
                }
            }
            i4++;
        }
        if (!z) {
            iNDArrayArr2 = null;
        }
        INDArray[] iNDArrayArr3 = new INDArray[this.outputs.size()];
        INDArray[] iNDArrayArr4 = new INDArray[this.outputs.size()];
        boolean z2 = false;
        int i5 = 0;
        for (SubsetDetails subsetDetails2 : this.outputs) {
            if (map.containsKey(subsetDetails2.readerName)) {
                iNDArrayArr3[i5] = convertWritables(map.get(subsetDetails2.readerName), i, subsetDetails2);
            } else {
                Pair<INDArray, INDArray> convertWritablesSequence2 = convertWritablesSequence(map2.get(subsetDetails2.readerName), i, i3, subsetDetails2, iArr);
                iNDArrayArr3[i5] = convertWritablesSequence2.getFirst();
                iNDArrayArr4[i5] = convertWritablesSequence2.getSecond();
                if (iNDArrayArr4[i5] != null) {
                    z2 = true;
                }
            }
            i5++;
        }
        if (!z2) {
            iNDArrayArr4 = null;
        }
        org.nd4j.linalg.dataset.MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet(iNDArrayArr, iNDArrayArr3, iNDArrayArr2, iNDArrayArr4);
        if (this.collectMetaData) {
            multiDataSet.setExampleMetaData(list);
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(multiDataSet);
        }
        return multiDataSet;
    }

    private INDArray convertWritables(List<List<Writable>> list, int i, SubsetDetails subsetDetails) {
        INDArray create;
        if (subsetDetails.entireReader) {
            if (list.get(0).size() == 1 && (list.get(0).get(0) instanceof NDArrayWritable)) {
                int[] clone = ArrayUtils.clone(((NDArrayWritable) list.get(0).get(0)).get().shape());
                clone[0] = i;
                create = Nd4j.create(clone);
            } else {
                create = Nd4j.create(i, list.get(0).size());
            }
        } else if (subsetDetails.oneHot) {
            create = Nd4j.zeros(i, subsetDetails.oneHotNumClasses);
        } else if (subsetDetails.subsetStart == subsetDetails.subsetEndInclusive && (list.get(0).get(subsetDetails.subsetStart) instanceof NDArrayWritable)) {
            int[] clone2 = ArrayUtils.clone(((NDArrayWritable) list.get(0).get(subsetDetails.subsetStart)).get().shape());
            clone2[0] = i;
            create = Nd4j.create(clone2);
        } else {
            create = Nd4j.create(i, (subsetDetails.subsetEndInclusive - subsetDetails.subsetStart) + 1);
        }
        for (int i2 = 0; i2 < i; i2++) {
            List<Writable> list2 = list.get(i2);
            if (subsetDetails.entireReader) {
                int i3 = 0;
                for (Writable writable : list2) {
                    try {
                        create.putScalar(i2, i3, writable.toDouble());
                    } catch (UnsupportedOperationException e) {
                        if (!(writable instanceof NDArrayWritable)) {
                            throw e;
                        }
                        putExample(create, ((NDArrayWritable) writable).get(), i2);
                    }
                    i3++;
                }
            } else if (subsetDetails.oneHot) {
                create.putScalar(i2, list2.get(subsetDetails.subsetStart).toInt(), 1.0d);
            } else if (subsetDetails.subsetStart == subsetDetails.subsetEndInclusive && (list2.get(subsetDetails.subsetStart) instanceof NDArrayWritable)) {
                putExample(create, ((NDArrayWritable) list2.get(subsetDetails.subsetStart)).get(), i2);
            } else {
                Iterator<Writable> it2 = list2.iterator();
                for (int i4 = 0; i4 < subsetDetails.subsetStart; i4++) {
                    it2.next();
                }
                int i5 = 0;
                for (int i6 = subsetDetails.subsetStart; i6 <= subsetDetails.subsetEndInclusive; i6++) {
                    Writable next = it2.next();
                    try {
                        create.putScalar(i2, i5, next.toDouble());
                    } catch (UnsupportedOperationException e2) {
                        if (!(next instanceof NDArrayWritable)) {
                            throw e2;
                        }
                        putExample(create, ((NDArrayWritable) next).get(), i2);
                    }
                    i5++;
                }
            }
        }
        return create;
    }

    private void putExample(INDArray iNDArray, INDArray iNDArray2, int i) {
        switch (iNDArray.rank()) {
            case 2:
                iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all()}, iNDArray2);
                return;
            case 3:
                iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()}, iNDArray2);
                return;
            case 4:
                iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, iNDArray2);
                return;
            default:
                throw new RuntimeException("Unexpected rank: " + iNDArray.rank());
        }
    }

    private Pair<INDArray, INDArray> convertWritablesSequence(List<List<List<Writable>>> list, int i, int i2, SubsetDetails subsetDetails, int[] iArr) {
        if (i2 == -1) {
            i2 = list.get(0).size();
        }
        INDArray create = subsetDetails.entireReader ? Nd4j.create(new int[]{i, list.get(0).iterator().next().size(), i2}, 'f') : subsetDetails.oneHot ? Nd4j.create(new int[]{i, subsetDetails.oneHotNumClasses, i2}, 'f') : Nd4j.create(new int[]{i, (subsetDetails.subsetEndInclusive - subsetDetails.subsetStart) + 1, i2}, 'f');
        boolean z = false;
        Iterator<List<List<Writable>>> it2 = list.iterator();
        while (it2.hasNext()) {
            if (it2.next().size() < i2) {
                z = true;
            }
        }
        INDArray ones = z ? Nd4j.ones(i, i2) : null;
        for (int i3 = 0; i3 < i; i3++) {
            List<List<Writable>> list2 = list.get(i3);
            int size = (this.alignmentMode == AlignmentMode.ALIGN_START || this.alignmentMode == AlignmentMode.EQUAL_LENGTH) ? 0 : iArr[i3] - list2.size();
            int i4 = 0;
            for (List<Writable> list3 : list2) {
                int i5 = i4;
                i4++;
                int i6 = size + i5;
                if (subsetDetails.entireReader) {
                    int i7 = 0;
                    for (Writable writable : list3) {
                        try {
                            create.putScalar(i3, i7, i6, writable.toDouble());
                        } catch (UnsupportedOperationException e) {
                            if (!(writable instanceof NDArrayWritable)) {
                                throw e;
                            }
                            create.get(NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.point(i6)).putRow(0, ((NDArrayWritable) writable).get());
                        }
                        i7++;
                    }
                } else if (subsetDetails.oneHot) {
                    Writable writable2 = null;
                    if (list3 instanceof List) {
                        writable2 = list3.get(subsetDetails.subsetStart);
                    } else {
                        Iterator<Writable> it3 = list3.iterator();
                        for (int i8 = 0; i8 <= subsetDetails.subsetStart; i8++) {
                            writable2 = it3.next();
                        }
                    }
                    create.putScalar(i3, writable2.toInt(), i6, 1.0d);
                } else {
                    Iterator<Writable> it4 = list3.iterator();
                    for (int i9 = 0; i9 < subsetDetails.subsetStart; i9++) {
                        it4.next();
                    }
                    int i10 = 0;
                    for (int i11 = subsetDetails.subsetStart; i11 <= subsetDetails.subsetEndInclusive; i11++) {
                        Writable next = it4.next();
                        try {
                            int i12 = i10;
                            i10++;
                            create.putScalar(i3, i12, i6, next.toDouble());
                        } catch (UnsupportedOperationException e2) {
                            if (!(next instanceof NDArrayWritable)) {
                                throw e2;
                            }
                            create.get(NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.point(i6)).putRow(0, ((NDArrayWritable) next).get().get(NDArrayIndex.all(), NDArrayIndex.interval(subsetDetails.subsetStart, subsetDetails.subsetEndInclusive + 1)));
                        }
                    }
                }
            }
            if (z) {
                if (this.alignmentMode == AlignmentMode.ALIGN_END) {
                    for (int i13 = 0; i13 < size; i13++) {
                        ones.putScalar(i3, i13, CMAESOptimizer.DEFAULT_STOPFITNESS);
                    }
                }
                if (this.alignmentMode == AlignmentMode.ALIGN_START) {
                    for (int i14 = i4; i14 < i2; i14++) {
                        ones.putScalar(i3, i14, CMAESOptimizer.DEFAULT_STOPFITNESS);
                    }
                }
            }
        }
        return new Pair<>(create, ones);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
        this.preProcessor = multiDataSetPreProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean resetSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean asyncSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void reset() {
        Iterator<RecordReader> it2 = this.recordReaders.values().iterator();
        while (it2.hasNext()) {
            it2.next().reset();
        }
        Iterator<SequenceRecordReader> it3 = this.sequenceRecordReaders.values().iterator();
        while (it3.hasNext()) {
            it3.next().reset();
        }
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        Iterator<RecordReader> it2 = this.recordReaders.values().iterator();
        while (it2.hasNext()) {
            if (!it2.next().hasNext()) {
                return false;
            }
        }
        Iterator<SequenceRecordReader> it3 = this.sequenceRecordReaders.values().iterator();
        while (it3.hasNext()) {
            if (!it3.next().hasNext()) {
                return false;
            }
        }
        return true;
    }

    public MultiDataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        Map<String, List<List<Writable>>> hashMap = new HashMap<>();
        Map<String, List<List<List<Writable>>>> hashMap2 = new HashMap<>();
        List<RecordMetaDataComposableMap> arrayList = this.collectMetaData ? new ArrayList<>() : null;
        for (Map.Entry<String, RecordReader> entry : this.recordReaders.entrySet()) {
            RecordReader value = entry.getValue();
            ArrayList arrayList2 = new ArrayList();
            Iterator<RecordMetaData> it2 = list.iterator();
            while (it2.hasNext()) {
                arrayList2.add(((RecordMetaDataComposableMap) it2.next()).getMeta().get(entry.getKey()));
            }
            List<Record> loadFromMetaData = value.loadFromMetaData(arrayList2);
            List<List<Writable>> arrayList3 = new ArrayList<>(list.size());
            Iterator<Record> it3 = loadFromMetaData.iterator();
            while (it3.hasNext()) {
                arrayList3.add(it3.next().getRecord());
            }
            hashMap.put(entry.getKey(), arrayList3);
        }
        for (Map.Entry<String, SequenceRecordReader> entry2 : this.sequenceRecordReaders.entrySet()) {
            SequenceRecordReader value2 = entry2.getValue();
            ArrayList arrayList4 = new ArrayList();
            Iterator<RecordMetaData> it4 = list.iterator();
            while (it4.hasNext()) {
                arrayList4.add(((RecordMetaDataComposableMap) it4.next()).getMeta().get(entry2.getKey()));
            }
            List<SequenceRecord> loadSequenceFromMetaData = value2.loadSequenceFromMetaData(arrayList4);
            List<List<List<Writable>>> arrayList5 = new ArrayList<>(list.size());
            Iterator<SequenceRecord> it5 = loadSequenceFromMetaData.iterator();
            while (it5.hasNext()) {
                arrayList5.add(it5.next().getSequenceRecord());
            }
            hashMap2.put(entry2.getKey(), arrayList5);
        }
        return nextMultiDataSet(hashMap, hashMap2, arrayList);
    }

    public boolean isCollectMetaData() {
        return this.collectMetaData;
    }

    public void setCollectMetaData(boolean z) {
        this.collectMetaData = z;
    }
}
