package edu.iu.dsc.tws.comms.dfw;

import edu.iu.dsc.tws.api.comms.CommunicationContext;
import edu.iu.dsc.tws.api.comms.DataFlowOperation;
import edu.iu.dsc.tws.api.comms.LogicalPlan;
import edu.iu.dsc.tws.api.comms.channel.ChannelReceiver;
import edu.iu.dsc.tws.api.comms.channel.TWSChannel;
import edu.iu.dsc.tws.api.comms.messaging.ChannelMessage;
import edu.iu.dsc.tws.api.comms.messaging.MessageHeader;
import edu.iu.dsc.tws.api.comms.messaging.MessageReceiver;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageType;
import edu.iu.dsc.tws.api.comms.packing.MessageSchema;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.comms.dfw.OutMessage;
import edu.iu.dsc.tws.comms.dfw.io.Deserializers;
import edu.iu.dsc.tws.comms.dfw.io.Serializers;
import edu.iu.dsc.tws.comms.routing.BinaryTreeRouter;
import edu.iu.dsc.tws.comms.utils.TaskPlanUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;

/* loaded from: input_file:edu/iu/dsc/tws/comms/dfw/TreeBroadcast.class */
public class TreeBroadcast implements DataFlowOperation, ChannelReceiver {
    private static final Logger LOG = Logger.getLogger(TreeBroadcast.class.getName());
    private int source;
    private Set<Integer> destinations;
    private BinaryTreeRouter router;
    private MessageReceiver finalReceiver;
    private Set<Integer> targetSet;
    private Set<Integer> pendingFinishSources;
    private Set<Integer> finishedSources;
    private Set<Integer> thisSources;
    private ChannelDataFlowOperation delegate;
    private MessageSchema messageSchema;
    private Config config;
    private LogicalPlan instancePlan;
    private int executor;
    private int edge;
    private MessageType dataType;
    private MessageType recvDataType;
    private MessageType keyType;
    private MessageType recvKeyType;
    private Queue<Pair<MessageHeader, Object>> currentReceiveMessage;
    private List<Integer> receiveTasks = new ArrayList();
    private Map<Integer, ArrayBlockingQueue<OutMessage>> pendingSendMessagesPerSource = new HashMap();
    private Lock lock = new ReentrantLock();
    private Map<Integer, RoutingParameters> routingParametersCache = new HashMap();
    private int receiveIndex = 0;
    private Set<Integer> sourceSet = new HashSet();

    public TreeBroadcast(TWSChannel tWSChannel, int i, Set<Integer> set, MessageReceiver messageReceiver, MessageType messageType, MessageType messageType2, MessageSchema messageSchema) {
        this.source = i;
        this.destinations = set;
        this.finalReceiver = messageReceiver;
        this.keyType = messageType;
        this.dataType = messageType2;
        this.delegate = new ChannelDataFlowOperation(tWSChannel);
        this.messageSchema = messageSchema;
        this.sourceSet.add(Integer.valueOf(i));
        this.targetSet = new HashSet(set);
        this.pendingFinishSources = new HashSet();
        this.finishedSources = new HashSet();
    }

    public TreeBroadcast(TWSChannel tWSChannel, int i, Set<Integer> set, MessageReceiver messageReceiver, MessageSchema messageSchema) {
        this.source = i;
        this.destinations = set;
        this.finalReceiver = messageReceiver;
        this.delegate = new ChannelDataFlowOperation(tWSChannel);
        this.messageSchema = messageSchema;
        this.sourceSet.add(Integer.valueOf(i));
        this.targetSet = new HashSet(set);
        this.pendingFinishSources = new HashSet();
        this.finishedSources = new HashSet();
    }

    public void close() {
        if (this.finalReceiver != null) {
            this.finalReceiver.close();
        }
        this.delegate.close();
    }

    public void reset() {
        if (this.finalReceiver != null) {
            this.finalReceiver.clean();
        }
        this.pendingFinishSources.clear();
        this.finishedSources.clear();
    }

    public void finish(int i) {
        if (!this.thisSources.contains(Integer.valueOf(this.source))) {
            throw new RuntimeException("Invalid source completion: " + this.source);
        }
        this.lock.lock();
        try {
            this.pendingFinishSources.add(Integer.valueOf(this.source));
        } finally {
            this.lock.unlock();
        }
    }

    public LogicalPlan getLogicalPlan() {
        return this.instancePlan;
    }

    public String getUniqueId() {
        return String.valueOf(this.edge);
    }

    public boolean receiveMessage(MessageHeader messageHeader, Object obj) {
        return this.currentReceiveMessage.offer(new ImmutablePair(messageHeader, obj));
    }

    private boolean receiveProgressMessage() {
        Pair<MessageHeader, Object> peek = this.currentReceiveMessage.peek();
        if (peek == null) {
            return false;
        }
        MessageHeader messageHeader = (MessageHeader) peek.getLeft();
        Object right = peek.getRight();
        boolean z = true;
        boolean z2 = true;
        int i = this.receiveIndex;
        while (true) {
            if (i >= this.receiveTasks.size()) {
                break;
            }
            if (!this.finalReceiver.onMessage(messageHeader.getSourceId(), 0, this.receiveTasks.get(i).intValue(), messageHeader.getFlags(), right)) {
                z2 = false;
                z = false;
                break;
            }
            this.receiveIndex++;
            i++;
        }
        if (z) {
            this.currentReceiveMessage.poll();
            this.receiveIndex = 0;
        }
        return z2;
    }

    public void init(Config config, MessageType messageType, MessageType messageType2, MessageType messageType3, MessageType messageType4, LogicalPlan logicalPlan, int i) {
        this.config = config;
        this.instancePlan = logicalPlan;
        this.dataType = messageType;
        this.recvDataType = messageType2;
        this.keyType = messageType3;
        this.recvKeyType = messageType4;
        this.edge = i;
        this.executor = logicalPlan.getThisWorker();
        this.currentReceiveMessage = new ArrayBlockingQueue(CommunicationContext.sendPendingMax(config));
        this.router = new BinaryTreeRouter(config, logicalPlan, this.source, this.destinations);
        if (this.finalReceiver == null) {
            throw new RuntimeException("Final receiver is required");
        }
        this.finalReceiver.init(config, this, receiveExpectedTaskIds());
        LOG.log(Level.FINE, String.format("%d bast sources %d dest %s send tasks: %s", Integer.valueOf(this.executor), Integer.valueOf(this.source), this.destinations, this.router.sendQueueIds()));
        this.thisSources = TaskPlanUtils.getTasksOfThisWorker(logicalPlan, this.sourceSet);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        Set<Integer> sendQueueIds = this.router.sendQueueIds();
        Iterator<Integer> it = sendQueueIds.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            this.pendingSendMessagesPerSource.put(Integer.valueOf(intValue), new ArrayBlockingQueue<>(CommunicationContext.sendPendingMax(config)));
            hashMap3.put(Integer.valueOf(intValue), Serializers.get(messageType3 != null, this.messageSchema));
        }
        int receiveBufferCount = CommunicationContext.receiveBufferCount(config);
        int size = receivingExecutors().size();
        if (size == 0) {
            size = 1;
        }
        Iterator<Integer> it2 = this.router.getReceiveSources().iterator();
        while (it2.hasNext()) {
            it2.next().intValue();
            int i2 = receiveBufferCount * 2 * size;
            hashMap.put(Integer.valueOf(this.source), new ArrayBlockingQueue(i2));
            hashMap2.put(Integer.valueOf(this.source), new ArrayBlockingQueue(i2));
            hashMap4.put(Integer.valueOf(this.source), Deserializers.get(messageType3 != null, this.messageSchema));
        }
        calculateRoutingParameters();
        for (Integer num : sendQueueIds) {
            this.routingParametersCache.put(num, sendRoutingParameters(num.intValue(), 0));
        }
        if (this.keyType != null) {
            this.delegate.init(config, messageType, this.recvDataType, messageType3, messageType4, logicalPlan, i, this.router.receivingExecutors(), this, this.pendingSendMessagesPerSource, hashMap, hashMap2, hashMap3, hashMap4, true);
        } else {
            this.delegate.init(config, messageType, this.recvDataType, logicalPlan, i, this.router.receivingExecutors(), this, this.pendingSendMessagesPerSource, hashMap, hashMap2, hashMap3, hashMap4, false);
        }
    }

    public void init(Config config, MessageType messageType, MessageType messageType2, LogicalPlan logicalPlan, int i) {
        init(config, messageType, messageType2, this.keyType, this.keyType, logicalPlan, i);
    }

    public void init(Config config, MessageType messageType, LogicalPlan logicalPlan, int i) {
        init(config, messageType, messageType, logicalPlan, i);
    }

    public boolean sendPartial(int i, Object obj, int i2) {
        throw new RuntimeException("Not supported method");
    }

    public boolean isComplete() {
        if (!this.lock.tryLock()) {
            return true;
        }
        try {
            return this.delegate.isComplete() && this.finalReceiver.isComplete() && !handleFinish();
        } finally {
            this.lock.unlock();
        }
    }

    public boolean send(int i, Object obj, int i2) {
        return this.delegate.sendMessage(i, obj, 0, i2, sendRoutingParameters(i, 0));
    }

    public boolean send(int i, Object obj, int i2, int i3) {
        return this.delegate.sendMessage(i, obj, i3, i2, sendRoutingParameters(i, 0));
    }

    private void calculateRoutingParameters() {
        this.routingParametersCache.put(Integer.valueOf(this.source), sendRoutingParameters(this.source, 0));
        Set<Integer> tasksOfThisWorker = TaskPlanUtils.getTasksOfThisWorker(this.instancePlan, this.destinations);
        if (TaskPlanUtils.getThisWorkerTasks(this.instancePlan).contains(Integer.valueOf(this.source))) {
            return;
        }
        this.receiveTasks.addAll(tasksOfThisWorker);
    }

    public boolean sendPartial(int i, Object obj, int i2, int i3) {
        return false;
    }

    public boolean progress() {
        if (!this.lock.tryLock()) {
            return true;
        }
        try {
            try {
                boolean handleFinish = handleFinish();
                boolean receiveProgressMessage = receiveProgressMessage();
                this.delegate.progress();
                boolean progress = this.finalReceiver.progress();
                this.lock.unlock();
                return progress || handleFinish || receiveProgressMessage;
            } catch (Throwable th) {
                LOG.log(Level.SEVERE, "un-expected error", th);
                throw new RuntimeException(String.format("%d exception", Integer.valueOf(this.executor)), th);
            }
        } catch (Throwable th2) {
            this.lock.unlock();
            throw th2;
        }
    }

    private boolean handleFinish() {
        Iterator<Integer> it = this.pendingFinishSources.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (!this.finishedSources.contains(Integer.valueOf(intValue))) {
                if (!send(intValue, new byte[1], 67108864, 0)) {
                    return true;
                }
                this.finishedSources.add(Integer.valueOf(intValue));
            }
        }
        return false;
    }

    public boolean isDelegateComplete() {
        return this.delegate.isComplete();
    }

    public boolean handleReceivedChannelMessage(ChannelMessage channelMessage) {
        int mainTaskOfExecutor = this.router.mainTaskOfExecutor(this.instancePlan.getThisWorker(), 0);
        RoutingParameters sendRoutingParameters = this.routingParametersCache.containsKey(Integer.valueOf(mainTaskOfExecutor)) ? this.routingParametersCache.get(Integer.valueOf(mainTaskOfExecutor)) : sendRoutingParameters(mainTaskOfExecutor, 0);
        ArrayBlockingQueue<OutMessage> arrayBlockingQueue = this.pendingSendMessagesPerSource.get(Integer.valueOf(mainTaskOfExecutor));
        int i = -1;
        if (sendRoutingParameters.getExternalRoutes().size() > 0) {
            i = sendRoutingParameters.getDestinationId();
        }
        OutMessage outMessage = new OutMessage(mainTaskOfExecutor, channelMessage.getHeader().getEdge(), i, 0, channelMessage.getHeader().getFlags(), sendRoutingParameters.getInternalRoutes(), sendRoutingParameters.getExternalRoutes(), this.dataType, this.keyType, this.delegate, CommunicationContext.EMPTY_OBJECT);
        outMessage.getChannelMessages().offer(channelMessage);
        if (!channelMessage.isOutCountUpdated()) {
            channelMessage.incrementRefCount(sendRoutingParameters.getExternalRoutes().size());
            channelMessage.setOutCountUpdated(true);
        }
        outMessage.setSendState(OutMessage.SendState.SERIALIZED);
        return arrayBlockingQueue.offer(outMessage);
    }

    private RoutingParameters sendRoutingParameters(int i, int i2) {
        if (this.routingParametersCache.containsKey(Integer.valueOf(i))) {
            return this.routingParametersCache.get(Integer.valueOf(i));
        }
        RoutingParameters routingParameters = new RoutingParameters();
        Map<Integer, Set<Integer>> internalSendTasks = this.router.getInternalSendTasks(this.source);
        if (internalSendTasks == null) {
            throw new RuntimeException("Un-expected message from source: " + i);
        }
        Set<Integer> set = internalSendTasks.get(Integer.valueOf(i));
        if (set != null) {
            routingParameters.addInternalRoutes(set);
        }
        Map<Integer, Set<Integer>> externalSendTasks = this.router.getExternalSendTasks(i);
        if (externalSendTasks == null) {
            throw new RuntimeException("Un-expected message from source: " + i);
        }
        Set<Integer> set2 = externalSendTasks.get(Integer.valueOf(i));
        if (set2 != null) {
            routingParameters.addExternalRoutes(set2);
        }
        return routingParameters;
    }

    public boolean receiveSendInternally(int i, int i2, int i3, int i4, Object obj) {
        return this.finalReceiver.onMessage(i, i3, i2, i4, obj);
    }

    protected Set<Integer> receivingExecutors() {
        return this.router.receivingExecutors();
    }

    public Map<Integer, List<Integer>> receiveExpectedTaskIds() {
        return this.router.receiveExpectedTaskIds();
    }

    public Set<Integer> getSources() {
        return this.sourceSet;
    }

    public Set<Integer> getTargets() {
        return this.targetSet;
    }
}
