package edu.iu.dsc.tws.executor.core;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import edu.iu.dsc.tws.api.checkpointing.CheckpointingClient;
import edu.iu.dsc.tws.api.comms.Communicator;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.executor.IExecutionPlanBuilder;
import edu.iu.dsc.tws.api.compute.executor.INodeInstance;
import edu.iu.dsc.tws.api.compute.executor.IParallelOperation;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.Edge;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.graph.Vertex;
import edu.iu.dsc.tws.api.compute.nodes.ICompute;
import edu.iu.dsc.tws.api.compute.nodes.INode;
import edu.iu.dsc.tws.api.compute.nodes.ISource;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskInstancePlan;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskSchedulePlan;
import edu.iu.dsc.tws.api.compute.schedule.elements.WorkerSchedulePlan;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.exceptions.net.BlockingSendException;
import edu.iu.dsc.tws.checkpointing.task.CheckpointableTask;
import edu.iu.dsc.tws.checkpointing.task.CheckpointingSGatherSink;
import edu.iu.dsc.tws.checkpointing.util.CheckpointingContext;
import edu.iu.dsc.tws.executor.core.batch.SourceBatchInstance;
import edu.iu.dsc.tws.executor.core.batch.TaskBatchInstance;
import edu.iu.dsc.tws.executor.core.streaming.SourceStreamingInstance;
import edu.iu.dsc.tws.executor.core.streaming.TaskStreamingInstance;
import edu.iu.dsc.tws.executor.util.Utils;
import edu.iu.dsc.tws.proto.jobmaster.JobMasterAPI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/* loaded from: input_file:edu/iu/dsc/tws/executor/core/ExecutionPlanBuilder.class */
public class ExecutionPlanBuilder implements IExecutionPlanBuilder {
    private static final Logger LOG;
    private int workerId;
    private CheckpointingClient checkpointingClient;
    private Communicator network;
    private List<JobMasterAPI.WorkerInfo> workerInfoList;
    static final /* synthetic */ boolean $assertionsDisabled;
    private Table<String, String, Communication> parOpTable = HashBasedTable.create();
    private Table<String, Integer, TaskBatchInstance> batchTaskInstances = HashBasedTable.create();
    private Table<String, Integer, SourceBatchInstance> batchSourceInstances = HashBasedTable.create();
    private Table<String, Integer, TaskStreamingInstance> streamingTaskInstances = HashBasedTable.create();
    private Table<String, Integer, SourceStreamingInstance> streamingSourceInstances = HashBasedTable.create();
    private Map<String, Communication> targetParOpTable = new HashMap();
    private TaskIdGenerator taskIdGenerator = new TaskIdGenerator();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/executor/core/ExecutionPlanBuilder$Communication.class */
    public class Communication {
        private Set<Integer> sourceTasks;
        private Set<Integer> targetTasks;
        private String targetTask;
        private int numberOfEdges;
        private Map<Integer, Integer> srcGlobalToIndex;
        private Map<Integer, Integer> tarGlobalToIndex;
        private List<Edge> edge = new ArrayList();
        private List<String> sourceTask = new ArrayList();
        private Map<Integer, Edge> edgeMap = new HashMap();
        private Map<Integer, String> sourceTaskMap = new HashMap();

        Communication(String str, Set<Integer> set, Set<Integer> set2, int i, Map<Integer, Integer> map, Map<Integer, Integer> map2) {
            this.targetTasks = set2;
            this.sourceTasks = set;
            this.targetTask = str;
            this.numberOfEdges = i;
            this.srcGlobalToIndex = map;
            this.tarGlobalToIndex = map2;
        }

        void build() {
            for (int i = 0; i < this.edgeMap.size(); i++) {
                this.edge.add(this.edgeMap.get(Integer.valueOf(i)));
                this.sourceTask.add(this.sourceTaskMap.get(Integer.valueOf(i)));
            }
        }

        public void addSourceTasks(Set<Integer> set) {
            this.sourceTasks.addAll(set);
        }

        void addEdge(int i, Edge edge) {
            this.edgeMap.put(Integer.valueOf(i), edge);
        }

        Set<Integer> getSourceTasks() {
            return this.sourceTasks;
        }

        Set<Integer> getTargetTasks() {
            return this.targetTasks;
        }

        Edge getEdge(int i) {
            return this.edge.get(i);
        }

        void addSourceTask(int i, String str) {
            this.sourceTaskMap.put(Integer.valueOf(i), str);
        }

        List<String> getSourceTask() {
            return this.sourceTask;
        }

        String getTargetTask() {
            return this.targetTask;
        }

        List<Edge> getEdge() {
            return this.edge;
        }

        int getNumberOfEdges() {
            return this.numberOfEdges;
        }
    }

    public ExecutionPlanBuilder(int i, List<JobMasterAPI.WorkerInfo> list, Communicator communicator, CheckpointingClient checkpointingClient) {
        this.workerId = i;
        this.checkpointingClient = checkpointingClient;
        this.workerInfoList = list;
        this.network = communicator;
    }

    public ExecutionPlan build(Config config, ComputeGraph computeGraph, TaskSchedulePlan taskSchedulePlan) {
        IParallelOperation build;
        ParallelOperationFactory parallelOperationFactory = new ParallelOperationFactory(config, this.network, TaskPlanBuilder.build(this.workerId, this.workerInfoList, taskSchedulePlan, this.taskIdGenerator));
        Map containersMap = taskSchedulePlan.getContainersMap();
        WorkerSchedulePlan workerSchedulePlan = (WorkerSchedulePlan) containersMap.get(Integer.valueOf(this.workerId));
        if (workerSchedulePlan == null) {
            LOG.log(Level.INFO, "Cannot find worker in the task plan: " + this.workerId);
            return null;
        }
        ExecutionPlan executionPlan = new ExecutionPlan();
        Set<TaskInstancePlan> taskInstances = workerSchedulePlan.getTaskInstances();
        long j = 0;
        if (CheckpointingContext.isCheckpointingEnabled(config)) {
            Set emptySet = Collections.emptySet();
            if (this.workerId == 0) {
                emptySet = (Set) containersMap.values().stream().flatMap(workerSchedulePlan2 -> {
                    return workerSchedulePlan2.getTaskInstances().stream();
                }).filter(taskInstancePlan -> {
                    return (computeGraph.vertex(taskInstancePlan.getTaskName()).getTask() instanceof CheckpointableTask) && !(computeGraph.vertex(taskInstancePlan.getTaskName()).getTask() instanceof CheckpointingSGatherSink);
                }).map((v0) -> {
                    return v0.getTaskId();
                }).collect(Collectors.toSet());
            }
            try {
                j = this.checkpointingClient.initFamily(this.workerId, containersMap.size(), computeGraph.getGraphName(), emptySet).getVersion();
                LOG.info("Tasks will start with version " + j);
            } catch (BlockingSendException e) {
                throw new RuntimeException("Failed to register tasks with Checkpoint Manager", e);
            }
        }
        for (TaskInstancePlan taskInstancePlan2 : taskInstances) {
            Vertex vertex = computeGraph.vertex(taskInstancePlan2.getTaskName());
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            if (vertex == null) {
                throw new RuntimeException("Non-existing task scheduled: " + taskInstancePlan2.getTaskName());
            }
            INode task = vertex.getTask();
            if ((task instanceof ICompute) || (task instanceof ISource)) {
                for (Edge edge : computeGraph.outEdges(vertex)) {
                    Vertex childOfTask = computeGraph.childOfTask(vertex, edge.getName());
                    createCommunication(childOfTask, edge, vertex, this.taskIdGenerator.getTaskIds(vertex, taskInstancePlan2.getTaskId()), this.taskIdGenerator.getTaskIds(childOfTask, getTaskIdOfTask(childOfTask.getName(), taskSchedulePlan)), this.taskIdGenerator.getGlobalTaskToIndex(vertex, taskInstancePlan2.getTaskId()), this.taskIdGenerator.getGlobalTaskToIndex(childOfTask, getTaskIdOfTask(childOfTask.getName(), taskSchedulePlan)));
                    hashMap2.put(edge.getName(), childOfTask.getName());
                }
            }
            if (task instanceof ICompute) {
                for (Edge edge2 : computeGraph.inEdges(vertex)) {
                    Vertex parentOfTask = computeGraph.getParentOfTask(vertex, edge2.getName());
                    createCommunication(vertex, edge2, parentOfTask, this.taskIdGenerator.getTaskIds(parentOfTask, getTaskIdOfTask(parentOfTask.getName(), taskSchedulePlan)), this.taskIdGenerator.getTaskIds(vertex, taskInstancePlan2.getTaskId()), this.taskIdGenerator.getGlobalTaskToIndex(parentOfTask, getTaskIdOfTask(parentOfTask.getName(), taskSchedulePlan)), this.taskIdGenerator.getGlobalTaskToIndex(vertex, taskInstancePlan2.getTaskId()));
                    String name = edge2.getTargetEdge() == null ? edge2.getName() : edge2.getTargetEdge();
                    Set<String> set = hashMap.get(name);
                    if (set == null) {
                        set = new HashSet();
                    }
                    set.add(name);
                    hashMap.put(name, set);
                }
            }
            executionPlan.addNodes(vertex.getName(), this.taskIdGenerator.generateGlobalTaskId(taskInstancePlan2.getTaskId(), taskInstancePlan2.getTaskIndex()), createInstances(config, computeGraph.getGraphName(), taskInstancePlan2, vertex, computeGraph.getOperationMode(), hashMap, hashMap2, taskSchedulePlan, j));
        }
        Iterator it = this.parOpTable.cellSet().iterator();
        while (it.hasNext()) {
            Communication communication = (Communication) ((Table.Cell) it.next()).getValue();
            OperationMode operationMode = computeGraph.getOperationMode();
            if (!$assertionsDisabled && communication == null) {
                throw new AssertionError();
            }
            communication.build();
            if (communication.getEdge().size() == 1) {
                build = parallelOperationFactory.build(communication.getEdge(0), communication.getSourceTasks(), communication.getTargetTasks(), operationMode, communication.srcGlobalToIndex, communication.tarGlobalToIndex);
            } else {
                if (communication.getEdge().size() <= 1) {
                    throw new RuntimeException("Cannot have communication with 0 edges");
                }
                Set<Integer> sourceTasks = communication.getSourceTasks();
                HashSet hashSet = new HashSet();
                HashSet hashSet2 = new HashSet();
                if (!sourceTasks.isEmpty()) {
                    int intValue = (sourceTasks.stream().min((v0, v1) -> {
                        return v0.compareTo(v1);
                    }).get().intValue() / TaskIdGenerator.TASK_OFFSET) * TaskIdGenerator.TASK_OFFSET;
                    for (Integer num : sourceTasks) {
                        if ((num.intValue() / TaskIdGenerator.TASK_OFFSET) * TaskIdGenerator.TASK_OFFSET == intValue) {
                            hashSet.add(num);
                        } else {
                            hashSet2.add(num);
                        }
                    }
                }
                build = parallelOperationFactory.build(communication.getEdge(0), communication.getEdge(1), hashSet, hashSet2, communication.getTargetTasks(), operationMode, communication.srcGlobalToIndex, communication.tarGlobalToIndex);
            }
            Set<Integer> intersectionOfTasks = intersectionOfTasks(workerSchedulePlan, communication.getSourceTasks());
            Set<Integer> intersectionOfTasks2 = intersectionOfTasks(workerSchedulePlan, communication.getTargetTasks());
            String targetEdge = communication.getEdge().size() > 1 ? communication.getEdge(0).getTargetEdge() : communication.getEdge(0).getName();
            if (operationMode == OperationMode.STREAMING) {
                for (Integer num2 : intersectionOfTasks) {
                    boolean z = false;
                    for (int i = 0; i < communication.getSourceTask().size(); i++) {
                        String str = communication.getSourceTask().get(i);
                        if (this.streamingTaskInstances.contains(str, num2)) {
                            TaskStreamingInstance taskStreamingInstance = (TaskStreamingInstance) this.streamingTaskInstances.get(str, num2);
                            taskStreamingInstance.registerOutParallelOperation(communication.getEdge(i).getName(), build);
                            build.registerSync(num2.intValue(), taskStreamingInstance);
                            z = true;
                        } else if (this.streamingSourceInstances.contains(str, num2)) {
                            ((SourceStreamingInstance) this.streamingSourceInstances.get(str, num2)).registerOutParallelOperation(communication.getEdge(i).getName(), build);
                            z = true;
                        }
                        if (!z) {
                            throw new RuntimeException("Not found: " + communication.getSourceTask());
                        }
                    }
                }
                for (Integer num3 : intersectionOfTasks2) {
                    if (!this.streamingTaskInstances.contains(communication.getTargetTask(), num3)) {
                        throw new RuntimeException("Not found: " + communication.getTargetTask());
                    }
                    TaskStreamingInstance taskStreamingInstance2 = (TaskStreamingInstance) this.streamingTaskInstances.get(communication.getTargetTask(), num3);
                    build.register(num3.intValue(), taskStreamingInstance2.getInQueue());
                    taskStreamingInstance2.registerInParallelOperation(targetEdge, build);
                    build.registerSync(num3.intValue(), taskStreamingInstance2);
                }
                executionPlan.addOps(build);
            }
            if (operationMode == OperationMode.BATCH) {
                for (Integer num4 : intersectionOfTasks) {
                    boolean z2 = false;
                    for (int i2 = 0; i2 < communication.getSourceTask().size(); i2++) {
                        String str2 = communication.getSourceTask().get(i2);
                        if (this.batchTaskInstances.contains(str2, num4)) {
                            ((TaskBatchInstance) this.batchTaskInstances.get(str2, num4)).registerOutParallelOperation(communication.getEdge(i2).getName(), build);
                            z2 = true;
                        } else if (this.batchSourceInstances.contains(str2, num4)) {
                            ((SourceBatchInstance) this.batchSourceInstances.get(str2, num4)).registerOutParallelOperation(communication.getEdge(i2).getName(), build);
                            z2 = true;
                        }
                    }
                    if (!z2) {
                        throw new RuntimeException("Not found: " + communication.getSourceTask());
                    }
                }
                for (Integer num5 : intersectionOfTasks2) {
                    if (!this.batchTaskInstances.contains(communication.getTargetTask(), num5)) {
                        throw new RuntimeException("Not found: " + communication.getTargetTask());
                    }
                    TaskBatchInstance taskBatchInstance = (TaskBatchInstance) this.batchTaskInstances.get(communication.getTargetTask(), num5);
                    build.register(num5.intValue(), taskBatchInstance.getInQueue());
                    taskBatchInstance.registerInParallelOperation(targetEdge, build);
                    build.registerSync(num5.intValue(), taskBatchInstance);
                }
                executionPlan.addOps(build);
            }
        }
        return executionPlan;
    }

    private void createCommunication(Vertex vertex, Edge edge, Vertex vertex2, Set<Integer> set, Set<Integer> set2, Map<Integer, Integer> map, Map<Integer, Integer> map2) {
        if (edge.getTargetEdge() == null) {
            if (this.parOpTable.contains(vertex2.getName(), edge.getName())) {
                return;
            }
            Communication communication = new Communication(vertex.getName(), set, set2, edge.getNumberOfEdges(), map, map2);
            communication.addEdge(edge.getEdgeIndex(), edge);
            communication.addSourceTask(edge.getEdgeIndex(), vertex2.getName());
            this.parOpTable.put(vertex2.getName(), edge.getName(), communication);
            return;
        }
        if (this.parOpTable.contains(vertex2.getName(), edge.getTargetEdge())) {
            Communication communication2 = (Communication) this.parOpTable.get(vertex2.getName(), edge.getTargetEdge());
            communication2.addEdge(edge.getEdgeIndex(), edge);
            communication2.addSourceTasks(set);
            communication2.addSourceTask(edge.getEdgeIndex(), vertex2.getName());
            return;
        }
        if (this.targetParOpTable.containsKey(edge.getTargetEdge())) {
            Communication communication3 = this.targetParOpTable.get(edge.getTargetEdge());
            communication3.addEdge(edge.getEdgeIndex(), edge);
            communication3.addSourceTasks(set);
            communication3.addSourceTask(edge.getEdgeIndex(), vertex2.getName());
            return;
        }
        Communication communication4 = new Communication(vertex.getName(), set, set2, edge.getNumberOfEdges(), map, map2);
        communication4.addEdge(edge.getEdgeIndex(), edge);
        communication4.addSourceTask(edge.getEdgeIndex(), vertex2.getName());
        this.parOpTable.put(vertex2.getName(), edge.getTargetEdge(), communication4);
        this.targetParOpTable.put(edge.getTargetEdge(), communication4);
    }

    private Set<Integer> intersectionOfTasks(WorkerSchedulePlan workerSchedulePlan, Set<Integer> set) {
        Set<Integer> taskIdsOfContainer = this.taskIdGenerator.getTaskIdsOfContainer(workerSchedulePlan);
        taskIdsOfContainer.retainAll(set);
        return taskIdsOfContainer;
    }

    private INodeInstance createInstances(Config config, String str, TaskInstancePlan taskInstancePlan, Vertex vertex, OperationMode operationMode, Map<String, Set<String>> map, Map<String, String> map2, TaskSchedulePlan taskSchedulePlan, long j) {
        ICompute iCompute = (INode) Utils.deserialize(Utils.serialize(vertex.getTask()));
        int generateGlobalTaskId = this.taskIdGenerator.generateGlobalTaskId(taskInstancePlan.getTaskId(), taskInstancePlan.getTaskIndex());
        if (operationMode.equals(OperationMode.BATCH)) {
            if (iCompute instanceof ICompute) {
                TaskBatchInstance taskBatchInstance = new TaskBatchInstance(iCompute, new LinkedBlockingQueue(), new LinkedBlockingQueue(), config, vertex.getName(), taskInstancePlan.getTaskId(), generateGlobalTaskId, taskInstancePlan.getTaskIndex(), vertex.getParallelism(), this.workerId, vertex.getConfig().toMap(), map, map2, taskSchedulePlan, this.checkpointingClient, str, j);
                this.batchTaskInstances.put(vertex.getName(), Integer.valueOf(generateGlobalTaskId), taskBatchInstance);
                return taskBatchInstance;
            }
            if (!(iCompute instanceof ISource)) {
                throw new RuntimeException("Un-known type");
            }
            SourceBatchInstance sourceBatchInstance = new SourceBatchInstance((ISource) iCompute, new LinkedBlockingQueue(), config, vertex.getName(), taskInstancePlan.getTaskId(), generateGlobalTaskId, taskInstancePlan.getTaskIndex(), vertex.getParallelism(), this.workerId, vertex.getConfig().toMap(), map2, taskSchedulePlan, this.checkpointingClient, str, j);
            this.batchSourceInstances.put(vertex.getName(), Integer.valueOf(generateGlobalTaskId), sourceBatchInstance);
            return sourceBatchInstance;
        }
        if (!operationMode.equals(OperationMode.STREAMING)) {
            return null;
        }
        if (iCompute instanceof ICompute) {
            TaskStreamingInstance taskStreamingInstance = new TaskStreamingInstance(iCompute, new LinkedBlockingQueue(), new LinkedBlockingQueue(), config, vertex.getName(), taskInstancePlan.getTaskId(), generateGlobalTaskId, taskInstancePlan.getTaskIndex(), vertex.getParallelism(), this.workerId, vertex.getConfig().toMap(), map, map2, taskSchedulePlan, this.checkpointingClient, str, j);
            this.streamingTaskInstances.put(vertex.getName(), Integer.valueOf(generateGlobalTaskId), taskStreamingInstance);
            return taskStreamingInstance;
        }
        if (!(iCompute instanceof ISource)) {
            throw new RuntimeException("Un-known type");
        }
        SourceStreamingInstance sourceStreamingInstance = new SourceStreamingInstance((ISource) iCompute, new LinkedBlockingQueue(), config, vertex.getName(), taskInstancePlan.getTaskId(), generateGlobalTaskId, taskInstancePlan.getTaskIndex(), vertex.getParallelism(), this.workerId, vertex.getConfig().toMap(), map2, taskSchedulePlan, this.checkpointingClient, str, j);
        this.streamingSourceInstances.put(vertex.getName(), Integer.valueOf(generateGlobalTaskId), sourceStreamingInstance);
        return sourceStreamingInstance;
    }

    private int getTaskIdOfTask(String str, TaskSchedulePlan taskSchedulePlan) {
        Iterator it = taskSchedulePlan.getContainers().iterator();
        while (it.hasNext()) {
            for (TaskInstancePlan taskInstancePlan : ((WorkerSchedulePlan) it.next()).getTaskInstances()) {
                if (str.equals(taskInstancePlan.getTaskName())) {
                    return taskInstancePlan.getTaskId();
                }
            }
        }
        throw new RuntimeException("Task without a schedule plan: " + str);
    }

    static {
        $assertionsDisabled = !ExecutionPlanBuilder.class.desiredAssertionStatus();
        LOG = Logger.getLogger(ExecutionPlanBuilder.class.getName());
    }
}
