package edu.iu.dsc.tws.comms.ucx;

import edu.iu.dsc.tws.api.comms.channel.ChannelListener;
import edu.iu.dsc.tws.api.comms.channel.TWSChannel;
import edu.iu.dsc.tws.api.comms.messaging.ChannelMessage;
import edu.iu.dsc.tws.api.comms.packing.DataBuffer;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.exceptions.TimeoutException;
import edu.iu.dsc.tws.api.exceptions.Twister2RuntimeException;
import edu.iu.dsc.tws.api.resource.IWorkerController;
import edu.iu.dsc.tws.proto.jobmaster.JobMasterAPI;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.openucx.jucx.UcxCallback;
import org.openucx.jucx.ucp.UcpContext;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpEndpointParams;
import org.openucx.jucx.ucp.UcpListenerParams;
import org.openucx.jucx.ucp.UcpParams;
import org.openucx.jucx.ucp.UcpRequest;
import org.openucx.jucx.ucp.UcpWorker;
import org.openucx.jucx.ucp.UcpWorkerParams;

/* loaded from: input_file:edu/iu/dsc/tws/comms/ucx/TWSUCXChannel.class */
public class TWSUCXChannel implements TWSChannel {
    private static final Logger LOG = Logger.getLogger(TWSUCXChannel.class.getName());
    private UcpWorker ucpWorker;
    private int workerId;
    private final Stack<Closeable> closeables = new Stack<>();
    private final Map<Integer, UcpEndpoint> endpoints = new HashMap();
    private AtomicLong pendingSendRequests = new AtomicLong();
    private int tagWIdOffset = 100000;
    private List<ReceiveProgress> receiveProgresses = new ArrayList();
    private Map<Integer, Map<Integer, Set<ReceiveProgress>>> groupReceives = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/iu/dsc/tws/comms/ucx/TWSUCXChannel$ReceiveProgress.class */
    public class ReceiveProgress implements Closeable {
        private int group;
        private int id;
        private int edge;
        private ChannelListener callback;
        private Queue<DataBuffer> receiveBuffers;
        private AtomicLong requestIdCounter = new AtomicLong();
        private Map<Long, UcpRequest> requestsMap = new ConcurrentHashMap();
        private boolean closed = false;

        ReceiveProgress(int i, int i2, int i3, ChannelListener channelListener, Queue<DataBuffer> queue) {
            this.group = i;
            this.id = i2;
            this.edge = i3;
            this.callback = channelListener;
            this.receiveBuffers = queue;
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() {
            if (this.closed) {
                return;
            }
            this.closed = true;
            this.requestsMap.values().forEach(ucpRequest -> {
                try {
                    TWSUCXChannel.this.ucpWorker.cancelRequest(ucpRequest);
                } catch (NullPointerException e) {
                }
            });
        }

        public void progress() {
            while (!this.receiveBuffers.isEmpty() && !this.closed) {
                final DataBuffer poll = this.receiveBuffers.poll();
                final int i = (this.id * TWSUCXChannel.this.tagWIdOffset) + this.edge;
                TWSUCXChannel.LOG.log(Level.FINE, () -> {
                    return String.format("EXPECTING from TAG: %d, Buffer : %s", Integer.valueOf(i), poll.getByteBuffer());
                });
                final long incrementAndGet = this.requestIdCounter.incrementAndGet();
                this.requestsMap.put(Long.valueOf(incrementAndGet), TWSUCXChannel.this.ucpWorker.recvTaggedNonBlocking(poll.getByteBuffer(), i, 65535L, new UcxCallback() { // from class: edu.iu.dsc.tws.comms.ucx.TWSUCXChannel.ReceiveProgress.1
                    public void onSuccess(UcpRequest ucpRequest) {
                        Logger logger = TWSUCXChannel.LOG;
                        Level level = Level.FINE;
                        DataBuffer dataBuffer = poll;
                        int i2 = i;
                        logger.log(level, () -> {
                            return String.format("Recv Buff from %d[%d] : %s, TAG[%d], Size : %d", Integer.valueOf(ReceiveProgress.this.id), Integer.valueOf(ReceiveProgress.this.edge), dataBuffer.getByteBuffer(), Integer.valueOf(i2), Integer.valueOf(dataBuffer.getByteBuffer().getInt(0)));
                        });
                        poll.setSize((int) ucpRequest.getRecvSize());
                        ReceiveProgress.this.requestsMap.remove(Long.valueOf(incrementAndGet));
                        ReceiveProgress.this.callback.onReceiveComplete(ReceiveProgress.this.id, ReceiveProgress.this.edge, poll);
                    }

                    public void onError(int i2, String str) {
                        if (i2 != -16) {
                            String str2 = "Failed to receive from " + ReceiveProgress.this.id + " with status " + i2 + ". Error : " + str;
                            TWSUCXChannel.LOG.severe(str2);
                            ReceiveProgress.this.requestsMap.remove(Long.valueOf(incrementAndGet));
                            throw new Twister2RuntimeException(str2);
                        }
                    }
                }));
            }
        }
    }

    public TWSUCXChannel(Config config, IWorkerController iWorkerController) {
        this.workerId = iWorkerController.getWorkerInfo().getWorkerID();
        createUXCWorker(iWorkerController);
    }

    private void createUXCWorker(IWorkerController iWorkerController) {
        Closeable ucpContext = new UcpContext(new UcpParams().requestTagFeature().setMtWorkersShared(false));
        this.closeables.push(ucpContext);
        this.ucpWorker = ucpContext.newWorker(new UcpWorkerParams().requestThreadSafety());
        this.closeables.push(this.ucpWorker);
        this.closeables.push(this.ucpWorker.newListener(new UcpListenerParams().setSockAddr(new InetSocketAddress(iWorkerController.getWorkerInfo().getWorkerIP(), iWorkerController.getWorkerInfo().getPort()))));
        try {
            iWorkerController.waitOnBarrier();
        } catch (TimeoutException e) {
            LOG.log(Level.SEVERE, "Failed to wait on barrier", e);
        }
        for (JobMasterAPI.WorkerInfo workerInfo : iWorkerController.getJoinedWorkers()) {
            if (workerInfo.getWorkerID() != this.workerId) {
                Closeable newEndpoint = this.ucpWorker.newEndpoint(new UcpEndpointParams().setSocketAddress(new InetSocketAddress(workerInfo.getWorkerIP(), workerInfo.getPort())));
                this.endpoints.put(Integer.valueOf(workerInfo.getWorkerID()), newEndpoint);
                this.closeables.push(newEndpoint);
            }
        }
    }

    public boolean sendMessage(final int i, final ChannelMessage channelMessage, final ChannelListener channelListener) {
        final AtomicInteger atomicInteger = new AtomicInteger(channelMessage.getBuffers().size());
        for (DataBuffer dataBuffer : channelMessage.getBuffers()) {
            dataBuffer.getByteBuffer().limit(dataBuffer.getSize());
            dataBuffer.getByteBuffer().position(0);
            int edge = (this.workerId * this.tagWIdOffset) + channelMessage.getHeader().getEdge();
            LOG.log(Level.FINE, () -> {
                return String.format("SENDING to %d[%d] : %s, TAG[%d]", Integer.valueOf(i), Integer.valueOf(channelMessage.getHeader().getEdge()), dataBuffer.getByteBuffer(), Integer.valueOf(edge));
            });
            this.endpoints.get(Integer.valueOf(i)).sendTaggedNonBlocking(dataBuffer.getByteBuffer(), edge, new UcxCallback() { // from class: edu.iu.dsc.tws.comms.ucx.TWSUCXChannel.1
                public void onSuccess(UcpRequest ucpRequest) {
                    TWSUCXChannel.this.pendingSendRequests.decrementAndGet();
                    if (atomicInteger.decrementAndGet() == 0) {
                        channelListener.onSendComplete(i, channelMessage.getHeader().getEdge(), channelMessage);
                    }
                }

                public void onError(int i2, String str) {
                    TWSUCXChannel.LOG.severe("UCX send request failed to worker " + i + " with status " + i2 + ". Error : " + str);
                    throw new Twister2RuntimeException("Send request to worker : " + i + " failed. " + str);
                }
            });
            this.pendingSendRequests.incrementAndGet();
        }
        return true;
    }

    public boolean receiveMessage(int i, int i2, int i3, ChannelListener channelListener, Queue<DataBuffer> queue) {
        ReceiveProgress receiveProgress = new ReceiveProgress(i, i2, i3, channelListener, queue);
        receiveProgress.progress();
        this.receiveProgresses.add(receiveProgress);
        this.groupReceives.computeIfAbsent(Integer.valueOf(i2), num -> {
            return new HashMap();
        }).computeIfAbsent(Integer.valueOf(i3), num2 -> {
            return new HashSet();
        }).add(receiveProgress);
        return true;
    }

    public void progress() {
        Iterator<ReceiveProgress> it = this.receiveProgresses.iterator();
        while (it.hasNext()) {
            it.next().progress();
        }
        this.ucpWorker.progress();
    }

    public void progressSends() {
        progress();
    }

    public void progressReceives(int i) {
        progress();
    }

    public boolean isComplete() {
        return this.pendingSendRequests.get() == 0;
    }

    public ByteBuffer createBuffer(int i) {
        return ByteBuffer.allocateDirect(i);
    }

    public void close() {
        while (!this.closeables.isEmpty()) {
            Closeable pop = this.closeables.pop();
            try {
                pop.close();
            } catch (IOException e) {
                throw new Twister2RuntimeException("Failed to close UCX channel component : " + pop, e);
            }
        }
    }

    public void releaseBuffers(int i, int i2) {
        Iterator<ReceiveProgress> it = this.groupReceives.getOrDefault(Integer.valueOf(i), Collections.emptyMap()).getOrDefault(Integer.valueOf(i2), Collections.emptySet()).iterator();
        while (it.hasNext()) {
            it.next().close();
        }
    }

    public void reInit(List<JobMasterAPI.WorkerInfo> list) {
    }
}
