package edu.iu.dsc.tws.executor.threading;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import edu.iu.dsc.tws.api.comms.channel.TWSChannel;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.executor.ExecutionState;
import edu.iu.dsc.tws.api.compute.executor.ExecutorContext;
import edu.iu.dsc.tws.api.compute.executor.IExecution;
import edu.iu.dsc.tws.api.compute.executor.IExecutor;
import edu.iu.dsc.tws.api.compute.executor.INodeInstance;
import edu.iu.dsc.tws.api.compute.executor.IParallelOperation;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.executor.core.ExecutionRuntime;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/executor/threading/BatchSharingExecutor2.class */
public class BatchSharingExecutor2 implements IExecutor {
    private static final Logger LOG = Logger.getLogger(BatchSharingExecutor2.class.getName());
    protected int numThreads;
    protected ExecutorService threads;
    protected TWSChannel channel;
    protected Config config;
    private int workerId;
    private CountDownLatch doneSignal;
    private AtomicInteger finishedInstances = new AtomicInteger(0);
    private boolean notStopped = true;
    private boolean cleanUpCalled = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/executor/threading/BatchSharingExecutor2$BatchExecution.class */
    public class BatchExecution implements IExecution {
        private Map<Integer, INodeInstance> nodeMap;
        private ExecutionPlan executionPlan;
        private BlockingQueue<INodeInstance> tasks;
        private boolean taskExecution = true;
        private BatchWorker mainWorker;
        private CommunicationWorker worker;

        BatchExecution(ExecutionPlan executionPlan, Map<Integer, INodeInstance> map, BatchWorker batchWorker) {
            this.nodeMap = map;
            this.executionPlan = executionPlan;
            this.mainWorker = batchWorker;
            this.tasks = new ArrayBlockingQueue(map.size() * 2);
            this.tasks.addAll(map.values());
        }

        public boolean waitForCompletion() {
            while (BatchSharingExecutor2.this.notStopped && BatchSharingExecutor2.this.finishedInstances.get() != this.nodeMap.size()) {
                BatchSharingExecutor2.this.channel.progress();
                this.mainWorker.runExecution();
            }
            this.executionPlan.setExecutionState(ExecutionState.EXECUTED);
            BatchSharingExecutor2.this.cleanUp(this.executionPlan, this.nodeMap);
            BatchSharingExecutor2.this.waitFor(this.executionPlan);
            return true;
        }

        public boolean progress() {
            if (this.taskExecution) {
                if (BatchSharingExecutor2.this.finishedInstances.get() != this.nodeMap.size()) {
                    BatchSharingExecutor2.this.channel.progress();
                    this.mainWorker.runExecution();
                    return true;
                }
                this.executionPlan.setExecutionState(ExecutionState.EXECUTED);
                BatchSharingExecutor2.this.cleanUp(this.executionPlan, this.nodeMap);
                BatchSharingExecutor2.this.cleanUpCalled = false;
                this.worker = BatchSharingExecutor2.this.scheduleWaitFor(this.nodeMap)[0];
                this.taskExecution = false;
            }
            if (!BatchSharingExecutor2.this.notStopped || BatchSharingExecutor2.this.finishedInstances.get() == this.nodeMap.size()) {
                return false;
            }
            BatchSharingExecutor2.this.channel.progress();
            this.worker.runChannelComplete();
            return true;
        }

        public void close() {
            if (BatchSharingExecutor2.this.notStopped) {
                throw new RuntimeException("We need to stop the execution before close");
            }
            if (BatchSharingExecutor2.this.cleanUpCalled) {
                throw new RuntimeException("Close is called on a already closed execution");
            }
            BatchSharingExecutor2.this.close(this.executionPlan, this.nodeMap);
            BatchSharingExecutor2.this.cleanUpCalled = true;
        }

        public void stop() {
            BatchSharingExecutor2.this.notStopped = false;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:edu/iu/dsc/tws/executor/threading/BatchSharingExecutor2$BatchWorker.class */
    public class BatchWorker implements Runnable {
        private List<INodeInstance> tasks;
        private AtomicBoolean[] ignoreIndex;
        private int lastIndex;

        public BatchWorker(List<INodeInstance> list, AtomicBoolean[] atomicBooleanArr) {
            this.tasks = list;
            this.ignoreIndex = atomicBooleanArr;
        }

        private int getNext() {
            if (this.lastIndex == this.tasks.size()) {
                this.lastIndex = 0;
            }
            if (!this.ignoreIndex[this.lastIndex].compareAndSet(false, true)) {
                this.lastIndex++;
                return -1;
            }
            int i = this.lastIndex;
            this.lastIndex = i + 1;
            return i;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (BatchSharingExecutor2.this.notStopped && BatchSharingExecutor2.this.finishedInstances.get() != this.tasks.size()) {
                runExecution();
            }
            BatchSharingExecutor2.this.doneSignal.countDown();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void runExecution() {
            try {
                int next = getNext();
                if (next != -1) {
                    if (this.tasks.get(next).execute()) {
                        this.ignoreIndex[next].set(false);
                    } else {
                        BatchSharingExecutor2.this.finishedInstances.incrementAndGet();
                    }
                }
            } catch (Throwable th) {
                BatchSharingExecutor2.LOG.log(Level.SEVERE, String.format("%d Error in executor", Integer.valueOf(BatchSharingExecutor2.this.workerId)), th);
                throw new RuntimeException("Error occurred in execution of task", th);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:edu/iu/dsc/tws/executor/threading/BatchSharingExecutor2$CommunicationWorker.class */
    public class CommunicationWorker implements Runnable {
        private BlockingQueue<INodeInstance> tasks;

        public CommunicationWorker(BlockingQueue<INodeInstance> blockingQueue) {
            this.tasks = blockingQueue;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (BatchSharingExecutor2.this.notStopped && runChannelComplete()) {
            }
            BatchSharingExecutor2.this.doneSignal.countDown();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean runChannelComplete() {
            try {
                INodeInstance poll = this.tasks.poll();
                if (poll == null) {
                    return false;
                }
                if (poll.isComplete()) {
                    BatchSharingExecutor2.this.finishedInstances.incrementAndGet();
                    return true;
                }
                this.tasks.offer(poll);
                return true;
            } catch (Throwable th) {
                BatchSharingExecutor2.LOG.log(Level.SEVERE, String.format("%d Error in executor", Integer.valueOf(BatchSharingExecutor2.this.workerId)), th);
                throw new RuntimeException("Error occurred in execution of task", th);
            }
        }
    }

    public BatchSharingExecutor2(Config config, int i, TWSChannel tWSChannel) {
        this.workerId = i;
        this.config = config;
        this.channel = tWSChannel;
        this.numThreads = ExecutorContext.threadsPerContainer(this.config);
        if (this.numThreads > 1) {
            this.threads = Executors.newFixedThreadPool(this.numThreads - 1, new ThreadFactoryBuilder().setNameFormat("executor-%d").setDaemon(true).build());
        }
    }

    public boolean execute(ExecutionPlan executionPlan) {
        this.config = Config.newBuilder().putAll(this.config).put("_twister2.runtime_", new ExecutionRuntime(ExecutorContext.jobName(this.config), executionPlan, this.channel)).build();
        if (executionPlan.getExecutionState() == ExecutionState.EXECUTED) {
            resetNodes(executionPlan.getNodes(), executionPlan.getParallelOperations());
        }
        return runExecution(executionPlan);
    }

    public IExecution iExecute(ExecutionPlan executionPlan) {
        this.config = Config.newBuilder().putAll(this.config).put("_twister2.runtime_", new ExecutionRuntime(ExecutorContext.jobName(this.config), executionPlan, this.channel)).build();
        if (executionPlan.getExecutionState() == ExecutionState.EXECUTED) {
            resetNodes(executionPlan.getNodes(), executionPlan.getParallelOperations());
        }
        return runIExecution(executionPlan);
    }

    public void close() {
        if (this.threads != null) {
            this.threads.shutdown();
        }
    }

    public boolean runExecution(ExecutionPlan executionPlan) {
        Map<Integer, INodeInstance> nodes = executionPlan.getNodes();
        if (nodes.size() == 0) {
            LOG.warning(String.format("Worker %d has zero assigned tasks, you may have more workers than tasks", Integer.valueOf(this.workerId)));
            return true;
        }
        BatchWorker batchWorker = scheduleExecution(nodes)[0];
        while (this.notStopped && this.finishedInstances.get() != nodes.size()) {
            this.channel.progress();
            batchWorker.runExecution();
        }
        cleanUp(executionPlan, nodes);
        return true;
    }

    private BatchWorker[] scheduleExecution(Map<Integer, INodeInstance> map) {
        ArrayList arrayList = new ArrayList(map.values());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((INodeInstance) it.next()).prepare(this.config);
        }
        BatchWorker[] batchWorkerArr = new BatchWorker[this.numThreads];
        AtomicBoolean[] atomicBooleanArr = new AtomicBoolean[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            atomicBooleanArr[i] = new AtomicBoolean(false);
        }
        this.doneSignal = new CountDownLatch(this.numThreads - 1);
        batchWorkerArr[0] = new BatchWorker(arrayList, atomicBooleanArr);
        for (int i2 = 1; i2 < this.numThreads; i2++) {
            BatchWorker batchWorker = new BatchWorker(arrayList, atomicBooleanArr);
            this.threads.submit(batchWorker);
            batchWorkerArr[i2] = batchWorker;
        }
        return batchWorkerArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void cleanUp(ExecutionPlan executionPlan, Map<Integer, INodeInstance> map) {
        try {
            this.doneSignal.await();
            executionPlan.setExecutionState(ExecutionState.EXECUTED);
            this.finishedInstances.set(0);
            this.cleanUpCalled = true;
        } catch (InterruptedException e) {
            throw new RuntimeException("Interrupted", e);
        }
    }

    public IExecution runIExecution(ExecutionPlan executionPlan) {
        Map<Integer, INodeInstance> nodes = executionPlan.getNodes();
        if (nodes.size() != 0) {
            return new BatchExecution(executionPlan, nodes, scheduleExecution(nodes)[0]);
        }
        LOG.warning(String.format("Worker %d has zero assigned tasks, you may have more workers than tasks", Integer.valueOf(this.workerId)));
        return new NullExecutor();
    }

    public boolean waitFor(ExecutionPlan executionPlan) {
        Map<Integer, INodeInstance> nodes = executionPlan.getNodes();
        if (nodes.size() == 0) {
            LOG.warning(String.format("Worker %d has zero assigned tasks, you may have more workers than tasks", Integer.valueOf(this.workerId)));
            return true;
        }
        CommunicationWorker communicationWorker = scheduleWaitFor(nodes)[0];
        while (this.notStopped && this.finishedInstances.get() != nodes.size()) {
            this.channel.progress();
            communicationWorker.runChannelComplete();
        }
        close(executionPlan, nodes);
        return true;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public CommunicationWorker[] scheduleWaitFor(Map<Integer, INodeInstance> map) {
        ArrayBlockingQueue arrayBlockingQueue = new ArrayBlockingQueue(map.size() * 2);
        arrayBlockingQueue.addAll(map.values());
        CommunicationWorker[] communicationWorkerArr = new CommunicationWorker[this.numThreads];
        communicationWorkerArr[0] = new CommunicationWorker(arrayBlockingQueue);
        this.doneSignal = new CountDownLatch(this.numThreads - 1);
        for (int i = 1; i < this.numThreads; i++) {
            communicationWorkerArr[i] = new CommunicationWorker(arrayBlockingQueue);
            this.threads.submit(communicationWorkerArr[i]);
        }
        return communicationWorkerArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void close(ExecutionPlan executionPlan, Map<Integer, INodeInstance> map) {
        try {
            this.doneSignal.await();
            List<IParallelOperation> parallelOperations = executionPlan.getParallelOperations();
            resetNodes(map, parallelOperations);
            Iterator<INodeInstance> it = map.values().iterator();
            while (it.hasNext()) {
                it.next().close();
            }
            Iterator<IParallelOperation> it2 = parallelOperations.iterator();
            while (it2.hasNext()) {
                it2.next().close();
            }
            this.finishedInstances.set(0);
            this.cleanUpCalled = true;
        } catch (InterruptedException e) {
            throw new RuntimeException("Interrupted", e);
        }
    }

    private void resetNodes(Map<Integer, INodeInstance> map, List<IParallelOperation> list) {
        Iterator<INodeInstance> it = map.values().iterator();
        while (it.hasNext()) {
            it.next().reset();
        }
        Iterator<IParallelOperation> it2 = list.iterator();
        while (it2.hasNext()) {
            it2.next().reset();
        }
    }
}
