package edu.iu.dsc.tws.common.net.tcp.request;

import com.google.common.collect.HashBiMap;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.net.StatusCode;
import edu.iu.dsc.tws.api.net.request.ConnectHandler;
import edu.iu.dsc.tws.api.net.request.MessageHandler;
import edu.iu.dsc.tws.api.net.request.RequestID;
import edu.iu.dsc.tws.common.net.tcp.ChannelHandler;
import edu.iu.dsc.tws.common.net.tcp.Progress;
import edu.iu.dsc.tws.common.net.tcp.Server;
import edu.iu.dsc.tws.common.net.tcp.TCPMessage;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/common/net/tcp/request/RRServer.class */
public class RRServer {
    private static final Logger LOG = Logger.getLogger(RRServer.class.getName());
    protected Server server;
    protected SocketChannel clientChannel;
    protected int serverID;
    public static final int CLIENT_ID = -100;
    protected ConnectHandler connectHandler;
    protected Progress loop;
    protected List<SocketChannel> connectedChannels = new ArrayList();
    protected HashBiMap<SocketChannel, Integer> workerChannels = HashBiMap.create();
    protected Map<String, MessageHandler> requestHandlers = new HashMap();
    protected Map<String, Message.Builder> messageBuilders = new HashMap();
    protected Map<RequestID, SocketChannel> requestChannels = new HashMap();
    protected int pendingSendCount = 0;

    /* loaded from: input_file:edu/iu/dsc/tws/common/net/tcp/request/RRServer$Handler.class */
    private class Handler implements ChannelHandler {
        private Handler() {
        }

        @Override // edu.iu.dsc.tws.common.net.tcp.ChannelHandler
        public void onError(SocketChannel socketChannel) {
            RRServer.this.workerChannels.remove(socketChannel);
            RRServer.this.connectedChannels.remove(socketChannel);
            RRServer.this.connectHandler.onError(socketChannel);
            RRServer.this.loop.removeAllInterest(socketChannel);
            try {
                socketChannel.close();
                RRServer.LOG.log(Level.FINEST, "Closed the channel: " + socketChannel);
            } catch (IOException e) {
                RRServer.LOG.log(Level.SEVERE, "Channel closed error: " + socketChannel, (Throwable) e);
            }
        }

        @Override // edu.iu.dsc.tws.common.net.tcp.ChannelHandler
        public void onConnect(SocketChannel socketChannel, StatusCode statusCode) {
            RRServer.this.connectedChannels.add(socketChannel);
            RRServer.this.connectHandler.onConnect(socketChannel, statusCode);
        }

        @Override // edu.iu.dsc.tws.common.net.tcp.ChannelHandler
        public void onClose(SocketChannel socketChannel) {
            RRServer.this.workerChannels.remove(socketChannel);
            RRServer.this.connectedChannels.remove(socketChannel);
            RRServer.this.connectHandler.onClose(socketChannel);
            if (socketChannel.equals(RRServer.this.clientChannel)) {
                RRServer.this.clientChannel = null;
            }
        }

        @Override // edu.iu.dsc.tws.common.net.tcp.ChannelHandler
        public void onReceiveComplete(SocketChannel socketChannel, TCPMessage tCPMessage) {
            if (socketChannel == null) {
                RRServer.LOG.log(Level.SEVERE, "Chanel on receive is NULL");
            }
            ByteBuffer byteBuffer = tCPMessage.getByteBuffer();
            byte[] bArr = new byte[32];
            byteBuffer.get(bArr);
            RequestID fromBytes = RequestID.fromBytes(bArr);
            String unPackString = ByteUtils.unPackString(byteBuffer);
            int i = byteBuffer.getInt();
            Message.Builder builder = RRServer.this.messageBuilders.get(unPackString);
            if (builder == null) {
                throw new RuntimeException("Received response without a registered response");
            }
            try {
                builder.clear();
                byte[] bArr2 = new byte[tCPMessage.getLength() - ((8 + bArr.length) + unPackString.getBytes().length)];
                byteBuffer.get(bArr2);
                builder.mergeFrom(bArr2);
                Message build = builder.build();
                RRServer.this.saveChannel(socketChannel, i, build);
                RRServer.LOG.log(Level.FINEST, String.format("Adding channel %s", new String(bArr)));
                RRServer.this.requestChannels.put(fromBytes, socketChannel);
                RRServer.this.requestHandlers.get(unPackString).onMessage(fromBytes, i, build);
            } catch (InvalidProtocolBufferException e) {
                RRServer.LOG.log(Level.SEVERE, "Failed to build a message", e);
            }
        }

        @Override // edu.iu.dsc.tws.common.net.tcp.ChannelHandler
        public void onSendComplete(SocketChannel socketChannel, TCPMessage tCPMessage) {
            RRServer.this.pendingSendCount--;
        }
    }

    public RRServer(Config config, String str, int i, Progress progress, int i2, ConnectHandler connectHandler) {
        this.connectHandler = connectHandler;
        this.loop = progress;
        this.serverID = i2;
        this.server = new Server(config, str, i, this.loop, new Handler(), false);
    }

    public void registerRequestHandler(Message.Builder builder, MessageHandler messageHandler) {
        this.requestHandlers.put(builder.getDescriptorForType().getFullName(), messageHandler);
        this.messageBuilders.put(builder.getDescriptorForType().getFullName(), builder);
    }

    public void start() {
        this.server.start();
    }

    public void stop() {
        this.server.stop();
    }

    public void stopGraceFully(long j) {
        long currentTimeMillis;
        long currentTimeMillis2 = System.currentTimeMillis();
        do {
            this.loop.loop();
            boolean hasPending = this.server.hasPending();
            currentTimeMillis = System.currentTimeMillis() - currentTimeMillis2;
            if (!hasPending && this.pendingSendCount == 0 && this.connectedChannels.size() <= 0) {
                break;
            }
        } while (currentTimeMillis < j);
        stop();
    }

    public Set<Integer> getConnectedWorkers() {
        return this.workerChannels.values();
    }

    public boolean sendResponse(RequestID requestID, Message message) {
        if (!this.requestChannels.containsKey(requestID)) {
            LOG.log(Level.SEVERE, "Trying to send a response to non-existing request");
            return false;
        }
        SocketChannel socketChannel = this.requestChannels.get(requestID);
        if (socketChannel == null) {
            LOG.log(Level.SEVERE, "Channel is NULL for response");
        }
        if (!this.workerChannels.containsKey(socketChannel) && !socketChannel.equals(this.clientChannel)) {
            LOG.log(Level.WARNING, "Failed to send response on disconnected socket");
            return false;
        }
        if (sendMessage(message, requestID, socketChannel) == null) {
            return false;
        }
        this.requestChannels.remove(requestID);
        return true;
    }

    public boolean sendMessage(Message message, int i) {
        SocketChannel socketChannel;
        if (i == -100) {
            if (this.clientChannel == null) {
                LOG.severe("Trying to send a message to the client, but it has not connected yet.");
                return false;
            }
            socketChannel = this.clientChannel;
        } else {
            if (!this.workerChannels.containsValue(Integer.valueOf(i))) {
                LOG.severe("Trying to send a message to a worker that has not connected yet. workerID: " + i);
                return false;
            }
            socketChannel = (SocketChannel) this.workerChannels.inverse().get(Integer.valueOf(i));
        }
        if (socketChannel != null) {
            return sendMessage(message, RequestID.DUMMY_REQUEST_ID, socketChannel) != null;
        }
        LOG.log(Level.SEVERE, "Channel is NULL for response");
        return false;
    }

    protected TCPMessage sendMessage(Message message, RequestID requestID, SocketChannel socketChannel) {
        byte[] byteArray = message.toByteArray();
        String fullName = message.getDescriptorForType().getFullName();
        int length = requestID.getId().length + byteArray.length + fullName.getBytes().length + 8;
        ByteBuffer allocate = ByteBuffer.allocate(length);
        allocate.put(requestID.getId());
        ByteUtils.packString(fullName, allocate);
        allocate.putInt(this.serverID);
        allocate.put(byteArray);
        TCPMessage send = this.server.send(socketChannel, allocate, length, 0);
        if (send != null) {
            this.pendingSendCount++;
        }
        return send;
    }

    public void removeWorkerChannel(int i) {
        SocketChannel socketChannel = (SocketChannel) this.workerChannels.inverse().remove(Integer.valueOf(i));
        if (socketChannel == null) {
            return;
        }
        try {
            socketChannel.close();
        } catch (IOException e) {
            LOG.log(Level.WARNING, "Exception when closing the channel: ", (Throwable) e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void saveChannel(SocketChannel socketChannel, int i, Message message) {
        if (this.workerChannels.containsKey(socketChannel)) {
            return;
        }
        if (i == -100) {
            this.clientChannel = socketChannel;
            LOG.info("Message received from submitting client. Channel set.");
        } else {
            if (this.workerChannels.inverse().containsKey(Integer.valueOf(i))) {
                LOG.warning(String.format("While there is a channel for workerID[%d], another channel connected from the same worker. Replacing older one. ", Integer.valueOf(i)));
            }
            this.workerChannels.forcePut(socketChannel, Integer.valueOf(i));
        }
    }
}
