package edu.iu.dsc.tws.examples.ml.svm.streamer;

import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.MatrixMultiplicationException;
import edu.iu.dsc.tws.examples.ml.svm.integration.test.IReceptor;
import edu.iu.dsc.tws.examples.ml.svm.sgd.pegasos.PegasosSgdSvm;
import edu.iu.dsc.tws.examples.ml.svm.test.Predict;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import java.util.Arrays;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/streamer/IterativePredictionDataStreamer.class */
public class IterativePredictionDataStreamer extends BaseSource implements IReceptor<Double> {
    private static final long serialVersionUID = -5619263102396811849L;
    private static final Logger LOG = Logger.getLogger(IterativePredictionDataStreamer.class.getName());
    private final double[] labels;
    private int features;
    private OperationMode operationMode;
    private boolean isDummy;
    private BinaryBatchModel binaryBatchModel;
    private DataObject<double[][]> dataPointsObject;
    private DataObject<double[]> weightVectorObject;
    private double[][] datapoints;
    private double[] weightVector;
    private PegasosSgdSvm pegasosSgdSvm;
    private boolean debug;
    private double finalAccuracy;

    public IterativePredictionDataStreamer(OperationMode operationMode) {
        this.labels = new double[]{-1.0d, 1.0d};
        this.features = 10;
        this.isDummy = false;
        this.dataPointsObject = null;
        this.weightVectorObject = null;
        this.datapoints = null;
        this.weightVector = null;
        this.pegasosSgdSvm = null;
        this.debug = false;
        this.finalAccuracy = 0.0d;
        this.operationMode = operationMode;
    }

    public IterativePredictionDataStreamer(int i, OperationMode operationMode, boolean z, BinaryBatchModel binaryBatchModel) {
        this.labels = new double[]{-1.0d, 1.0d};
        this.features = 10;
        this.isDummy = false;
        this.dataPointsObject = null;
        this.weightVectorObject = null;
        this.datapoints = null;
        this.weightVector = null;
        this.pegasosSgdSvm = null;
        this.debug = false;
        this.finalAccuracy = 0.0d;
        this.features = i;
        this.operationMode = operationMode;
        this.isDummy = z;
        this.binaryBatchModel = binaryBatchModel;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.iu.dsc.tws.examples.ml.svm.integration.test.IReceptor
    public void add(String str, DataObject<?> dataObject) {
        if ("test_data".equals(str)) {
            this.dataPointsObject = dataObject;
        }
        if (Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR.equals(str)) {
            this.weightVectorObject = dataObject;
        }
    }

    public void execute() {
        if (this.isDummy) {
            LOG.info(String.format("Dummy Data Training Doesn't support prediction", new Object[0]));
        } else {
            realDataStreamer();
        }
    }

    public void realDataStreamer() {
        if (this.operationMode.equals(OperationMode.BATCH)) {
            getData();
            initializeBatchModel();
            compute();
        }
        if (this.operationMode.equals(OperationMode.STREAMING)) {
        }
    }

    public void getData() {
        this.datapoints = (double[][]) this.dataPointsObject.getPartition(this.context.taskIndex()).getConsumer().next();
        this.weightVector = (double[]) this.weightVectorObject.getPartition(this.context.taskIndex()).getConsumer().next();
        if (this.debug) {
            LOG.info(String.format("Recieved Input Data : %s ", this.datapoints.getClass().getName()));
        }
    }

    public void initializeBatchModel() {
        initializeBinaryModel(this.datapoints);
        this.binaryBatchModel.setW(this.weightVector);
        this.pegasosSgdSvm = new PegasosSgdSvm(this.binaryBatchModel.getW(), this.binaryBatchModel.getX(), this.binaryBatchModel.getY(), this.binaryBatchModel.getAlpha(), this.binaryBatchModel.getIterations(), this.binaryBatchModel.getFeatures());
    }

    public void initializeBinaryModel(double[][] dArr) {
        if (this.binaryBatchModel == null) {
            throw new NullPointerException("Binary Batch Model is Null !!!");
        }
        if (this.debug) {
            LOG.info("Binary Batch Model Before Updated : " + this.binaryBatchModel.toString());
        }
        this.binaryBatchModel = DataUtils.updateModelData(this.binaryBatchModel, dArr);
        if (this.debug) {
            LOG.info("Binary Batch Model After Updated : " + this.binaryBatchModel.toString());
            LOG.info(String.format("Updated Data [%d,%d] ", Integer.valueOf(this.binaryBatchModel.getX().length), Integer.valueOf(this.binaryBatchModel.getX()[0].length)));
        }
    }

    public void compute() {
        double[][] x = this.binaryBatchModel.getX();
        double[] w = this.binaryBatchModel.getW();
        this.binaryBatchModel.getY();
        double d = 0.0d;
        try {
            d = new Predict(this.binaryBatchModel.getX(), this.binaryBatchModel.getY(), w).predict();
        } catch (MatrixMultiplicationException e) {
            e.printStackTrace();
        }
        if (this.debug) {
            LOG.info(String.format("Accuracy : %f, Context Id : %d, Weight Vector : %s, Data Size : %d", Double.valueOf(d), Integer.valueOf(this.context.taskIndex()), Arrays.toString(w), Integer.valueOf(x.length)));
        }
        this.finalAccuracy = d / this.context.getParallelism();
        this.context.write(Constants.SimpleGraphConfig.PREDICTION_EDGE, Double.valueOf(this.finalAccuracy));
        this.context.end(Constants.SimpleGraphConfig.PREDICTION_EDGE);
    }
}
