package org.nd4j.linalg.dataset;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/nd4j/linalg/dataset/MultiDataSet.class */
public class MultiDataSet implements org.nd4j.linalg.dataset.api.MultiDataSet {
    private static final ThreadLocal<INDArray> EMPTY_MASK_ARRAY_PLACEHOLDER = new ThreadLocal<>();
    private INDArray[] features;
    private INDArray[] labels;
    private INDArray[] featuresMaskArrays;
    private INDArray[] labelsMaskArrays;
    private List<Serializable> exampleMetaData;

    public MultiDataSet() {
    }

    public MultiDataSet(INDArray iNDArray, INDArray iNDArray2) {
        this(iNDArray, iNDArray2, (INDArray) null, (INDArray) null);
    }

    public MultiDataSet(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        this(iNDArray != null ? new INDArray[]{iNDArray} : null, iNDArray2 != null ? new INDArray[]{iNDArray2} : null, iNDArray3 != null ? new INDArray[]{iNDArray3} : null, iNDArray4 != null ? new INDArray[]{iNDArray4} : null);
    }

    public MultiDataSet(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        this(iNDArrayArr, iNDArrayArr2, (INDArray[]) null, (INDArray[]) null);
    }

    public MultiDataSet(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4) {
        if (iNDArrayArr != null && iNDArrayArr3 != null && iNDArrayArr.length != iNDArrayArr3.length) {
            throw new IllegalArgumentException("Invalid features / features mask arrays combination: features and features mask arrays must not be different lengths");
        }
        if (iNDArrayArr2 != null && iNDArrayArr4 != null && iNDArrayArr2.length != iNDArrayArr4.length) {
            throw new IllegalArgumentException("Invalid labels / labels mask arrays combination: labels and labels mask arrays must not be different lengths");
        }
        this.features = iNDArrayArr;
        this.labels = iNDArrayArr2;
        this.featuresMaskArrays = iNDArrayArr3;
        this.labelsMaskArrays = iNDArrayArr4;
        Nd4j.getExecutioner().commit();
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public List<Serializable> getExampleMetaData() {
        return this.exampleMetaData;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public <T extends Serializable> List<T> getExampleMetaData(Class<T> cls) {
        return (List<T>) this.exampleMetaData;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setExampleMetaData(List<? extends Serializable> list) {
        this.exampleMetaData = list;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public int numFeatureArrays() {
        if (this.features != null) {
            return this.features.length;
        }
        return 0;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public int numLabelsArrays() {
        if (this.labels != null) {
            return this.labels.length;
        }
        return 0;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray[] getFeatures() {
        return this.features;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray getFeatures(int i) {
        return this.features[i];
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setFeatures(INDArray[] iNDArrayArr) {
        this.features = iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setFeatures(int i, INDArray iNDArray) {
        this.features[i] = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray[] getLabels() {
        return this.labels;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray getLabels(int i) {
        return this.labels[i];
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setLabels(INDArray[] iNDArrayArr) {
        this.labels = iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setLabels(int i, INDArray iNDArray) {
        this.labels[i] = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public boolean hasMaskArrays() {
        if (this.featuresMaskArrays == null && this.labelsMaskArrays == null) {
            return false;
        }
        if (this.featuresMaskArrays != null) {
            for (INDArray iNDArray : this.featuresMaskArrays) {
                if (iNDArray != null) {
                    return true;
                }
            }
        }
        if (this.labelsMaskArrays == null) {
            return false;
        }
        for (INDArray iNDArray2 : this.labelsMaskArrays) {
            if (iNDArray2 != null) {
                return true;
            }
        }
        return false;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray[] getFeaturesMaskArrays() {
        return this.featuresMaskArrays;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray getFeaturesMaskArray(int i) {
        if (this.featuresMaskArrays != null) {
            return this.featuresMaskArrays[i];
        }
        return null;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setFeaturesMaskArrays(INDArray[] iNDArrayArr) {
        this.featuresMaskArrays = iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setFeaturesMaskArray(int i, INDArray iNDArray) {
        this.featuresMaskArrays[i] = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray[] getLabelsMaskArrays() {
        return this.labelsMaskArrays;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray getLabelsMaskArray(int i) {
        if (this.labelsMaskArrays != null) {
            return this.labelsMaskArrays[i];
        }
        return null;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setLabelsMaskArray(INDArray[] iNDArrayArr) {
        this.labelsMaskArrays = iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setLabelsMaskArray(int i, INDArray iNDArray) {
        this.labelsMaskArrays[i] = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void save(OutputStream outputStream) throws IOException {
        int length = this.features == null ? 0 : this.features.length;
        int length2 = this.labels == null ? 0 : this.labels.length;
        int length3 = this.featuresMaskArrays == null ? 0 : this.featuresMaskArrays.length;
        int length4 = this.labelsMaskArrays == null ? 0 : this.labelsMaskArrays.length;
        DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(outputStream));
        Throwable th = null;
        try {
            dataOutputStream.writeInt(length);
            dataOutputStream.writeInt(length2);
            dataOutputStream.writeInt(length3);
            dataOutputStream.writeInt(length4);
            saveINDArrays(this.features, dataOutputStream, false);
            saveINDArrays(this.labels, dataOutputStream, false);
            saveINDArrays(this.featuresMaskArrays, dataOutputStream, true);
            saveINDArrays(this.labelsMaskArrays, dataOutputStream, true);
            if (this.exampleMetaData != null && this.exampleMetaData.size() > 0) {
                dataOutputStream.writeInt(1);
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(dataOutputStream);
                objectOutputStream.writeObject(this.exampleMetaData);
                objectOutputStream.flush();
            }
            if (dataOutputStream != null) {
                if (0 == 0) {
                    dataOutputStream.close();
                    return;
                }
                try {
                    dataOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (dataOutputStream != null) {
                if (0 != 0) {
                    try {
                        dataOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    dataOutputStream.close();
                }
            }
            throw th3;
        }
    }

    private void saveINDArrays(INDArray[] iNDArrayArr, DataOutputStream dataOutputStream, boolean z) throws IOException {
        if (iNDArrayArr == null || iNDArrayArr.length <= 0) {
            return;
        }
        for (INDArray iNDArray : iNDArrayArr) {
            if (z && iNDArray == null) {
                INDArray iNDArray2 = EMPTY_MASK_ARRAY_PLACEHOLDER.get();
                if (iNDArray2 == null) {
                    EMPTY_MASK_ARRAY_PLACEHOLDER.set(Nd4j.create(new float[]{-1.0f}));
                    iNDArray2 = EMPTY_MASK_ARRAY_PLACEHOLDER.get();
                }
                iNDArray = iNDArray2;
            }
            Nd4j.write(iNDArray, dataOutputStream);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void save(File file) throws IOException {
        save(new FileOutputStream(file));
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void load(InputStream inputStream) throws IOException {
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(inputStream));
        Throwable th = null;
        try {
            try {
                int readInt = dataInputStream.readInt();
                int readInt2 = dataInputStream.readInt();
                int readInt3 = dataInputStream.readInt();
                int readInt4 = dataInputStream.readInt();
                this.features = loadINDArrays(readInt, dataInputStream, false);
                this.labels = loadINDArrays(readInt2, dataInputStream, false);
                this.featuresMaskArrays = loadINDArrays(readInt3, dataInputStream, true);
                this.labelsMaskArrays = loadINDArrays(readInt4, dataInputStream, true);
                try {
                    if (dataInputStream.readInt() == 1) {
                        try {
                            this.exampleMetaData = (List) new ObjectInputStream(dataInputStream).readObject();
                        } catch (ClassNotFoundException e) {
                            throw new RuntimeException("Error reading metadata from serialized MultiDataSet");
                        }
                    }
                    if (dataInputStream != null) {
                        if (0 == 0) {
                            dataInputStream.close();
                            return;
                        }
                        try {
                            dataInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                } catch (EOFException e2) {
                    if (dataInputStream != null) {
                        if (0 == 0) {
                            dataInputStream.close();
                            return;
                        }
                        try {
                            dataInputStream.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    }
                }
            } catch (Throwable th4) {
                th = th4;
                throw th4;
            }
        } catch (Throwable th5) {
            if (dataInputStream != null) {
                if (th != null) {
                    try {
                        dataInputStream.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    dataInputStream.close();
                }
            }
            throw th5;
        }
    }

    private INDArray[] loadINDArrays(int i, DataInputStream dataInputStream, boolean z) throws IOException {
        INDArray[] iNDArrayArr = null;
        if (i > 0) {
            iNDArrayArr = new INDArray[i];
            for (int i2 = 0; i2 < i; i2++) {
                INDArray read = Nd4j.read(dataInputStream);
                iNDArrayArr[i2] = (z && read.equals(EMPTY_MASK_ARRAY_PLACEHOLDER.get())) ? null : read;
            }
        }
        return iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void load(File file) throws IOException {
        load(new FileInputStream(file));
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public List<org.nd4j.linalg.dataset.api.MultiDataSet> asList() {
        long size = this.features[0].size(0);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            INDArray[] iNDArrayArr = new INDArray[this.features.length];
            INDArray[] iNDArrayArr2 = new INDArray[this.labels.length];
            INDArray[] iNDArrayArr3 = this.featuresMaskArrays != null ? new INDArray[this.featuresMaskArrays.length] : null;
            INDArray[] iNDArrayArr4 = this.labelsMaskArrays != null ? new INDArray[this.labelsMaskArrays.length] : null;
            for (int i2 = 0; i2 < this.features.length; i2++) {
                iNDArrayArr[i2] = getSubsetForExample(this.features[i2], i);
            }
            for (int i3 = 0; i3 < this.labels.length; i3++) {
                iNDArrayArr2[i3] = getSubsetForExample(this.labels[i3], i);
            }
            if (iNDArrayArr3 != null) {
                for (int i4 = 0; i4 < iNDArrayArr3.length; i4++) {
                    if (this.featuresMaskArrays[i4] != null) {
                        iNDArrayArr3[i4] = getSubsetForExample(this.featuresMaskArrays[i4], i);
                    }
                }
            }
            if (iNDArrayArr4 != null) {
                for (int i5 = 0; i5 < iNDArrayArr4.length; i5++) {
                    if (this.labelsMaskArrays[i5] != null) {
                        iNDArrayArr4[i5] = getSubsetForExample(this.labelsMaskArrays[i5], i);
                    }
                }
            }
            arrayList.add(new MultiDataSet(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4));
        }
        return arrayList;
    }

    private static INDArray getSubsetForExample(INDArray iNDArray, int i) {
        switch (iNDArray.rank()) {
            case 2:
                return iNDArray.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all());
            case 3:
                return iNDArray.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all());
            case 4:
                return iNDArray.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
            default:
                throw new IllegalStateException("Cannot get subset for rank " + iNDArray.rank() + " array");
        }
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public MultiDataSet copy() {
        MultiDataSet multiDataSet = new MultiDataSet(copy(getFeatures()), copy(getLabels()));
        if (this.labelsMaskArrays != null) {
            multiDataSet.setLabelsMaskArray(copy(this.labelsMaskArrays));
        }
        if (this.featuresMaskArrays != null) {
            multiDataSet.setFeaturesMaskArrays(copy(this.featuresMaskArrays));
        }
        return multiDataSet;
    }

    private INDArray[] copy(INDArray[] iNDArrayArr) {
        INDArray[] iNDArrayArr2 = new INDArray[iNDArrayArr.length];
        for (int i = 0; i < iNDArrayArr.length; i++) {
            iNDArrayArr2[i] = iNDArrayArr[i].dup();
        }
        return iNDArrayArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v91, types: [java.util.List] */
    public static MultiDataSet merge(Collection<? extends org.nd4j.linalg.dataset.api.MultiDataSet> collection) {
        if (collection.size() == 1) {
            org.nd4j.linalg.dataset.api.MultiDataSet next = collection.iterator().next();
            return next instanceof MultiDataSet ? (MultiDataSet) next : new MultiDataSet(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
        }
        ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet> arrayList = collection instanceof List ? (List) collection : new ArrayList(collection);
        int i = 0;
        Iterator<? extends org.nd4j.linalg.dataset.api.MultiDataSet> it2 = collection.iterator();
        while (it2.hasNext()) {
            if (!it2.next().isEmpty()) {
                i++;
            }
        }
        int numFeatureArrays = ((org.nd4j.linalg.dataset.api.MultiDataSet) arrayList.get(0)).numFeatureArrays();
        int numLabelsArrays = ((org.nd4j.linalg.dataset.api.MultiDataSet) arrayList.get(0)).numLabelsArrays();
        INDArray[][] iNDArrayArr = new INDArray[i][0];
        INDArray[][] iNDArrayArr2 = new INDArray[i][0];
        INDArray[][] iNDArrayArr3 = new INDArray[i][0];
        INDArray[][] iNDArrayArr4 = new INDArray[i][0];
        int i2 = 0;
        for (org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet : arrayList) {
            if (!multiDataSet.isEmpty()) {
                iNDArrayArr[i2] = multiDataSet.getFeatures();
                iNDArrayArr2[i2] = multiDataSet.getLabels();
                iNDArrayArr3[i2] = multiDataSet.getFeaturesMaskArrays();
                iNDArrayArr4[i2] = multiDataSet.getLabelsMaskArrays();
                if (iNDArrayArr[i2] == null || iNDArrayArr[i2].length != numFeatureArrays) {
                    throw new IllegalStateException("Cannot merge MultiDataSets with different number of input arrays: toMerge[0] has " + numFeatureArrays + " input arrays; toMerge[" + i2 + "] has " + (iNDArrayArr[i2] != null ? Integer.valueOf(iNDArrayArr[i2].length) : null) + " arrays");
                }
                if (iNDArrayArr2[i2] == null || iNDArrayArr2[i2].length != numLabelsArrays) {
                    throw new IllegalStateException("Cannot merge MultiDataSets with different number of output arrays: toMerge[0] has " + numLabelsArrays + " output arrays; toMerge[" + i2 + "] has " + (iNDArrayArr2[i2] != null ? Integer.valueOf(iNDArrayArr2[i2].length) : null) + " arrays");
                }
                i2++;
            }
        }
        INDArray[] iNDArrayArr5 = new INDArray[numFeatureArrays];
        INDArray[] iNDArrayArr6 = new INDArray[numLabelsArrays];
        INDArray[] iNDArrayArr7 = new INDArray[numFeatureArrays];
        INDArray[] iNDArrayArr8 = new INDArray[numLabelsArrays];
        boolean z = false;
        for (int i3 = 0; i3 < numFeatureArrays; i3++) {
            Pair<INDArray, INDArray> mergeFeatures = DataSetUtil.mergeFeatures(iNDArrayArr, iNDArrayArr3, i3);
            iNDArrayArr5[i3] = mergeFeatures.getFirst();
            iNDArrayArr7[i3] = mergeFeatures.getSecond();
            if (iNDArrayArr7[i3] != null) {
                z = true;
            }
        }
        if (!z) {
            iNDArrayArr7 = null;
        }
        boolean z2 = false;
        for (int i4 = 0; i4 < numLabelsArrays; i4++) {
            Pair<INDArray, INDArray> mergeLabels = DataSetUtil.mergeLabels(iNDArrayArr2, iNDArrayArr4, i4);
            iNDArrayArr6[i4] = mergeLabels.getFirst();
            iNDArrayArr8[i4] = mergeLabels.getSecond();
            if (iNDArrayArr8[i4] != null) {
                z2 = true;
            }
        }
        if (!z2) {
            iNDArrayArr8 = null;
        }
        return new MultiDataSet(iNDArrayArr5, iNDArrayArr6, iNDArrayArr7, iNDArrayArr8);
    }

    public String toString() {
        int i = 0;
        int i2 = 0;
        if (this.featuresMaskArrays != null) {
            for (INDArray iNDArray : this.featuresMaskArrays) {
                if (iNDArray != null) {
                    i++;
                }
            }
        }
        if (this.labelsMaskArrays != null) {
            for (INDArray iNDArray2 : this.labelsMaskArrays) {
                if (iNDArray2 != null) {
                    i2++;
                }
            }
        }
        StringBuilder sb = new StringBuilder();
        sb.append("MultiDataSet: ").append(numFeatureArrays()).append(" input arrays, ").append(numLabelsArrays()).append(" label arrays, ").append(i).append(" input masks, ").append(i2).append(" label masks");
        for (int i3 = 0; i3 < numFeatureArrays(); i3++) {
            sb.append("\n=== INPUT ").append(i3).append(" ===\n").append(getFeatures(i3).toString().replaceAll(";", "\n"));
            if (getFeaturesMaskArray(i3) != null) {
                sb.append("\n--- INPUT MASK ---\n").append(getFeaturesMaskArray(i3).toString().replaceAll(";", "\n"));
            }
        }
        for (int i4 = 0; i4 < numLabelsArrays(); i4++) {
            sb.append("\n=== LABEL ").append(i4).append(" ===\n").append(getLabels(i4).toString().replaceAll(";", "\n"));
            if (getLabelsMaskArray(i4) != null) {
                sb.append("\n--- LABEL MASK ---\n").append(getLabelsMaskArray(i4).toString().replaceAll(";", "\n"));
            }
        }
        return sb.toString();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MultiDataSet)) {
            return false;
        }
        MultiDataSet multiDataSet = (MultiDataSet) obj;
        if (bothNullOrEqual(this.features, multiDataSet.features) && bothNullOrEqual(this.labels, multiDataSet.labels) && bothNullOrEqual(this.featuresMaskArrays, multiDataSet.featuresMaskArrays)) {
            return bothNullOrEqual(this.labelsMaskArrays, multiDataSet.labelsMaskArrays);
        }
        return false;
    }

    private boolean bothNullOrEqual(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (iNDArrayArr == null && iNDArrayArr2 == null) {
            return true;
        }
        if (iNDArrayArr == null || iNDArrayArr2 == null || iNDArrayArr.length != iNDArrayArr2.length) {
            return false;
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            if (!Objects.equals(iNDArrayArr[i], iNDArrayArr2[i])) {
                return false;
            }
        }
        return true;
    }

    public int hashCode() {
        int i = 0;
        if (this.features != null) {
            for (INDArray iNDArray : this.features) {
                i = (i * 31) + iNDArray.hashCode();
            }
        }
        if (this.labels != null) {
            for (INDArray iNDArray2 : this.labels) {
                i = (i * 31) + iNDArray2.hashCode();
            }
        }
        if (this.featuresMaskArrays != null) {
            for (INDArray iNDArray3 : this.featuresMaskArrays) {
                i = (i * 31) + iNDArray3.hashCode();
            }
        }
        if (this.labelsMaskArrays != null) {
            for (INDArray iNDArray4 : this.labelsMaskArrays) {
                i = (i * 31) + iNDArray4.hashCode();
            }
        }
        return i;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public long getMemoryFootprint() {
        long j = 0;
        INDArray[] iNDArrayArr = this.features;
        int length = iNDArrayArr.length;
        for (int i = 0; i < length; i++) {
            INDArray iNDArray = iNDArrayArr[i];
            j += iNDArray == null ? 0L : iNDArray.length() * Nd4j.sizeOfDataType(iNDArray.dataType());
        }
        if (this.featuresMaskArrays != null) {
            INDArray[] iNDArrayArr2 = this.featuresMaskArrays;
            int length2 = iNDArrayArr2.length;
            for (int i2 = 0; i2 < length2; i2++) {
                INDArray iNDArray2 = iNDArrayArr2[i2];
                j += iNDArray2 == null ? 0L : iNDArray2.length() * Nd4j.sizeOfDataType(iNDArray2.dataType());
            }
        }
        if (this.labelsMaskArrays != null) {
            INDArray[] iNDArrayArr3 = this.labelsMaskArrays;
            int length3 = iNDArrayArr3.length;
            for (int i3 = 0; i3 < length3; i3++) {
                INDArray iNDArray3 = iNDArrayArr3[i3];
                j += iNDArray3 == null ? 0L : iNDArray3.length() * Nd4j.sizeOfDataType(iNDArray3.dataType());
            }
        }
        if (this.labels != null) {
            INDArray[] iNDArrayArr4 = this.labels;
            int length4 = iNDArrayArr4.length;
            for (int i4 = 0; i4 < length4; i4++) {
                INDArray iNDArray4 = iNDArrayArr4[i4];
                j += iNDArray4 == null ? 0L : iNDArray4.length() * Nd4j.sizeOfDataType(iNDArray4.dataType());
            }
        }
        return j;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void migrate() {
        if (Nd4j.getMemoryManager().getCurrentWorkspace() != null) {
            if (this.features != null) {
                for (int i = 0; i < this.features.length; i++) {
                    this.features[i] = this.features[i].migrate();
                }
            }
            if (this.labels != null) {
                for (int i2 = 0; i2 < this.labels.length; i2++) {
                    this.labels[i2] = this.labels[i2].migrate();
                }
            }
            if (this.featuresMaskArrays != null) {
                for (int i3 = 0; i3 < this.featuresMaskArrays.length; i3++) {
                    this.featuresMaskArrays[i3] = this.featuresMaskArrays[i3].migrate();
                }
            }
            if (this.labelsMaskArrays != null) {
                for (int i4 = 0; i4 < this.labelsMaskArrays.length; i4++) {
                    this.labelsMaskArrays[i4] = this.labelsMaskArrays[i4].migrate();
                }
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void detach() {
        if (this.features != null) {
            for (int i = 0; i < this.features.length; i++) {
                this.features[i] = this.features[i].detach();
            }
        }
        if (this.labels != null) {
            for (int i2 = 0; i2 < this.labels.length; i2++) {
                this.labels[i2] = this.labels[i2].detach();
            }
        }
        if (this.featuresMaskArrays != null) {
            for (int i3 = 0; i3 < this.featuresMaskArrays.length; i3++) {
                this.featuresMaskArrays[i3] = this.featuresMaskArrays[i3].detach();
            }
        }
        if (this.labelsMaskArrays != null) {
            for (int i4 = 0; i4 < this.labelsMaskArrays.length; i4++) {
                this.labelsMaskArrays[i4] = this.labelsMaskArrays[i4].detach();
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public boolean isEmpty() {
        return nullOrEmpty(this.features) && nullOrEmpty(this.labels) && nullOrEmpty(this.featuresMaskArrays) && nullOrEmpty(this.labelsMaskArrays);
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void shuffle() {
        List<org.nd4j.linalg.dataset.api.MultiDataSet> asList = asList();
        Collections.shuffle(asList);
        MultiDataSet merge = merge(asList);
        this.features = merge.features;
        this.labels = merge.labels;
        this.featuresMaskArrays = merge.featuresMaskArrays;
        this.labelsMaskArrays = merge.labelsMaskArrays;
        this.exampleMetaData = merge.exampleMetaData;
    }

    private static boolean nullOrEmpty(INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null) {
            return true;
        }
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray != null) {
                return false;
            }
        }
        return true;
    }
}
