package org.deeplearning4j.datasets.iterator.callbacks;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/callbacks/InterleavedDataSetCallback.class */
public class InterleavedDataSetCallback implements DataSetCallback {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) InterleavedDataSetCallback.class);
    private int bufferSize;
    private int numWorkspaces;
    private List<MemoryWorkspace> workspaces = new ArrayList();
    private boolean isInitialized = false;
    private AtomicLong counterInput = new AtomicLong(0);

    public InterleavedDataSetCallback(int i) {
        this.bufferSize = i;
    }

    protected void initializeWorkspaces(long j) {
        WorkspaceConfiguration build = WorkspaceConfiguration.builder().initialSize(j).overallocationLimit(this.bufferSize).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.NONE).build();
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int intValue = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        for (int i = 0; i < numberOfDevices; i++) {
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(i));
            this.workspaces.add(Nd4j.getWorkspaceManager().createNewWorkspace(build, "IDSC-" + i, Integer.valueOf(i)));
        }
        Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(intValue));
        this.numWorkspaces = numberOfDevices;
        this.isInitialized = true;
    }

    @Override // org.nd4j.linalg.dataset.callbacks.DataSetCallback
    public void call(DataSet dataSet) {
        if (!this.isInitialized) {
            initializeWorkspaces(dataSet.getMemoryFootprint());
        }
        Nd4j.getExecutioner().commit();
        int andIncrement = (int) (this.counterInput.getAndIncrement() % this.numWorkspaces);
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(this.workspaces.get(andIncrement));
        dataSet.migrate();
        Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
    }

    @Override // org.nd4j.linalg.dataset.callbacks.DataSetCallback
    public void call(MultiDataSet multiDataSet) {
        if (!this.isInitialized) {
            initializeWorkspaces(multiDataSet.getMemoryFootprint());
        }
        Nd4j.getExecutioner().commit();
        int andIncrement = (int) (this.counterInput.getAndIncrement() % this.numWorkspaces);
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(this.workspaces.get(andIncrement));
        multiDataSet.migrate();
        Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
    }

    @Override // org.nd4j.linalg.dataset.callbacks.DataSetCallback
    public void reset() {
        this.counterInput.set(0L);
    }
}
