package edu.iu.dsc.tws.comms.dfw.io;

import edu.iu.dsc.tws.api.comms.CommunicationContext;
import edu.iu.dsc.tws.api.comms.DataFlowOperation;
import edu.iu.dsc.tws.api.comms.messaging.ChannelMessage;
import edu.iu.dsc.tws.api.comms.messaging.MessageReceiver;
import edu.iu.dsc.tws.api.comms.structs.Tuple;
import edu.iu.dsc.tws.api.config.Config;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/comms/dfw/io/SourceReceiver.class */
public abstract class SourceReceiver implements MessageReceiver {
    private static final Logger LOG = Logger.getLogger(SourceReceiver.class.getName());
    protected int workerId;
    protected DataFlowOperation operation;
    protected int sendPendingMax;
    protected int destination;
    private boolean completed;
    protected Map<Integer, Map<Integer, Queue<Object>>> messages = new HashMap();
    protected Map<Integer, Set<Integer>> syncReceived = new HashMap();
    protected Map<Integer, Integer> sourcesOfTarget = new HashMap();
    protected Map<Integer, Boolean> isSyncSent = new HashMap();
    protected Map<Integer, ReceiverState> targetStates = new HashMap();
    protected Map<Integer, byte[]> barriers = new HashMap();
    protected SyncState syncState = SyncState.SYNC;

    public void init(Config config, DataFlowOperation dataFlowOperation, Map<Integer, List<Integer>> map) {
        this.workerId = dataFlowOperation.getLogicalPlan().getThisWorker();
        this.sendPendingMax = CommunicationContext.sendPendingMax(config);
        this.operation = dataFlowOperation;
        for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {
            HashMap hashMap = new HashMap();
            Iterator<Integer> it = entry.getValue().iterator();
            while (it.hasNext()) {
                hashMap.put(Integer.valueOf(it.next().intValue()), new ArrayBlockingQueue(this.sendPendingMax));
            }
            this.syncReceived.put(entry.getKey(), new HashSet());
            this.messages.put(entry.getKey(), hashMap);
            this.isSyncSent.put(entry.getKey(), false);
            this.sourcesOfTarget.put(entry.getKey(), Integer.valueOf(entry.getValue().size()));
            this.targetStates.put(entry.getKey(), ReceiverState.INIT);
        }
    }

    public boolean onMessage(int i, int i2, int i3, int i4, Object obj) {
        Set<Integer> set = this.syncReceived.get(Integer.valueOf(i3));
        if ((i4 & 67108864) == 67108864) {
            if (this.syncState == SyncState.BARRIER_SYNC) {
                throw new RuntimeException("We are receiving barriers syncs and received a normal sycn");
            }
            set.add(Integer.valueOf(i));
            if (!allSyncsPresent(i3)) {
                return true;
            }
            this.targetStates.put(Integer.valueOf(i3), ReceiverState.ALL_SYNCS_RECEIVED);
            return true;
        }
        if ((i4 & 33554432) == 33554432) {
            set.add(Integer.valueOf(i));
            if (allSyncsPresent(i3)) {
                this.targetStates.put(Integer.valueOf(i3), ReceiverState.ALL_SYNCS_RECEIVED);
            }
            this.syncState = SyncState.BARRIER_SYNC;
            if (obj instanceof Tuple) {
                this.barriers.put(Integer.valueOf(i3), (byte[]) ((Tuple) obj).getValue());
                return true;
            }
            this.barriers.put(Integer.valueOf(i3), (byte[]) obj);
            return true;
        }
        if (this.targetStates.get(Integer.valueOf(i3)) == ReceiverState.ALL_SYNCS_RECEIVED || this.targetStates.get(Integer.valueOf(i3)) == ReceiverState.SYNCED) {
            return false;
        }
        if (this.targetStates.get(Integer.valueOf(i3)) == ReceiverState.INIT) {
            this.targetStates.put(Integer.valueOf(i3), ReceiverState.RECEIVING);
        }
        if (set.contains(Integer.valueOf(i))) {
            return false;
        }
        Queue<Object> queue = this.messages.get(Integer.valueOf(i3)).get(Integer.valueOf(i));
        if (queue.size() >= this.sendPendingMax) {
            return false;
        }
        if (obj instanceof ChannelMessage) {
            ((ChannelMessage) obj).incrementRefCount();
        }
        queue.add(obj);
        if ((i4 & 1073741824) != 1073741824) {
            return true;
        }
        set.add(Integer.valueOf(i));
        if (!allSyncsPresent(i3)) {
            return true;
        }
        this.targetStates.put(Integer.valueOf(i3), ReceiverState.ALL_SYNCS_RECEIVED);
        return true;
    }

    public boolean progress() {
        boolean z = false;
        boolean z2 = true;
        Iterator<Integer> it = this.messages.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (this.targetStates.get(Integer.valueOf(intValue)) != ReceiverState.SYNCED) {
                z2 = false;
                Map<Integer, Queue<Object>> map = this.messages.get(Integer.valueOf(intValue));
                Set<Integer> set = this.syncReceived.get(Integer.valueOf(intValue));
                boolean z3 = true;
                while (z3) {
                    boolean z4 = true;
                    boolean z5 = true;
                    boolean z6 = false;
                    for (Map.Entry<Integer, Queue<Object>> entry : map.entrySet()) {
                        if (entry.getValue().size() == 0) {
                            z4 = false;
                            z3 = false;
                        } else {
                            z6 = true;
                        }
                        if (!set.contains(entry.getKey())) {
                            z5 = false;
                        }
                    }
                    if (z6) {
                        aggregate(intValue, z5, z4);
                    }
                    if (isFilledToSend(intValue, z5) && !sendToTarget(intValue, z5)) {
                        z3 = false;
                    }
                    if (this.targetStates.get(Integer.valueOf(intValue)) == ReceiverState.ALL_SYNCS_RECEIVED && allQueuesEmpty(map) && isAllEmpty(intValue)) {
                        if (!sendSyncForward(intValue)) {
                            this.targetStates.put(Integer.valueOf(intValue), ReceiverState.SYNCED);
                            onSyncEvent(intValue, this.barriers.get(Integer.valueOf(intValue)));
                            clearTarget(intValue);
                        } else {
                            z = true;
                        }
                    }
                    if (!z3) {
                        z = true;
                    }
                }
            }
        }
        this.completed = z2;
        return z;
    }

    public boolean isComplete() {
        return this.completed;
    }

    protected abstract boolean isAllEmpty(int i);

    protected abstract boolean sendSyncForward(int i);

    private void clearTarget(int i) {
        Iterator<Map.Entry<Integer, Queue<Object>>> it = this.messages.get(Integer.valueOf(i)).entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().clear();
        }
        this.isSyncSent.put(Integer.valueOf(i), false);
        this.syncState = SyncState.SYNC;
        this.syncReceived.forEach((num, set) -> {
            set.clear();
        });
    }

    protected abstract boolean sendToTarget(int i, boolean z);

    protected abstract boolean aggregate(int i, boolean z, boolean z2);

    protected abstract boolean isFilledToSend(int i, boolean z);

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean allQueuesEmpty(Map<Integer, Queue<Object>> map) {
        Iterator<Map.Entry<Integer, Queue<Object>>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            if (it.next().getValue().size() > 0) {
                return false;
            }
        }
        return true;
    }

    protected boolean allSyncsPresent(int i) {
        return this.sourcesOfTarget.get(Integer.valueOf(i)).intValue() == this.syncReceived.get(Integer.valueOf(i)).size();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void onFinish(int i) {
        Iterator<Integer> it = this.syncReceived.keySet().iterator();
        while (it.hasNext()) {
            this.syncReceived.get(it.next()).add(Integer.valueOf(i));
        }
    }

    public void clean() {
        Iterator<Integer> it = this.messages.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            clearTarget(intValue);
            this.targetStates.put(Integer.valueOf(intValue), ReceiverState.INIT);
        }
        this.completed = false;
    }

    protected abstract void onSyncEvent(int i, byte[] bArr);
}
