package edu.iu.dsc.tws.examples.ml.svm.job;

import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.resource.Twister2Worker;
import edu.iu.dsc.tws.api.resource.WorkerEnvironment;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.MatrixMultiplicationException;
import edu.iu.dsc.tws.examples.ml.svm.math.Matrix;
import edu.iu.dsc.tws.examples.ml.svm.tset.AccuracyAverager;
import edu.iu.dsc.tws.examples.ml.svm.tset.DataLoadingTask;
import edu.iu.dsc.tws.examples.ml.svm.tset.SvmTestMap;
import edu.iu.dsc.tws.examples.ml.svm.tset.SvmTrainMap;
import edu.iu.dsc.tws.examples.ml.svm.tset.WeightVectorAverager;
import edu.iu.dsc.tws.examples.ml.svm.tset.WeightVectorLoad;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import edu.iu.dsc.tws.examples.ml.svm.util.ResultsSaver;
import edu.iu.dsc.tws.examples.ml.svm.util.SVMJobParameters;
import edu.iu.dsc.tws.tset.env.BatchEnvironment;
import edu.iu.dsc.tws.tset.env.TSetEnvironment;
import edu.iu.dsc.tws.tset.sets.batch.CachedTSet;
import edu.iu.dsc.tws.tset.sets.batch.ComputeTSet;
import java.io.IOException;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/job/SvmSgdTsetRunner.class */
public class SvmSgdTsetRunner implements Twister2Worker, Serializable {
    private static final Logger LOG;
    private OperationMode operationMode;
    private SVMJobParameters svmJobParameters;
    private BinaryBatchModel binaryBatchModel;
    private CachedTSet<double[]> trainedWeightVector;
    private CachedTSet<double[][]> trainingData;
    private CachedTSet<double[][]> testingData;
    static final /* synthetic */ boolean $assertionsDisabled;
    private final int reduceParallelism = 1;
    private int dataStreamerParallelism = 4;
    private int svmComputeParallelism = 4;
    private int features = 10;
    private long dataLoadingTime = 0;
    private long initializingTime = 0;
    private double initializingDTime = 0.0d;
    private long trainingTime = 0;
    private long testingTime = 0;
    private double dataLoadingDTime = 0.0d;
    private double trainingDTime = 0.0d;
    private double testingDTime = 0.0d;
    private double totalTime = 0.0d;
    private double accuracy = 0.0d;
    private boolean debug = false;
    private String experimentName = "";
    private boolean testStatus = false;

    private void executeAll(BatchEnvironment batchEnvironment) {
        initialize(batchEnvironment).loadData(batchEnvironment).train(batchEnvironment).predict(batchEnvironment).summary(batchEnvironment).save(batchEnvironment);
    }

    public void execute(WorkerEnvironment workerEnvironment) {
        executeAll(TSetEnvironment.initBatch(workerEnvironment));
    }

    private void initializeParameters(BatchEnvironment batchEnvironment) {
        this.svmJobParameters = SVMJobParameters.build(batchEnvironment.getConfig());
        this.binaryBatchModel = new BinaryBatchModel();
        this.dataStreamerParallelism = this.svmJobParameters.getParallelism();
        this.experimentName = this.svmJobParameters.getExperimentName();
        this.svmComputeParallelism = this.dataStreamerParallelism;
        this.features = this.svmJobParameters.getFeatures();
        this.binaryBatchModel.setIterations(this.svmJobParameters.getIterations());
        this.binaryBatchModel.setAlpha(this.svmJobParameters.getAlpha());
        this.binaryBatchModel.setFeatures(this.svmJobParameters.getFeatures());
        this.binaryBatchModel.setSamples(this.svmJobParameters.getSamples());
        this.binaryBatchModel.setW(DataUtils.seedDoubleArray(this.svmJobParameters.getFeatures()));
        LOG.info(this.binaryBatchModel.toString());
    }

    private CachedTSet<double[][]> loadTrainingData(BatchEnvironment batchEnvironment) {
        return batchEnvironment.createSource(new DataLoadingTask(this.binaryBatchModel, this.svmJobParameters, "train"), this.dataStreamerParallelism).setName("trainingDataSource").cache();
    }

    private CachedTSet<double[][]> loadTestingData(BatchEnvironment batchEnvironment) {
        return batchEnvironment.createSource(new DataLoadingTask(this.binaryBatchModel, this.svmJobParameters, "test"), this.dataStreamerParallelism).setName("testingDataSource").cache();
    }

    private CachedTSet<double[]> loadWeightVector(BatchEnvironment batchEnvironment) {
        return batchEnvironment.createSource(new WeightVectorLoad(this.binaryBatchModel, this.svmJobParameters), this.dataStreamerParallelism).setName("weightVectorSource").cache();
    }

    private void executeTraining(BatchEnvironment batchEnvironment) {
        this.binaryBatchModel.setW((double[]) this.trainedWeightVector.getData().get(0));
        for (int i = 0; i < this.svmJobParameters.getIterations(); i++) {
            LOG.info(String.format("Iteration %d", Integer.valueOf(i)));
            ComputeTSet map = this.trainingData.direct().map(new SvmTrainMap(this.binaryBatchModel, this.svmJobParameters));
            map.addInput(Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR, this.trainedWeightVector);
            this.trainedWeightVector = map.allReduce((dArr, dArr2) -> {
                double[] dArr = new double[dArr.length];
                try {
                    dArr = Matrix.add(dArr, dArr2);
                } catch (MatrixMultiplicationException e) {
                    e.printStackTrace();
                }
                return dArr;
            }).map(new WeightVectorAverager(this.dataStreamerParallelism)).cache();
        }
    }

    private void executeSummary(BatchEnvironment batchEnvironment) {
        if (batchEnvironment.getWorkerID() == 0) {
            generateSummary();
        }
    }

    private void executePredict(BatchEnvironment batchEnvironment) {
        if (!$assertionsDisabled && this.trainedWeightVector.getStoredSourceTSet() == null) {
            throw new AssertionError("Partition is null");
        }
        this.binaryBatchModel.setW((double[]) this.trainedWeightVector.getData().get(0));
        this.accuracy = ((Double) this.testingData.direct().map(new SvmTestMap(this.binaryBatchModel, this.svmJobParameters)).reduce((v0, v1) -> {
            return Double.sum(v0, v1);
        }).map(new AccuracyAverager(this.svmJobParameters.getParallelism())).cache().getData().get(0)).doubleValue();
        LOG.info(String.format("Training Accuracy : %f ", Double.valueOf(this.accuracy)));
    }

    private SvmSgdTsetRunner initialize(BatchEnvironment batchEnvironment) {
        long nanoTime = System.nanoTime();
        initializeParameters(batchEnvironment);
        this.initializingTime = System.nanoTime() - nanoTime;
        return this;
    }

    private SvmSgdTsetRunner train(BatchEnvironment batchEnvironment) {
        long nanoTime = System.nanoTime();
        executeTraining(batchEnvironment);
        this.trainingTime = System.nanoTime() - nanoTime;
        return this;
    }

    private SvmSgdTsetRunner predict(BatchEnvironment batchEnvironment) {
        long nanoTime = System.nanoTime();
        executePredict(batchEnvironment);
        this.testingTime = System.nanoTime() - nanoTime;
        return this;
    }

    private SvmSgdTsetRunner summary(BatchEnvironment batchEnvironment) {
        executeSummary(batchEnvironment);
        return this;
    }

    private SvmSgdTsetRunner loadData(BatchEnvironment batchEnvironment) {
        long nanoTime = System.nanoTime();
        this.trainingData = loadTrainingData(batchEnvironment);
        this.testingData = loadTestingData(batchEnvironment);
        this.trainedWeightVector = loadWeightVector(batchEnvironment);
        this.dataLoadingTime = System.nanoTime() - nanoTime;
        return this;
    }

    private SvmSgdTsetRunner save(BatchEnvironment batchEnvironment) {
        try {
            saveResults(batchEnvironment);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return this;
    }

    private void saveResults(BatchEnvironment batchEnvironment) throws IOException {
        new ResultsSaver(this.trainingTime, this.testingTime, this.dataLoadingTime, this.dataLoadingTime + this.trainingTime + this.testingTime, this.svmJobParameters, "itr-tset").save();
    }

    private void generateSummary() {
        convert2Seconds();
        this.totalTime = this.initializingDTime + this.dataLoadingDTime + this.trainingDTime + this.testingDTime;
        LOG.info(String.format((((((((((((("\n\n======================================================================================\n") + "\t\t\tIterative SVM Task Summary : [" + this.experimentName + "]\n") + "======================================================================================\n") + "Training Dataset [" + this.svmJobParameters.getTrainingDataDir() + "] \n") + "Testing  Dataset [" + this.svmJobParameters.getTestingDataDir() + "] \n") + "Total Memory [ " + (Runtime.getRuntime().totalMemory() / 1048576.0d) + " MB] \n") + "Maximum Memory [ " + (Runtime.getRuntime().totalMemory() / 1048576.0d) + " MB] \n") + "Data Loading Time (Training + Testing) \t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.dataLoadingDTime)) + "  s \n") + "Training Time \t\t\t\t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.trainingDTime)) + "  s \n") + "Testing Time  \t\t\t\t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.testingDTime)) + "  s \n") + "Total Time (Data Loading Time + Training Time + Testing Time) \t=" + String.format(" %.9f", Double.valueOf(this.totalTime)) + "  s \n") + String.format("Accuracy of the Trained Model \t\t\t\t\t= %2.9f", Double.valueOf(this.accuracy)) + " %%\n") + "======================================================================================\n", new Object[0]));
    }

    private void convert2Seconds() {
        this.initializingDTime = this.initializingTime / 1.0E9d;
        this.dataLoadingDTime = this.dataLoadingTime / 1.0E9d;
        this.trainingDTime = this.trainingTime / 1.0E9d;
        this.testingDTime = this.testingTime / 1.0E9d;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -732742003:
                if (implMethodName.equals("lambda$executeTraining$efd8c19d$1")) {
                    z = true;
                    break;
                }
                break;
            case 114251:
                if (implMethodName.equals("sum")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("edu/iu/dsc/tws/api/tset/fn/ReduceFunc") && serializedLambda.getFunctionalInterfaceMethodName().equals("reduce") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("java/lang/Double") && serializedLambda.getImplMethodSignature().equals("(DD)D")) {
                    return (v0, v1) -> {
                        return Double.sum(v0, v1);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("edu/iu/dsc/tws/api/tset/fn/ReduceFunc") && serializedLambda.getFunctionalInterfaceMethodName().equals("reduce") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("edu/iu/dsc/tws/examples/ml/svm/job/SvmSgdTsetRunner") && serializedLambda.getImplMethodSignature().equals("([D[D)[D")) {
                    return (dArr, dArr2) -> {
                        double[] dArr = new double[dArr.length];
                        try {
                            dArr = Matrix.add(dArr, dArr2);
                        } catch (MatrixMultiplicationException e) {
                            e.printStackTrace();
                        }
                        return dArr;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !SvmSgdTsetRunner.class.desiredAssertionStatus();
        LOG = Logger.getLogger(SvmSgdTsetRunner.class.getName());
    }
}
