package edu.iu.dsc.tws.comms.table.channel;

import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.config.MPIContext;
import edu.iu.dsc.tws.api.exceptions.Twister2RuntimeException;
import edu.iu.dsc.tws.api.resource.IWorkerController;
import java.nio.IntBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.logging.Level;
import java.util.logging.Logger;
import mpi.Intracomm;
import mpi.MPI;
import mpi.MPIException;
import mpi.Request;
import mpi.Status;

/* loaded from: input_file:edu/iu/dsc/tws/comms/table/channel/MPIChannel.class */
public class MPIChannel {
    private static final Logger LOG;
    public static final int TWISTERX_CHANNEL_USER_HEADER = 6;
    private static final int TWISTERX_CHANNEL_HEADER_SIZE = 8;
    private static final int TWISTERX_MSG_FIN = 1;
    private Intracomm comm;
    private Config cfg;
    private int edge;
    private Map<Integer, PendingSend> sends = new HashMap();
    private Map<Integer, PendingReceive> pendingReceives = new HashMap();
    private Map<Integer, TRequest> finishRequests = new HashMap();
    private ChannelReceiveCallback receiveCallback;
    private ChannelSendCallback sendCallback;
    private int rank;
    private Allocator allocator;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/iu/dsc/tws/comms/table/channel/MPIChannel$PendingReceive.class */
    private class PendingReceive {
        private IntBuffer headerBuf;
        private int receiveId;
        private ChannelBuffer data;
        private int length;
        private ReceiveStatus status;
        private Request request;
        private int[] userHeader;

        private PendingReceive() {
            this.headerBuf = MPI.newIntBuffer(MPIChannel.TWISTERX_CHANNEL_HEADER_SIZE);
            this.status = ReceiveStatus.RECEIVE_INIT;
            this.userHeader = new int[9];
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/comms/table/channel/MPIChannel$PendingSend.class */
    public class PendingSend {
        private IntBuffer headerBuf;
        private Queue<TRequest> pendingData;
        private SendStatus status;
        private TRequest currentSend;
        private Request request;

        private PendingSend() {
            this.headerBuf = MPI.newIntBuffer(MPIChannel.TWISTERX_CHANNEL_HEADER_SIZE);
            this.pendingData = new LinkedList();
            this.status = SendStatus.SEND_INIT;
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/comms/table/channel/MPIChannel$ReceiveStatus.class */
    private enum ReceiveStatus {
        RECEIVE_INIT,
        RECEIVE_LENGTH_POSTED,
        RECEIVE_POSTED,
        RECEIVED_FIN
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/comms/table/channel/MPIChannel$SendStatus.class */
    public enum SendStatus {
        SEND_INIT,
        SEND_LENGTH_POSTED,
        SEND_POSTED,
        SEND_FINISH,
        SEND_DONE
    }

    public MPIChannel(Config config, IWorkerController iWorkerController, int i, List<Integer> list, List<Integer> list2, ChannelReceiveCallback channelReceiveCallback, ChannelSendCallback channelSendCallback, Allocator allocator) {
        this.cfg = config;
        Object runtimeObject = MPIContext.getRuntimeObject("comm");
        if (runtimeObject == null) {
            this.comm = MPI.COMM_WORLD;
        } else {
            this.comm = (Intracomm) runtimeObject;
        }
        this.edge = i;
        this.receiveCallback = channelReceiveCallback;
        this.sendCallback = channelSendCallback;
        this.allocator = allocator;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            PendingReceive pendingReceive = new PendingReceive();
            pendingReceive.receiveId = intValue;
            this.pendingReceives.put(Integer.valueOf(intValue), pendingReceive);
            try {
                pendingReceive.request = this.comm.iRecv(pendingReceive.headerBuf, TWISTERX_CHANNEL_HEADER_SIZE, MPI.INT, intValue, this.edge);
                pendingReceive.status = ReceiveStatus.RECEIVE_LENGTH_POSTED;
            } catch (MPIException e) {
                LOG.log(Level.SEVERE, "Failed to request", e);
                throw new RuntimeException(e);
            }
        }
        Iterator<Integer> it2 = list2.iterator();
        while (it2.hasNext()) {
            this.sends.put(Integer.valueOf(it2.next().intValue()), new PendingSend());
        }
        try {
            this.rank = this.comm.getRank();
        } catch (MPIException e2) {
            LOG.log(Level.SEVERE, "Failed to get mpi processes", e2);
            throw new RuntimeException(e2);
        }
    }

    public int send(TRequest tRequest) {
        if (this.sends.get(Integer.valueOf(tRequest.target)).pendingData.offer(tRequest)) {
            return TWISTERX_MSG_FIN;
        }
        return -1;
    }

    public int sendFin(TRequest tRequest) {
        if (this.finishRequests.containsKey(Integer.valueOf(tRequest.target))) {
            LOG.log(Level.WARNING, "Sending finish to target twice " + tRequest.target);
            return -1;
        }
        this.finishRequests.put(Integer.valueOf(tRequest.target), tRequest);
        return TWISTERX_MSG_FIN;
    }

    public void progressSends() {
        try {
            for (Map.Entry<Integer, PendingSend> entry : this.sends.entrySet()) {
                PendingSend value = entry.getValue();
                if (value.status == SendStatus.SEND_LENGTH_POSTED) {
                    if (value.request.test()) {
                        value.request = null;
                        TRequest tRequest = (TRequest) value.pendingData.peek();
                        if (!$assertionsDisabled && tRequest == null) {
                            throw new AssertionError();
                        }
                        value.request = this.comm.iSend(tRequest.buffer, tRequest.length, MPI.BYTE, tRequest.target, this.edge);
                        value.status = SendStatus.SEND_POSTED;
                        value.pendingData.poll();
                        value.currentSend = tRequest;
                    } else {
                        continue;
                    }
                } else if (value.status == SendStatus.SEND_INIT) {
                    value.request = null;
                    if (!value.pendingData.isEmpty()) {
                        sendHeader(entry.getValue());
                    } else if (this.finishRequests.containsKey(entry.getKey())) {
                        value.currentSend = this.finishRequests.get(entry.getKey());
                        sendFinishHeader(entry.getValue());
                    }
                } else if (value.status == SendStatus.SEND_POSTED) {
                    if (value.request.test()) {
                        value.request = null;
                        if (value.pendingData.isEmpty()) {
                            this.sendCallback.sendComplete(value.currentSend);
                            value.currentSend = null;
                            if (this.finishRequests.containsKey(entry.getKey())) {
                                value.currentSend = this.finishRequests.get(entry.getKey());
                                sendFinishHeader(entry.getValue());
                            } else {
                                value.status = SendStatus.SEND_INIT;
                            }
                        } else {
                            sendHeader(entry.getValue());
                            this.sendCallback.sendComplete(value.currentSend);
                            value.currentSend = null;
                        }
                    }
                } else if (value.status == SendStatus.SEND_FINISH && value.request.test()) {
                    this.sendCallback.sendFinishComplete(this.finishRequests.get(entry.getKey()));
                    value.status = SendStatus.SEND_DONE;
                }
            }
        } catch (MPIException e) {
            LOG.log(Level.SEVERE, "Exception in MPI", e);
            throw new RuntimeException(e);
        }
    }

    public void progressReceives() {
        try {
            for (Map.Entry<Integer, PendingReceive> entry : this.pendingReceives.entrySet()) {
                PendingReceive value = entry.getValue();
                if (value.status == ReceiveStatus.RECEIVE_LENGTH_POSTED) {
                    Status testStatus = value.request.testStatus();
                    if (testStatus != null) {
                        value.request = null;
                        int count = testStatus.getCount(MPI.INT);
                        int i = value.headerBuf.get(0);
                        int i2 = value.headerBuf.get(TWISTERX_MSG_FIN);
                        if (i2 != TWISTERX_MSG_FIN) {
                            if (count > TWISTERX_CHANNEL_HEADER_SIZE) {
                                LOG.log(Level.SEVERE, "Un-expected number of bytes expected: 8 or less received: " + count);
                            }
                            value.data = this.allocator.allocate(i);
                            value.length = i;
                            value.request = this.comm.iRecv(value.data.getByteBuffer(), i, MPI.BYTE, value.receiveId, this.edge);
                            value.status = ReceiveStatus.RECEIVE_POSTED;
                            if (count > 2) {
                                for (int i3 = 2; i3 < count; i3 += TWISTERX_MSG_FIN) {
                                    value.userHeader[i3 - 2] = value.headerBuf.get(i3);
                                }
                            }
                            this.receiveCallback.receivedHeader(entry.getKey().intValue(), i2, value.userHeader, count - 2);
                        } else {
                            if (count != 2) {
                                LOG.log(Level.SEVERE, "Un-expected number of bytes expected: 2 received: " + count);
                            }
                            value.status = ReceiveStatus.RECEIVED_FIN;
                            this.receiveCallback.receivedHeader(entry.getKey().intValue(), i2, null, 0);
                        }
                    }
                } else if (value.status == ReceiveStatus.RECEIVE_POSTED && value.request.test()) {
                    value.request = null;
                    value.headerBuf.clear();
                    value.request = this.comm.iRecv(value.headerBuf, TWISTERX_CHANNEL_HEADER_SIZE, MPI.INT, value.receiveId, this.edge);
                    value.status = ReceiveStatus.RECEIVE_LENGTH_POSTED;
                    this.receiveCallback.receivedData(entry.getKey().intValue(), value.data, value.length);
                }
            }
        } catch (MPIException e) {
            LOG.log(Level.SEVERE, "Error in MPI", e);
            throw new Twister2RuntimeException(e);
        }
    }

    private void sendHeader(PendingSend pendingSend) throws MPIException {
        TRequest tRequest = (TRequest) pendingSend.pendingData.peek();
        if (!$assertionsDisabled && pendingSend.pendingData.size() <= 0) {
            throw new AssertionError();
        }
        pendingSend.headerBuf.put(0, tRequest.length);
        pendingSend.headerBuf.put(TWISTERX_MSG_FIN, 0);
        if (tRequest.headerLength > 0) {
            for (int i = 0; i < tRequest.headerLength; i += TWISTERX_MSG_FIN) {
                pendingSend.headerBuf.put(i + 2, tRequest.header[i]);
            }
        }
        pendingSend.request = this.comm.iSend(pendingSend.headerBuf, 2 + tRequest.headerLength, MPI.INT, tRequest.target, this.edge);
        pendingSend.status = SendStatus.SEND_LENGTH_POSTED;
    }

    private void sendFinishHeader(PendingSend pendingSend) throws MPIException {
        pendingSend.headerBuf.put(0, 0);
        pendingSend.headerBuf.put(TWISTERX_MSG_FIN, TWISTERX_MSG_FIN);
        pendingSend.request = this.comm.iSend(pendingSend.headerBuf, 2, MPI.INT, pendingSend.currentSend.target, this.edge);
        pendingSend.status = SendStatus.SEND_FINISH;
    }

    public void close() {
    }

    static {
        $assertionsDisabled = !MPIChannel.class.desiredAssertionStatus();
        LOG = Logger.getLogger(MPIChannel.class.getName());
    }
}
