package org.nd4j.linalg.api.ops.aggregates;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/ops/aggregates/Batch.class */
public class Batch<T extends Aggregate> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) Batch.class);
    private DataBuffer paramsSurface;
    private static final int batchLimit = 512;
    private List<T> aggregates;
    private T sample;
    private int numAggregates;

    public Batch(List<T> list) {
        this.aggregates = list;
        this.numAggregates = list.size();
        this.sample = list.get(0);
    }

    public int opNum() {
        return this.sample.opNum();
    }

    public boolean append(T t) {
        if (isFull()) {
            return false;
        }
        this.aggregates.add(t);
        return true;
    }

    public boolean isFull() {
        return 512 == this.numAggregates;
    }

    public static <U extends Aggregate> List<Batch<U>> getBatches(List<U> list) {
        return getBatches(list, 512);
    }

    public static <U extends Aggregate> List<Batch<U>> getBatches(List<U> list, int i) {
        DataType dataType = null;
        Iterator<U> it2 = list.iterator();
        while (it2.hasNext()) {
            for (INDArray iNDArray : it2.next().getArguments()) {
                if (dataType == null && iNDArray != null) {
                    dataType = iNDArray.dataType();
                }
                if (iNDArray != null && dataType != null) {
                    Preconditions.checkArgument(dataType == iNDArray.dataType(), "All arguments must have same data type");
                }
            }
        }
        if (dataType == null) {
            throw new ND4JIllegalStateException("Can't infer data type from arguments");
        }
        List partition = Lists.partition(list, i);
        ArrayList arrayList = new ArrayList();
        Iterator it3 = partition.iterator();
        while (it3.hasNext()) {
            arrayList.add(new Batch((List) it3.next()));
        }
        return arrayList;
    }

    public DataBuffer getParamsSurface() {
        return this.paramsSurface;
    }

    public void setParamsSurface(DataBuffer dataBuffer) {
        this.paramsSurface = dataBuffer;
    }

    public static int getBatchLimit() {
        return 512;
    }

    public List<T> getAggregates() {
        return this.aggregates;
    }

    public T getSample() {
        return this.sample;
    }

    public int getNumAggregates() {
        return this.numAggregates;
    }
}
