package org.deeplearning4j.spark.datavec.iterator;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/datavec/iterator/IteratorUtils.class */
public class IteratorUtils {

    /* loaded from: input_file:org/deeplearning4j/spark/datavec/iterator/IteratorUtils$CombineFunction.class */
    private static class CombineFunction implements Function<Tuple2<Writable, Iterable<DataVecRecord>>, DataVecRecords> {
        private int expNumRecords;
        private int expNumSeqRecords;

        public DataVecRecords call(Tuple2<Writable, Iterable<DataVecRecord>> tuple2) throws Exception {
            List[] listArr = this.expNumRecords > 0 ? new List[this.expNumRecords] : null;
            List[] listArr2 = this.expNumSeqRecords > 0 ? new List[this.expNumSeqRecords] : null;
            for (DataVecRecord dataVecRecord : (Iterable) tuple2._2()) {
                if (dataVecRecord.getRecord() != null) {
                    listArr[dataVecRecord.getReaderIdx()] = dataVecRecord.getRecord();
                } else {
                    listArr2[dataVecRecord.getReaderIdx()] = dataVecRecord.getSeqRecord();
                }
            }
            if (listArr != null) {
                for (int i = 0; i < listArr.length; i++) {
                    if (listArr[i] == null) {
                        throw new IllegalStateException("Encountered null records for input index " + i);
                    }
                }
            }
            if (listArr2 != null) {
                for (int i2 = 0; i2 < listArr2.length; i2++) {
                    if (listArr2[i2] == null) {
                        throw new IllegalStateException("Encountered null sequence records for input index " + i2);
                    }
                }
            }
            return new DataVecRecords(listArr == null ? null : Arrays.asList(listArr), listArr2 == null ? null : Arrays.asList(listArr2));
        }

        public CombineFunction(int i, int i2) {
            this.expNumRecords = i;
            this.expNumSeqRecords = i2;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/spark/datavec/iterator/IteratorUtils$FilterMissingFn.class */
    private static class FilterMissingFn implements Function<Tuple2<Writable, Iterable<DataVecRecord>>, Boolean> {
        private final int expNumRec;
        private final int expNumSeqRec;
        private transient ThreadLocal<Set<Integer>> recIdxs;
        private transient ThreadLocal<Set<Integer>> seqRecIdxs;

        private FilterMissingFn(int i, int i2) {
            this.expNumRec = i;
            this.expNumSeqRec = i2;
        }

        public Boolean call(Tuple2<Writable, Iterable<DataVecRecord>> tuple2) throws Exception {
            if (this.recIdxs == null) {
                this.recIdxs = new ThreadLocal<>();
            }
            if (this.seqRecIdxs == null) {
                this.seqRecIdxs = new ThreadLocal<>();
            }
            Set<Integer> set = this.recIdxs.get();
            if (set == null) {
                set = new HashSet();
                this.recIdxs.set(set);
            }
            Set<Integer> set2 = this.seqRecIdxs.get();
            if (set2 == null) {
                set2 = new HashSet();
                this.seqRecIdxs.set(set2);
            }
            for (DataVecRecord dataVecRecord : (Iterable) tuple2._2()) {
                if (dataVecRecord.getRecord() != null) {
                    set.add(Integer.valueOf(dataVecRecord.getReaderIdx()));
                } else if (dataVecRecord.getSeqRecord() != null) {
                    set2.add(Integer.valueOf(dataVecRecord.getReaderIdx()));
                }
            }
            int size = set.size();
            int size2 = set2.size();
            set.clear();
            set2.clear();
            return Boolean.valueOf(size == this.expNumRec && size2 == this.expNumSeqRec);
        }

        public FilterMissingFn(int i, int i2, ThreadLocal<Set<Integer>> threadLocal, ThreadLocal<Set<Integer>> threadLocal2) {
            this.expNumRec = i;
            this.expNumSeqRec = i2;
            this.recIdxs = threadLocal;
            this.seqRecIdxs = threadLocal2;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/spark/datavec/iterator/IteratorUtils$MapToPairFn.class */
    private static class MapToPairFn implements PairFunction<List<Writable>, Writable, DataVecRecord> {
        private int readerIdx;
        private int keyIndex;

        public Tuple2<Writable, DataVecRecord> call(List<Writable> list) throws Exception {
            return new Tuple2<>(list.get(this.keyIndex), new DataVecRecord(this.readerIdx, list, null));
        }

        public MapToPairFn(int i, int i2) {
            this.readerIdx = i;
            this.keyIndex = i2;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/spark/datavec/iterator/IteratorUtils$MapToPairSeqFn.class */
    private static class MapToPairSeqFn implements PairFunction<List<List<Writable>>, Writable, DataVecRecord> {
        private int readerIdx;
        private int keyIndex;

        public Tuple2<Writable, DataVecRecord> call(List<List<Writable>> list) throws Exception {
            if (list.isEmpty()) {
                throw new IllegalStateException("Sequence of length 0 encountered");
            }
            return new Tuple2<>(list.get(0).get(this.keyIndex), new DataVecRecord(this.readerIdx, null, list));
        }

        public MapToPairSeqFn(int i, int i2) {
            this.readerIdx = i;
            this.keyIndex = i2;
        }
    }

    public static JavaRDD<MultiDataSet> mapRRMDSI(JavaRDD<List<Writable>> javaRDD, RecordReaderMultiDataSetIterator recordReaderMultiDataSetIterator) {
        checkIterator(recordReaderMultiDataSetIterator, 1, 0);
        return mapRRMDSIRecords(javaRDD.map(new Function<List<Writable>, DataVecRecords>() { // from class: org.deeplearning4j.spark.datavec.iterator.IteratorUtils.1
            public DataVecRecords call(List<Writable> list) throws Exception {
                return new DataVecRecords(Collections.singletonList(list), null);
            }
        }), recordReaderMultiDataSetIterator);
    }

    public static JavaRDD<MultiDataSet> mapRRMDSISeq(JavaRDD<List<List<Writable>>> javaRDD, RecordReaderMultiDataSetIterator recordReaderMultiDataSetIterator) {
        checkIterator(recordReaderMultiDataSetIterator, 0, 1);
        return mapRRMDSIRecords(javaRDD.map(new Function<List<List<Writable>>, DataVecRecords>() { // from class: org.deeplearning4j.spark.datavec.iterator.IteratorUtils.2
            public DataVecRecords call(List<List<Writable>> list) throws Exception {
                return new DataVecRecords(null, Collections.singletonList(list));
            }
        }), recordReaderMultiDataSetIterator);
    }

    public static JavaRDD<MultiDataSet> mapRRMDSI(List<JavaRDD<List<Writable>>> list, List<JavaRDD<List<List<Writable>>>> list2, int[] iArr, int[] iArr2, boolean z, RecordReaderMultiDataSetIterator recordReaderMultiDataSetIterator) {
        checkIterator(recordReaderMultiDataSetIterator, list == null ? 0 : list.size(), list2 == null ? 0 : list2.size());
        assertNullOrSameLength(list, iArr, false);
        assertNullOrSameLength(list2, iArr2, true);
        if ((list == null || list.isEmpty()) && (list2 == null || list2.isEmpty())) {
            throw new IllegalArgumentException();
        }
        JavaPairRDD javaPairRDD = null;
        if (list != null) {
            for (int i = 0; i < list.size(); i++) {
                JavaPairRDD mapToPair = list.get(i).mapToPair(new MapToPairFn(i, iArr[i]));
                javaPairRDD = javaPairRDD == null ? mapToPair : javaPairRDD.union(mapToPair);
            }
        }
        if (list2 != null) {
            for (int i2 = 0; i2 < list2.size(); i2++) {
                JavaPairRDD mapToPair2 = list2.get(i2).mapToPair(new MapToPairSeqFn(i2, iArr2[i2]));
                javaPairRDD = javaPairRDD == null ? mapToPair2 : javaPairRDD.union(mapToPair2);
            }
        }
        int length = iArr == null ? 0 : iArr.length;
        int length2 = iArr2 == null ? 0 : iArr2.length;
        JavaPairRDD groupByKey = javaPairRDD.groupByKey();
        if (z) {
            groupByKey = groupByKey.filter(new FilterMissingFn(length, length2));
        }
        return mapRRMDSIRecords(groupByKey.map(new CombineFunction(length, length2)), recordReaderMultiDataSetIterator);
    }

    private static void assertNullOrSameLength(List<?> list, int[] iArr, boolean z) {
        if (list != null && iArr == null) {
            throw new IllegalStateException();
        }
        if (list == null && iArr != null && iArr.length > 0) {
            throw new IllegalStateException();
        }
        if (list != null && list.size() != iArr.length) {
            throw new IllegalStateException();
        }
    }

    public static JavaRDD<MultiDataSet> mapRRMDSIRecords(JavaRDD<DataVecRecords> javaRDD, RecordReaderMultiDataSetIterator recordReaderMultiDataSetIterator) {
        return javaRDD.map(new RRMDSIFunction(recordReaderMultiDataSetIterator));
    }

    private static void checkIterator(RecordReaderMultiDataSetIterator recordReaderMultiDataSetIterator, int i, int i2) {
        Map recordReaders = recordReaderMultiDataSetIterator.getRecordReaders();
        Map sequenceRecordReaders = recordReaderMultiDataSetIterator.getSequenceRecordReaders();
        if (recordReaders != null && recordReaders.size() > i) {
            throw new IllegalStateException("Invalid state: iterator has " + recordReaders.size() + " readers but " + i + " RDDs of List<Writable> were provided");
        }
        if (sequenceRecordReaders != null && sequenceRecordReaders.size() > i2) {
            throw new IllegalStateException("Invalid state: iterator has " + sequenceRecordReaders.size() + " sequence readers but " + i2 + " RDDs of sequences - List<List<Writable>> were provided");
        }
        if (recordReaders != null && recordReaders.size() > 0) {
            for (Map.Entry entry : recordReaders.entrySet()) {
                if (!(entry.getValue() instanceof SparkSourceDummyReader)) {
                    throw new IllegalStateException("Invalid state: expected SparkSourceDummyReader for reader with name \"" + ((String) entry.getKey()) + "\", but got reader type: " + ((String) entry.getKey()).getClass());
                }
            }
        }
        if (sequenceRecordReaders == null || sequenceRecordReaders.size() <= 0) {
            return;
        }
        for (Map.Entry entry2 : sequenceRecordReaders.entrySet()) {
            if (!(entry2.getValue() instanceof SparkSourceDummySeqReader)) {
                throw new IllegalStateException("Invalid state: expected SparkSourceDummySeqReader for sequence reader with name \"" + ((String) entry2.getKey()) + "\", but got reader type: " + ((String) entry2.getKey()).getClass());
            }
        }
    }
}
