package edu.iu.dsc.tws.examples.internal.batchscheduler;

import edu.iu.dsc.tws.api.JobConfig;
import edu.iu.dsc.tws.api.Twister2Job;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.IFunction;
import edu.iu.dsc.tws.api.compute.IMessage;
import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.modifiers.Collector;
import edu.iu.dsc.tws.api.compute.modifiers.IONames;
import edu.iu.dsc.tws.api.compute.modifiers.Receptor;
import edu.iu.dsc.tws.api.compute.nodes.BaseCompute;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.api.dataset.DataPartition;
import edu.iu.dsc.tws.api.resource.Twister2Worker;
import edu.iu.dsc.tws.api.resource.WorkerEnvironment;
import edu.iu.dsc.tws.dataset.partition.EntityPartition;
import edu.iu.dsc.tws.examples.batch.cdfw.CDFConstants;
import edu.iu.dsc.tws.rsched.core.ResourceAllocator;
import edu.iu.dsc.tws.rsched.job.Twister2Submitter;
import edu.iu.dsc.tws.task.ComputeEnvironment;
import edu.iu.dsc.tws.task.impl.ComputeConnection;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.TaskExecutor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

/* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample.class */
public class BatchTaskSchedulerExample implements Twister2Worker {
    private static final Logger LOG = Logger.getLogger(BatchTaskSchedulerExample.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$Aggregator.class */
    public static class Aggregator implements IFunction {
        private static final long serialVersionUID = -254264120110286748L;

        private Aggregator() {
        }

        public Object onMessage(Object obj, Object obj2) throws ArrayIndexOutOfBoundsException {
            double[] dArr = (double[]) obj;
            double[] dArr2 = (double[]) obj2;
            BatchTaskSchedulerExample.LOG.fine("object 1 and object 2:" + Arrays.toString(dArr) + "\t\n" + Arrays.toString(dArr2));
            double[] dArr3 = new double[dArr.length];
            for (int i = 0; i < dArr.length; i++) {
                dArr3[i] = dArr[i] + dArr2[i];
            }
            BatchTaskSchedulerExample.LOG.fine("Object 3 len:" + dArr3.length + "\tObject 3:" + Arrays.toString(dArr3));
            return dArr3;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$FirstComputeTask.class */
    public static class FirstComputeTask extends BaseCompute {
        private static final long serialVersionUID = -254264120110286748L;

        FirstComputeTask() {
        }

        public boolean execute(IMessage iMessage) {
            BatchTaskSchedulerExample.LOG.log(Level.FINE, "Received Points: " + this.context.getWorkerId() + ":" + this.context.globalTaskId());
            if (iMessage.getContent() instanceof Iterator) {
                Iterator it = (Iterator) iMessage.getContent();
                while (it.hasNext()) {
                    this.context.write("direct", it.next());
                }
            }
            this.context.end("direct");
            return true;
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$FirstSinkTask.class */
    public static class FirstSinkTask extends BaseCompute implements Collector {
        private static final long serialVersionUID = -5190777711234234L;
        private double[] datapoints;
        private String inputKey;

        FirstSinkTask(String str) {
            this.inputKey = str;
        }

        public boolean execute(IMessage iMessage) {
            ArrayList arrayList = new ArrayList();
            while (((Iterator) iMessage.getContent()).hasNext()) {
                arrayList.add((double[]) ((Iterator) iMessage.getContent()).next());
            }
            this.datapoints = new double[arrayList.size()];
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                this.datapoints = (double[]) it.next();
            }
            return true;
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }

        public DataPartition<double[]> get() {
            return new EntityPartition(this.datapoints);
        }

        public IONames getCollectibleNames() {
            return IONames.declare(new String[]{this.inputKey});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$FirstSourceTask.class */
    public static class FirstSourceTask extends BaseSource {
        private static final long serialVersionUID = -254264120110286748L;
        private double[] datapoints = null;
        private int numPoints = 0;

        FirstSourceTask() {
        }

        public void execute() {
            this.datapoints = new double[this.numPoints];
            Random random = new Random(100);
            for (int i = 0; i < this.numPoints; i++) {
                this.datapoints[i] = random.nextDouble();
            }
            this.context.writeEnd("direct", this.datapoints);
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
            this.numPoints = Integer.parseInt(config.getStringValue(CDFConstants.ARGS_DSIZE));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$SecondComputeTask.class */
    public static class SecondComputeTask extends BaseCompute {
        private static final long serialVersionUID = -254264120110286748L;

        SecondComputeTask() {
        }

        public boolean execute(IMessage iMessage) {
            BatchTaskSchedulerExample.LOG.log(Level.FINE, "Received Points: " + this.context.getWorkerId() + ":" + this.context.globalTaskId());
            if (iMessage.getContent() instanceof Iterator) {
                Iterator it = (Iterator) iMessage.getContent();
                while (it.hasNext()) {
                    this.context.write("direct", it.next());
                }
            }
            this.context.end("direct");
            return true;
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$SecondSinkTask.class */
    public static class SecondSinkTask extends BaseCompute implements Collector {
        private static final long serialVersionUID = -5190777711234234L;
        private double[] datapoints;
        private String inputKey;

        SecondSinkTask(String str) {
            this.inputKey = str;
        }

        public boolean execute(IMessage iMessage) {
            ArrayList arrayList = new ArrayList();
            while (((Iterator) iMessage.getContent()).hasNext()) {
                arrayList.add((double[]) ((Iterator) iMessage.getContent()).next());
            }
            this.datapoints = new double[arrayList.size()];
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                this.datapoints = (double[]) it.next();
            }
            return true;
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }

        public DataPartition<double[]> get() {
            return new EntityPartition(this.datapoints);
        }

        public IONames getCollectibleNames() {
            return IONames.declare(new String[]{this.inputKey});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$SecondSourceTask.class */
    public static class SecondSourceTask extends BaseSource {
        private static final long serialVersionUID = -254264120110286748L;
        private double[] datapoints = null;
        private int numPoints = 0;

        SecondSourceTask() {
        }

        public void execute() {
            this.datapoints = new double[this.numPoints];
            Random random = new Random(100);
            for (int i = 0; i < this.numPoints; i++) {
                this.datapoints[i] = random.nextDouble();
            }
            this.context.writeEnd("direct", this.datapoints);
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
            this.numPoints = Integer.parseInt(config.getStringValue(CDFConstants.ARGS_DSIZE));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$ThirdSinkTask.class */
    public static class ThirdSinkTask extends BaseCompute implements Collector {
        private static final long serialVersionUID = -5190777711234234L;
        private double[] datapoints;
        private String inputKey;

        ThirdSinkTask() {
        }

        ThirdSinkTask(String str) {
            this.inputKey = str;
        }

        public boolean execute(IMessage iMessage) {
            this.datapoints = (double[]) iMessage.getContent();
            return true;
        }

        public DataPartition<double[]> get() {
            return new EntityPartition(this.datapoints);
        }

        public IONames getCollectibleNames() {
            return IONames.declare(new String[]{this.inputKey});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/batchscheduler/BatchTaskSchedulerExample$ThirdSourceTask.class */
    public static class ThirdSourceTask extends BaseSource implements Receptor {
        private static final long serialVersionUID = -254264120110286748L;
        private double[] datapoints = null;
        private DataPartition<?> dataPartition = null;

        ThirdSourceTask() {
        }

        public void execute() {
            this.datapoints = (double[]) this.dataPartition.first();
            this.context.writeEnd("all-reduce", this.datapoints);
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }

        public void add(String str, DataObject<?> dataObject) {
        }

        public void add(String str, DataPartition<?> dataPartition) {
            if ("secondgraphpoints".equals(str)) {
                this.dataPartition = dataPartition;
            }
        }

        public IONames getReceivableNames() {
            return IONames.declare(new String[]{"secondgraphpoints"});
        }
    }

    private static ComputeGraph buildFirstGraph(int i, Config config) {
        FirstSourceTask firstSourceTask = new FirstSourceTask();
        FirstComputeTask firstComputeTask = new FirstComputeTask();
        FirstSinkTask firstSinkTask = new FirstSinkTask("firstgraphpoints");
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("firstsource", firstSourceTask, i);
        ComputeConnection addCompute = newBuilder.addCompute("firstcompute", firstComputeTask, i);
        ComputeConnection addCompute2 = newBuilder.addCompute("firstsink", firstSinkTask, i);
        addCompute.direct("firstsource").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        addCompute2.direct("firstcompute").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("firstTG");
        return newBuilder.build();
    }

    private static ComputeGraph buildSecondGraph(int i, Config config) {
        SecondSourceTask secondSourceTask = new SecondSourceTask();
        SecondComputeTask secondComputeTask = new SecondComputeTask();
        SecondSinkTask secondSinkTask = new SecondSinkTask("secondgraphpoints");
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("secondsource", secondSourceTask, i);
        ComputeConnection addCompute = newBuilder.addCompute("secondcompute", secondComputeTask, i);
        ComputeConnection addCompute2 = newBuilder.addCompute("secondsink", secondSinkTask, i);
        addCompute.direct("secondsource").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        addCompute2.direct("secondcompute").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("secondTG");
        return newBuilder.build();
    }

    private static ComputeGraph buildThirdGraph(int i, Config config) {
        ThirdSourceTask thirdSourceTask = new ThirdSourceTask();
        ThirdSinkTask thirdSinkTask = new ThirdSinkTask("thirdgraphpoints");
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("thirdsource", thirdSourceTask, i);
        newBuilder.addCompute("thirdsink", thirdSinkTask, i).allreduce("thirdsource").viaEdge("all-reduce").withReductionFunction(new Aggregator()).withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("thirdTG");
        return newBuilder.build();
    }

    public void execute(WorkerEnvironment workerEnvironment) {
        int workerId = workerEnvironment.getWorkerId();
        Config config = workerEnvironment.getConfig();
        long currentTimeMillis = System.currentTimeMillis();
        LOG.log(Level.FINE, "Task worker starting: " + workerId);
        ComputeEnvironment init = ComputeEnvironment.init(workerEnvironment);
        TaskExecutor taskExecutor = init.getTaskExecutor();
        ComputeGraph buildFirstGraph = buildFirstGraph(2, config);
        ComputeGraph buildSecondGraph = buildSecondGraph(4, config);
        ComputeGraph buildThirdGraph = buildThirdGraph(4, config);
        ComputeGraph[] computeGraphArr = {buildFirstGraph, buildSecondGraph, buildThirdGraph};
        ExecutionPlan plan = taskExecutor.plan(buildFirstGraph);
        ExecutionPlan plan2 = taskExecutor.plan(buildSecondGraph);
        ExecutionPlan plan3 = taskExecutor.plan(buildThirdGraph);
        taskExecutor.execute(buildFirstGraph, plan);
        taskExecutor.execute(buildSecondGraph, plan2);
        taskExecutor.execute(buildThirdGraph, plan3);
        init.close();
        LOG.info("Total Execution Time: " + (System.currentTimeMillis() - currentTimeMillis));
    }

    public static void main(String[] strArr) throws ParseException {
        LOG.log(Level.INFO, "Batch Task Graph Example");
        Config loadConfig = ResourceAllocator.loadConfig(new HashMap());
        Options options = new Options();
        options.addOption("workers", true, "Workers");
        options.addOption(CDFConstants.ARGS_PARALLELISM_VALUE, true, CDFConstants.ARGS_PARALLELISM_VALUE);
        options.addOption(CDFConstants.ARGS_DSIZE, true, CDFConstants.ARGS_DSIZE);
        CommandLine parse = new DefaultParser().parse(options, strArr);
        int parseInt = Integer.parseInt(parse.getOptionValue("workers"));
        int parseInt2 = Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_PARALLELISM_VALUE));
        int parseInt3 = Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_DSIZE));
        JobConfig jobConfig = new JobConfig();
        jobConfig.put("workers", Integer.toString(parseInt));
        jobConfig.put(CDFConstants.ARGS_PARALLELISM_VALUE, Integer.toString(parseInt2));
        jobConfig.put(CDFConstants.ARGS_DSIZE, Integer.toString(parseInt3));
        Twister2Job.Twister2JobBuilder newBuilder = Twister2Job.newBuilder();
        newBuilder.setJobName("BatchScheduler-test");
        newBuilder.setWorkerClass(BatchTaskSchedulerExample.class.getName());
        newBuilder.addComputeResource(2.0d, 2048, 1.0d, parseInt);
        newBuilder.setConfig(jobConfig);
        Twister2Submitter.submitJob(newBuilder.build(), loadConfig);
    }
}
