package edu.iu.dsc.tws.examples.ml.svm.streamer;

import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.InputDataFormatException;
import edu.iu.dsc.tws.examples.ml.svm.integration.test.IReceptor;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/streamer/IterativeStreamingDataStreamer.class */
public class IterativeStreamingDataStreamer extends BaseSource implements IReceptor<double[][]> {
    private static final long serialVersionUID = -5845248990437663713L;
    private static final Logger LOG = Logger.getLogger(IterativeStreamingDataStreamer.class.getName());
    private final double[] labels;
    private int features;
    private OperationMode operationMode;
    private boolean isDummy;
    private BinaryBatchModel binaryBatchModel;
    private DataObject<double[][]> dataPointsObject;
    private DataObject<double[]> weightVectorObject;
    private double[][] datapoints;
    private double[] weightVector;
    private boolean debug;
    private int count;
    private boolean isDataLoaded;

    public void execute() {
        if (!this.isDummy) {
            loadData();
            realDataStreamer();
        } else {
            try {
                dummyDataStreamer();
            } catch (InputDataFormatException e) {
                e.printStackTrace();
            }
        }
    }

    public IterativeStreamingDataStreamer(OperationMode operationMode) {
        this.labels = new double[]{-1.0d, 1.0d};
        this.features = 10;
        this.isDummy = false;
        this.dataPointsObject = null;
        this.weightVectorObject = null;
        this.datapoints = null;
        this.weightVector = null;
        this.debug = false;
        this.count = 0;
        this.isDataLoaded = false;
        this.operationMode = operationMode;
    }

    public IterativeStreamingDataStreamer(int i, OperationMode operationMode) {
        this.labels = new double[]{-1.0d, 1.0d};
        this.features = 10;
        this.isDummy = false;
        this.dataPointsObject = null;
        this.weightVectorObject = null;
        this.datapoints = null;
        this.weightVector = null;
        this.debug = false;
        this.count = 0;
        this.isDataLoaded = false;
        this.features = i;
        this.operationMode = operationMode;
    }

    public IterativeStreamingDataStreamer(int i, OperationMode operationMode, boolean z, BinaryBatchModel binaryBatchModel) {
        this.labels = new double[]{-1.0d, 1.0d};
        this.features = 10;
        this.isDummy = false;
        this.dataPointsObject = null;
        this.weightVectorObject = null;
        this.datapoints = null;
        this.weightVector = null;
        this.debug = false;
        this.count = 0;
        this.isDataLoaded = false;
        this.features = i;
        this.operationMode = operationMode;
        this.isDummy = z;
        this.binaryBatchModel = binaryBatchModel;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.iu.dsc.tws.examples.ml.svm.integration.test.IReceptor
    public void add(String str, DataObject<?> dataObject) {
        if (this.debug) {
            LOG.log(Level.INFO, String.format("Received input: %s ", str));
        }
        if (Constants.SimpleGraphConfig.INPUT_DATA.equals(str)) {
            this.dataPointsObject = dataObject;
        }
        if (Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR.equals(str)) {
            this.weightVectorObject = dataObject;
        }
    }

    private void prepareDataPoints() {
        this.datapoints = (double[][]) this.dataPointsObject.getPartition(this.context.taskIndex()).getConsumer().next();
        if (this.debug) {
            LOG.info(String.format("Recieved Input Data : %s ", this.datapoints.getClass().getName()));
        }
        LOG.info(String.format("Data Point TaskIndex[%d], Size : %d ", Integer.valueOf(this.context.taskIndex()), Integer.valueOf(this.datapoints.length)));
    }

    private void prepareWeightVector() {
        this.weightVector = (double[]) this.weightVectorObject.getPartition(this.context.taskIndex()).getConsumer().next();
        LOG.info(String.format("Weight Vector TaskIndex[%d], Size : %d ", Integer.valueOf(this.context.taskIndex()), Integer.valueOf(this.weightVector.length)));
    }

    public void dummyDataStreamer() throws InputDataFormatException {
        if (this.operationMode.equals(OperationMode.STREAMING)) {
            double[] combineLabelAndData = DataUtils.combineLabelAndData(DataUtils.seedDoubleArray(this.binaryBatchModel.getFeatures()), this.labels[new Random().nextInt(2)]);
            if (combineLabelAndData.length != this.binaryBatchModel.getFeatures() + 1) {
                throw new InputDataFormatException(String.format("Input Data Format Exception : [data length : %d, feature length +1 : %d]", Integer.valueOf(combineLabelAndData.length), Integer.valueOf(this.binaryBatchModel.getFeatures() + 1)));
            }
            this.context.write(Constants.SimpleGraphConfig.DATA_EDGE, combineLabelAndData);
        }
        if (this.operationMode.equals(OperationMode.BATCH)) {
            this.context.write(Constants.SimpleGraphConfig.DATA_EDGE, DataUtils.generateDummyDataPoints(this.binaryBatchModel.getSamples(), this.binaryBatchModel.getFeatures()));
            this.context.end(Constants.SimpleGraphConfig.DATA_EDGE);
        }
    }

    public void realDataStreamer() {
        if (this.operationMode.equals(OperationMode.STREAMING)) {
            streamData();
        } else {
            LOG.info(String.format("This Data Source only supports for streaming tasks", new Object[0]));
        }
    }

    private void loadData() {
        if (this.count == 0) {
            prepareDataPoints();
            prepareWeightVector();
        }
    }

    public void streamData() {
        if (this.count < this.datapoints.length) {
            TaskContext taskContext = this.context;
            double[][] dArr = this.datapoints;
            int i = this.count;
            this.count = i + 1;
            taskContext.write(Constants.SimpleGraphConfig.STREAMING_EDGE, dArr[i]);
        }
        if (this.count == this.datapoints.length) {
            this.context.end(Constants.SimpleGraphConfig.STREAMING_EDGE);
        }
    }
}
