package edu.iu.dsc.tws.examples.batch.cdfw;

import edu.iu.dsc.tws.api.Twister2Job;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.IFunction;
import edu.iu.dsc.tws.api.compute.IMessage;
import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.modifiers.Collector;
import edu.iu.dsc.tws.api.compute.modifiers.IONames;
import edu.iu.dsc.tws.api.compute.modifiers.Receptor;
import edu.iu.dsc.tws.api.compute.nodes.BaseCompute;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.api.dataset.DataPartition;
import edu.iu.dsc.tws.dataset.partition.EntityPartition;
import edu.iu.dsc.tws.examples.batch.kmeans.KMeansUtils;
import edu.iu.dsc.tws.rsched.core.ResourceAllocator;
import edu.iu.dsc.tws.rsched.job.Twister2Submitter;
import edu.iu.dsc.tws.task.cdfw.BaseDriver;
import edu.iu.dsc.tws.task.cdfw.CDFWEnv;
import edu.iu.dsc.tws.task.cdfw.DataFlowGraph;
import edu.iu.dsc.tws.task.cdfw.DataFlowJobConfig;
import edu.iu.dsc.tws.task.dataobjects.DataFileReplicatedReadSource;
import edu.iu.dsc.tws.task.dataobjects.DataObjectSource;
import edu.iu.dsc.tws.task.impl.ComputeConnection;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.cdfw.CDFWWorker;
import java.util.HashMap;
import java.util.logging.Logger;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

/* loaded from: input_file:edu/iu/dsc/tws/examples/batch/cdfw/KMeansConnectedDataflowExample.class */
public final class KMeansConnectedDataflowExample {
    private static final Logger LOG = Logger.getLogger(KMeansConnectedDataflowExample.class.getName());

    /* loaded from: input_file:edu/iu/dsc/tws/examples/batch/cdfw/KMeansConnectedDataflowExample$CentroidAggregator.class */
    public static class CentroidAggregator implements IFunction {
        private static final long serialVersionUID = -254264120110286748L;

        public Object onMessage(Object obj, Object obj2) throws ArrayIndexOutOfBoundsException {
            double[][] dArr = (double[][]) obj;
            double[][] dArr2 = (double[][]) obj2;
            double[][] dArr3 = new double[dArr.length][dArr[0].length];
            if (dArr.length != dArr2.length) {
                throw new RuntimeException("Center sizes not equal " + dArr.length + " != " + dArr2.length);
            }
            for (int i = 0; i < dArr.length; i++) {
                for (int i2 = 0; i2 < dArr[0].length; i2++) {
                    dArr3[i][i2] = dArr[i][i2] + dArr2[i][i2];
                }
            }
            return dArr3;
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/batch/cdfw/KMeansConnectedDataflowExample$KMeansAllReduceTask.class */
    public static class KMeansAllReduceTask extends BaseCompute implements Collector {
        private static final long serialVersionUID = -5190777711234234L;
        private double[][] centroids;
        private double[][] newCentroids;

        public boolean execute(IMessage iMessage) {
            this.centroids = (double[][]) iMessage.getContent();
            this.newCentroids = new double[this.centroids.length][this.centroids[0].length - 1];
            for (int i = 0; i < this.centroids.length; i++) {
                for (int i2 = 0; i2 < this.centroids[0].length - 1; i2++) {
                    this.newCentroids[i][i2] = this.centroids[i][i2] / this.centroids[i][this.centroids[0].length - 1];
                }
            }
            return true;
        }

        public DataPartition<double[][]> get() {
            return new EntityPartition(this.newCentroids);
        }

        public IONames getCollectibleNames() {
            return IONames.declare(new String[]{"centroids"});
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/batch/cdfw/KMeansConnectedDataflowExample$KMeansDriver.class */
    public static class KMeansDriver extends BaseDriver {
        public void execute(CDFWEnv cDFWEnv) {
            Config config = cDFWEnv.getConfig();
            DataFlowJobConfig dataFlowJobConfig = new DataFlowJobConfig();
            String valueOf = String.valueOf(config.get(CDFConstants.ARGS_DINPUT));
            String valueOf2 = String.valueOf(config.get(CDFConstants.ARGS_CINPUT));
            int parseInt = Integer.parseInt(String.valueOf(config.get(CDFConstants.ARGS_PARALLELISM_VALUE)));
            int parseInt2 = Integer.parseInt(String.valueOf(config.get("workers")));
            int parseInt3 = Integer.parseInt(String.valueOf(config.get(CDFConstants.ARGS_ITERATIONS)));
            int parseInt4 = Integer.parseInt(String.valueOf(config.get("dim")));
            int parseInt5 = Integer.parseInt(String.valueOf(config.get(CDFConstants.ARGS_DSIZE)));
            int parseInt6 = Integer.parseInt(String.valueOf(config.get(CDFConstants.ARGS_CSIZE)));
            cDFWEnv.executeDataFlowGraph(KMeansConnectedDataflowExample.generateData(config, valueOf, valueOf2, parseInt4, parseInt5, parseInt6, parseInt2, parseInt, dataFlowJobConfig));
            DataFlowGraph generateFirstJob = KMeansConnectedDataflowExample.generateFirstJob(config, parseInt, valueOf, parseInt4, parseInt5, parseInt2, dataFlowJobConfig);
            DataFlowGraph generateSecondJob = KMeansConnectedDataflowExample.generateSecondJob(config, parseInt, valueOf2, parseInt4, parseInt6, parseInt2, dataFlowJobConfig);
            long currentTimeMillis = System.currentTimeMillis();
            cDFWEnv.executeDataFlowGraph(generateFirstJob);
            cDFWEnv.executeDataFlowGraph(generateSecondJob);
            long currentTimeMillis2 = System.currentTimeMillis();
            for (int i = 0; i < parseInt3; i++) {
                DataFlowGraph generateThirdJob = KMeansConnectedDataflowExample.generateThirdJob(config, parseInt, parseInt2, parseInt3, parseInt4, dataFlowJobConfig);
                generateThirdJob.setIterationNumber(i);
                cDFWEnv.executeDataFlowGraph(generateThirdJob);
            }
            long currentTimeMillis3 = System.currentTimeMillis();
            KMeansConnectedDataflowExample.LOG.info("Total K-Means Execution Time: " + (currentTimeMillis3 - currentTimeMillis) + "\tData Load time : " + (currentTimeMillis2 - currentTimeMillis) + "\tCompute Time : " + (currentTimeMillis3 - currentTimeMillis2));
            cDFWEnv.close();
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/batch/cdfw/KMeansConnectedDataflowExample$KMeansSourceTask.class */
    public static class KMeansSourceTask extends BaseSource implements Receptor {
        private static final long serialVersionUID = -254264120110286748L;
        private DataPartition<?> dataPartition;
        private DataPartition<?> centroidPartition;
        private int dimension;

        public KMeansSourceTask() {
            this.dataPartition = null;
            this.centroidPartition = null;
            this.dimension = 0;
        }

        public KMeansSourceTask(int i) {
            this.dataPartition = null;
            this.centroidPartition = null;
            this.dimension = 0;
            this.dimension = i;
        }

        public void execute() {
            this.context.writeEnd("all-reduce", KMeansUtils.findNearestCenter(this.dimension, (double[][]) this.dataPartition.first(), (double[][]) this.centroidPartition.first()));
        }

        public void add(String str, DataObject<?> dataObject) {
        }

        public void add(String str, DataPartition<?> dataPartition) {
            if ("points".equals(str)) {
                this.dataPartition = dataPartition;
            }
            if ("centroids".equals(str)) {
                this.centroidPartition = dataPartition;
            }
        }

        public IONames getReceivableNames() {
            return IONames.declare(new String[]{"points", "centroids"});
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }
    }

    private KMeansConnectedDataflowExample() {
    }

    public static void main(String[] strArr) throws ParseException {
        Config loadConfig = ResourceAllocator.loadConfig(new HashMap());
        new HashMap().put("twister2.exector.worker.threads", 1);
        Options options = new Options();
        options.addOption(CDFConstants.ARGS_PARALLELISM_VALUE, true, "2");
        options.addOption("workers", true, "2");
        options.addOption("dim", true, "2");
        options.addOption(CDFConstants.ARGS_DSIZE, true, "2");
        options.addOption(CDFConstants.ARGS_CSIZE, true, "2");
        options.addOption(CDFConstants.ARGS_DINPUT, true, "2");
        options.addOption(CDFConstants.ARGS_CINPUT, true, "2");
        options.addOption(CDFConstants.ARGS_ITERATIONS, true, "2");
        CommandLine parse = new DefaultParser().parse(options, strArr);
        String optionValue = parse.getOptionValue(CDFConstants.ARGS_DINPUT);
        String optionValue2 = parse.getOptionValue(CDFConstants.ARGS_CINPUT);
        int parseInt = Integer.parseInt(parse.getOptionValue("workers"));
        int parseInt2 = Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_PARALLELISM_VALUE));
        int parseInt3 = Integer.parseInt(parse.getOptionValue("dim"));
        int parseInt4 = Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_DSIZE));
        Twister2Submitter.submitJob(Twister2Job.newBuilder().setJobName("kmeans-connected-dataflow").setWorkerClass(CDFWWorker.class).setDriverClass(KMeansDriver.class.getName()).addComputeResource(1.0d, 2048, parseInt, true).build(), Config.newBuilder().putAll(loadConfig).put("workers", Integer.toString(parseInt)).put(CDFConstants.ARGS_PARALLELISM_VALUE, Integer.toString(parseInt2)).put("dim", Integer.toString(parseInt3)).put(CDFConstants.ARGS_CSIZE, Integer.toString(Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_CSIZE)))).put(CDFConstants.ARGS_DSIZE, Integer.toString(parseInt4)).put(CDFConstants.ARGS_DINPUT, optionValue).put(CDFConstants.ARGS_CINPUT, optionValue2).put(CDFConstants.ARGS_ITERATIONS, Integer.toString(Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_ITERATIONS)))).put("twister2.resource.job.driver.class", (Object) null).build());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static DataFlowGraph generateData(Config config, String str, String str2, int i, int i2, int i3, int i4, int i5, DataFlowJobConfig dataFlowJobConfig) {
        DataGeneratorSource dataGeneratorSource = new DataGeneratorSource("direct", i2, i3, i, str, str2);
        DataGeneratorSink dataGeneratorSink = new DataGeneratorSink();
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.setTaskGraphName("DataGenerator");
        newBuilder.addSource("datageneratorsource", dataGeneratorSource, i5);
        newBuilder.addCompute("datageneratorsink", dataGeneratorSink, i5).direct("datageneratorsource").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        ComputeGraph build = newBuilder.build();
        newBuilder.setTaskGraphName("datageneratorTG");
        return DataFlowGraph.newSubGraphJob("datageneratorTG", build).setWorkers(i4).addDataFlowJobConfig(dataFlowJobConfig).setGraphType("non-iterative");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static DataFlowGraph generateFirstJob(Config config, int i, String str, int i2, int i3, int i4, DataFlowJobConfig dataFlowJobConfig) {
        DataObjectSource dataObjectSource = new DataObjectSource("direct", str);
        KMeansDataObjectCompute kMeansDataObjectCompute = new KMeansDataObjectCompute("direct", i3, i, i2);
        KMeansDataObjectDirectSink kMeansDataObjectDirectSink = new KMeansDataObjectDirectSink("points");
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("datapointsource", dataObjectSource, i);
        ComputeConnection addCompute = newBuilder.addCompute("datapointcompute", kMeansDataObjectCompute, i);
        ComputeConnection addCompute2 = newBuilder.addCompute("datapointsink", kMeansDataObjectDirectSink, i);
        addCompute.direct("datapointsource").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        addCompute2.direct("datapointcompute").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("datapointsTG");
        return DataFlowGraph.newSubGraphJob("datapointsTG", newBuilder.build()).setWorkers(i4).addDataFlowJobConfig(dataFlowJobConfig).setGraphType("non-iterative");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static DataFlowGraph generateSecondJob(Config config, int i, String str, int i2, int i3, int i4, DataFlowJobConfig dataFlowJobConfig) {
        DataFileReplicatedReadSource dataFileReplicatedReadSource = new DataFileReplicatedReadSource("direct", str);
        KMeansDataObjectCompute kMeansDataObjectCompute = new KMeansDataObjectCompute("direct", i3, i2);
        KMeansDataObjectDirectSink kMeansDataObjectDirectSink = new KMeansDataObjectDirectSink("centroids");
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("centroidsource", dataFileReplicatedReadSource, i);
        ComputeConnection addCompute = newBuilder.addCompute("centroidcompute", kMeansDataObjectCompute, i);
        ComputeConnection addCompute2 = newBuilder.addCompute("centroidsink", kMeansDataObjectDirectSink, i);
        addCompute.direct("centroidsource").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        addCompute2.direct("centroidcompute").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("centroidTG");
        return DataFlowGraph.newSubGraphJob("centroidTG", newBuilder.build()).setWorkers(i4).addDataFlowJobConfig(dataFlowJobConfig).setGraphType("non-iterative");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static DataFlowGraph generateThirdJob(Config config, int i, int i2, int i3, int i4, DataFlowJobConfig dataFlowJobConfig) {
        KMeansSourceTask kMeansSourceTask = new KMeansSourceTask(i4);
        KMeansAllReduceTask kMeansAllReduceTask = new KMeansAllReduceTask();
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("kmeanssource", kMeansSourceTask, i);
        newBuilder.addCompute("kmeanssink", kMeansAllReduceTask, i).allreduce("kmeanssource").viaEdge("all-reduce").withReductionFunction(new CentroidAggregator()).withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("kmeansTG");
        return DataFlowGraph.newSubGraphJob("kmeansTG", newBuilder.build()).setWorkers(i2).addDataFlowJobConfig(dataFlowJobConfig).setGraphType("iterative").setIterations(i3);
    }
}
