package edu.iu.dsc.tws.comms.tcp;

import edu.iu.dsc.tws.api.comms.CommunicationContext;
import edu.iu.dsc.tws.api.comms.channel.ChannelListener;
import edu.iu.dsc.tws.api.comms.channel.TWSChannel;
import edu.iu.dsc.tws.api.comms.messaging.ChannelMessage;
import edu.iu.dsc.tws.api.comms.packing.DataBuffer;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.exceptions.TimeoutException;
import edu.iu.dsc.tws.api.exceptions.Twister2RuntimeException;
import edu.iu.dsc.tws.api.resource.IWorkerController;
import edu.iu.dsc.tws.common.net.NetworkInfo;
import edu.iu.dsc.tws.common.net.tcp.TCPChannel;
import edu.iu.dsc.tws.common.net.tcp.TCPMessage;
import edu.iu.dsc.tws.common.net.tcp.TCPStatus;
import edu.iu.dsc.tws.common.util.IterativeLinkedList;
import edu.iu.dsc.tws.proto.jobmaster.JobMasterAPI;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;

/* loaded from: input_file:edu/iu/dsc/tws/comms/tcp/TWSTCPChannel.class */
public class TWSTCPChannel implements TWSChannel {
    private static final Logger LOG = Logger.getLogger(TWSTCPChannel.class.getName());
    private int executor;
    private ArrayBlockingQueue<TCPSendRequests> pendingSends;
    private List<TCPReceiveRequests> registeredReceives;
    private Int2ObjectArrayMap<List<TCPReceiveRequests>> groupedRegisteredReceives;
    private IterativeLinkedList<TCPSendRequests> waitForCompletionSends;
    private TCPChannel comm;
    private int sendCount = 0;
    private int pendingSendCount = 0;
    private int completedReceives = 0;
    private List<Pair<Integer, Integer>> pendingCloseRequests = new ArrayList();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/comms/tcp/TWSTCPChannel$Request.class */
    public class Request {
        TCPMessage request;
        DataBuffer buffer;

        Request(TCPMessage tCPMessage, DataBuffer dataBuffer) {
            this.request = tCPMessage;
            this.buffer = dataBuffer;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/comms/tcp/TWSTCPChannel$TCPReceiveRequests.class */
    public class TCPReceiveRequests {
        IterativeLinkedList<Request> pendingRequests = new IterativeLinkedList<>();
        int rank;
        int edge;
        ChannelListener callback;
        Queue<DataBuffer> availableBuffers;

        TCPReceiveRequests(int i, int i2, ChannelListener channelListener, Queue<DataBuffer> queue) {
            this.rank = i;
            this.edge = i2;
            this.callback = channelListener;
            this.availableBuffers = queue;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/comms/tcp/TWSTCPChannel$TCPSendRequests.class */
    public class TCPSendRequests {
        IterativeLinkedList<Request> pendingSends = new IterativeLinkedList<>();
        int rank;
        int edge;
        ChannelMessage message;
        ChannelListener callback;

        TCPSendRequests(int i, int i2, ChannelMessage channelMessage, ChannelListener channelListener) {
            this.rank = i;
            this.edge = i2;
            this.message = channelMessage;
            this.callback = channelListener;
        }
    }

    public TWSTCPChannel(Config config, IWorkerController iWorkerController) {
        int workerID = iWorkerController.getWorkerInfo().getWorkerID();
        TCPChannel createChannel = createChannel(config, iWorkerController.getWorkerInfo().getWorkerIP(), iWorkerController.getWorkerInfo().getPort(), workerID);
        createChannel.startListening();
        try {
            iWorkerController.waitOnBarrier();
            List<JobMasterAPI.WorkerInfo> joinedWorkers = iWorkerController.getJoinedWorkers();
            ArrayList arrayList = new ArrayList();
            for (JobMasterAPI.WorkerInfo workerInfo : joinedWorkers) {
                NetworkInfo networkInfo = new NetworkInfo(workerInfo.getWorkerID());
                networkInfo.addProperty("twister2.tcp.port", Integer.valueOf(workerInfo.getPort()));
                networkInfo.addProperty("twister2.tcp.hostname", workerInfo.getWorkerIP());
                arrayList.add(networkInfo);
            }
            createChannel.startConnections(arrayList);
            createChannel.waitForConnections();
            this.pendingSends = new ArrayBlockingQueue<>(CommunicationContext.networkChannelPendingSize(config));
            this.registeredReceives = new ArrayList(1024);
            this.groupedRegisteredReceives = new Int2ObjectArrayMap<>();
            this.waitForCompletionSends = new IterativeLinkedList<>();
            this.executor = iWorkerController.getWorkerInfo().getWorkerID();
            this.comm = createChannel;
        } catch (TimeoutException e) {
            LOG.log(Level.SEVERE, e.getMessage(), e);
            throw new Twister2RuntimeException(e);
        }
    }

    private static TCPChannel createChannel(Config config, String str, int i, int i2) {
        NetworkInfo networkInfo = new NetworkInfo(i2);
        networkInfo.addProperty("twister2.tcp.hostname", str);
        networkInfo.addProperty("twister2.tcp.port", Integer.valueOf(i));
        return new TCPChannel(config, networkInfo);
    }

    public boolean sendMessage(int i, ChannelMessage channelMessage, ChannelListener channelListener) {
        boolean offer = this.pendingSends.offer(new TCPSendRequests(i, channelMessage.getHeader().getEdge(), channelMessage, channelListener));
        if (offer) {
            this.pendingSendCount++;
        }
        return offer;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v17, types: [java.util.List] */
    public boolean receiveMessage(int i, int i2, int i3, ChannelListener channelListener, Queue<DataBuffer> queue) {
        ArrayList arrayList;
        TCPReceiveRequests tCPReceiveRequests = new TCPReceiveRequests(i2, i3, channelListener, queue);
        this.registeredReceives.add(tCPReceiveRequests);
        if (this.groupedRegisteredReceives.containsKey(i)) {
            arrayList = (List) this.groupedRegisteredReceives.get(i);
        } else {
            arrayList = new ArrayList();
            this.groupedRegisteredReceives.put(i, arrayList);
        }
        arrayList.add(tCPReceiveRequests);
        return true;
    }

    public void close() {
        while (true) {
            if (this.pendingCloseRequests.isEmpty() && this.pendingSends.isEmpty() && this.waitForCompletionSends.isEmpty()) {
                this.comm.stop();
                return;
            }
            progress();
        }
    }

    public boolean isComplete() {
        return this.pendingCloseRequests.isEmpty() && this.pendingSends.isEmpty() && this.waitForCompletionSends.isEmpty();
    }

    private void postMessage(TCPSendRequests tCPSendRequests) {
        ChannelMessage channelMessage = tCPSendRequests.message;
        for (int i = 0; i < channelMessage.getNormalBuffers().size(); i++) {
            this.sendCount++;
            DataBuffer dataBuffer = (DataBuffer) channelMessage.getNormalBuffers().get(i);
            tCPSendRequests.pendingSends.add(new Request(this.comm.iSend(dataBuffer.getByteBuffer(), dataBuffer.getSize(), tCPSendRequests.rank, channelMessage.getHeader().getEdge()), dataBuffer));
        }
    }

    private void postReceive(TCPReceiveRequests tCPReceiveRequests) {
        DataBuffer poll = tCPReceiveRequests.availableBuffers.poll();
        if (poll != null) {
            tCPReceiveRequests.pendingRequests.add(new Request(postReceive(tCPReceiveRequests.rank, tCPReceiveRequests.edge, poll), poll));
        }
    }

    private TCPMessage postReceive(int i, int i2, DataBuffer dataBuffer) {
        return this.comm.iRecv(dataBuffer.getByteBuffer(), dataBuffer.getCapacity(), i, i2);
    }

    private void internalProgressReceives(List<TCPReceiveRequests> list) {
        for (int i = 0; i < list.size(); i++) {
            TCPReceiveRequests tCPReceiveRequests = list.get(i);
            if (tCPReceiveRequests.availableBuffers.size() > 0) {
                postReceive(tCPReceiveRequests);
            }
            IterativeLinkedList.ILLIterator it = tCPReceiveRequests.pendingRequests.iterator();
            while (it.hasNext()) {
                Request request = (Request) it.next();
                if (request != null && request.request != null && request.request.testStatus() == TCPStatus.COMPLETE) {
                    request.buffer.setSize(request.buffer.getByteBuffer().limit());
                    tCPReceiveRequests.callback.onReceiveComplete(tCPReceiveRequests.rank, tCPReceiveRequests.edge, request.buffer);
                    it.remove();
                }
            }
        }
        handlePendingCloseRequests();
    }

    public void progress() {
        progressSends();
        internalProgressReceives(this.registeredReceives);
        this.comm.progress();
    }

    public void progressSends() {
        while (this.pendingSends.size() > 0) {
            TCPSendRequests poll = this.pendingSends.poll();
            postMessage(poll);
            this.waitForCompletionSends.add(poll);
        }
        IterativeLinkedList.ILLIterator it = this.waitForCompletionSends.iterator();
        while (it.hasNext()) {
            TCPSendRequests tCPSendRequests = (TCPSendRequests) it.next();
            IterativeLinkedList.ILLIterator it2 = tCPSendRequests.pendingSends.iterator();
            while (it2.hasNext()) {
                if (((Request) it2.next()).request.testStatus() == TCPStatus.COMPLETE) {
                    it2.remove();
                }
            }
            if (tCPSendRequests.pendingSends.size() == 0) {
                tCPSendRequests.callback.onSendComplete(tCPSendRequests.rank, tCPSendRequests.edge, tCPSendRequests.message);
                it.remove();
            }
        }
    }

    public void progressReceives(int i) {
        internalProgressReceives((List) this.groupedRegisteredReceives.get(i));
    }

    private void handlePendingCloseRequests() {
        while (this.pendingCloseRequests.size() > 0) {
            Pair<Integer, Integer> remove = this.pendingCloseRequests.remove(0);
            Iterator<TCPReceiveRequests> it = this.registeredReceives.iterator();
            while (it.hasNext()) {
                TCPReceiveRequests next = it.next();
                if (next.edge == ((Integer) remove.getRight()).intValue() && next.rank == ((Integer) remove.getLeft()).intValue()) {
                    IterativeLinkedList.ILLIterator it2 = next.pendingRequests.iterator();
                    while (it2.hasNext()) {
                        ((Request) it2.next()).request.isComplete();
                        it2.remove();
                    }
                    it.remove();
                }
            }
        }
    }

    public ByteBuffer createBuffer(int i) {
        return ByteBuffer.allocate(i);
    }

    public void releaseBuffers(int i, int i2) {
        this.pendingCloseRequests.add(new ImmutablePair(Integer.valueOf(i), Integer.valueOf(i2)));
    }
}
