package edu.iu.dsc.tws.examples.checkpointing;

import edu.iu.dsc.tws.api.JobConfig;
import edu.iu.dsc.tws.api.Twister2Job;
import edu.iu.dsc.tws.api.checkpointing.Snapshot;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.IMessage;
import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.nodes.BaseCompute;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.resource.Twister2Worker;
import edu.iu.dsc.tws.api.resource.WorkerEnvironment;
import edu.iu.dsc.tws.checkpointing.task.CheckpointableTask;
import edu.iu.dsc.tws.examples.batch.cdfw.CDFConstants;
import edu.iu.dsc.tws.rsched.core.ResourceAllocator;
import edu.iu.dsc.tws.rsched.job.Twister2Submitter;
import edu.iu.dsc.tws.task.ComputeEnvironment;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import java.util.HashMap;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/checkpointing/CheckpointingTaskExample.class */
public class CheckpointingTaskExample implements Twister2Worker {
    private static final Logger LOG = Logger.getLogger(CheckpointingTaskExample.class.getName());

    /* loaded from: input_file:edu/iu/dsc/tws/examples/checkpointing/CheckpointingTaskExample$ComputeTask.class */
    public static class ComputeTask extends BaseCompute<Integer> implements CheckpointableTask {
        private int count = 0;

        public void restoreSnapshot(Snapshot snapshot) {
            this.count = ((Integer) snapshot.getOrDefault("count", 0)).intValue();
            CheckpointingTaskExample.LOG.info("Restored compute to  " + this.count);
        }

        public void takeSnapshot(Snapshot snapshot) {
            snapshot.setValue("count", Integer.valueOf(this.count));
        }

        public void initSnapshot(Snapshot snapshot) {
            snapshot.setPacker("count", MessageTypes.INTEGER.getDataPacker());
        }

        public boolean execute(IMessage<Integer> iMessage) {
            this.count = ((Integer) iMessage.getContent()).intValue();
            this.context.write("c-si", Integer.valueOf(this.count));
            return true;
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/checkpointing/CheckpointingTaskExample$SinkTask.class */
    public static class SinkTask extends BaseCompute<Integer> implements CheckpointableTask {
        private int count = 0;

        public boolean execute(IMessage<Integer> iMessage) {
            this.count = ((Integer) iMessage.getContent()).intValue();
            return true;
        }

        public void restoreSnapshot(Snapshot snapshot) {
            this.count = ((Integer) snapshot.getOrDefault("count", 0)).intValue();
            CheckpointingTaskExample.LOG.info("Restored sinks to  " + this.count);
        }

        public void takeSnapshot(Snapshot snapshot) {
            snapshot.setValue("count", Integer.valueOf(this.count));
        }

        public void initSnapshot(Snapshot snapshot) {
            snapshot.setPacker("count", MessageTypes.INTEGER.getDataPacker());
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/checkpointing/CheckpointingTaskExample$SourceTask.class */
    public static class SourceTask extends BaseSource implements CheckpointableTask {
        private int count = 0;

        public void execute() {
            TaskContext taskContext = this.context;
            int i = this.count;
            this.count = i + 1;
            taskContext.write("so-c", Integer.valueOf(i));
        }

        public void restoreSnapshot(Snapshot snapshot) {
            this.count = ((Integer) snapshot.getOrDefault("count", 0)).intValue();
            CheckpointingTaskExample.LOG.info("Restored source to  " + this.count);
        }

        public void takeSnapshot(Snapshot snapshot) {
            snapshot.setValue("count", Integer.valueOf(this.count));
        }

        public void initSnapshot(Snapshot snapshot) {
            snapshot.setPacker("count", MessageTypes.INTEGER.getDataPacker());
        }
    }

    public void execute(WorkerEnvironment workerEnvironment) {
        ComputeEnvironment init = ComputeEnvironment.init(workerEnvironment);
        ComputeGraphBuilder newTaskGraph = init.newTaskGraph(OperationMode.STREAMING);
        int intValue = workerEnvironment.getConfig().getIntegerValue(CDFConstants.ARGS_PARALLELISM_VALUE, 1).intValue();
        newTaskGraph.addSource("source", new SourceTask(), intValue);
        newTaskGraph.addCompute("compute", new ComputeTask(), intValue).direct("source").viaEdge("so-c").withDataType(MessageTypes.INTEGER);
        newTaskGraph.addCompute("sink", new SinkTask(), intValue).direct("compute").viaEdge("c-si").withDataType(MessageTypes.INTEGER);
        init.buildAndExecute(newTaskGraph);
        init.close();
    }

    public static void main(String[] strArr) {
        int i = 4;
        if (strArr.length == 1) {
            i = Integer.valueOf(strArr[0]).intValue();
        }
        Config loadConfig = ResourceAllocator.loadConfig(new HashMap());
        JobConfig jobConfig = new JobConfig();
        jobConfig.put(CDFConstants.ARGS_PARALLELISM_VALUE, Integer.valueOf(i));
        Twister2Submitter.submitJob(Twister2Job.newBuilder().setJobName("hello-checkpointing-job").setWorkerClass(CheckpointingTaskExample.class).addComputeResource(1.0d, 1024, i).setConfig(jobConfig).build(), loadConfig);
    }
}
