package org.deeplearning4j.spark.impl.paramavg;

import java.util.Collection;
import java.util.Iterator;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerMultiDataSetFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPDSFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPDSMDSFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPathFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPathMDSFlatMap;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.util.SparkUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.class */
public class ParameterAveragingTrainingMaster implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {
    private static final Logger log = LoggerFactory.getLogger(ParameterAveragingTrainingMaster.class);
    private static final int COALESCE_THRESHOLD = 3;
    private boolean saveUpdater;
    private Integer numWorkers;
    private int rddDataSetNumExamples;
    private int batchSizePerWorker;
    private int averagingFrequency;
    private int prefetchNumBatches;
    private boolean collectTrainingStats;
    private ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats;
    private Collection<IterationListener> listeners;
    private int iterationCount;
    private Repartition repartition;
    private RepartitionStrategy repartitionStrategy;

    /* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster$Builder.class */
    public static class Builder {
        private boolean saveUpdater;
        private Integer numWorkers;
        private int rddDataSetNumExamples;
        private int batchSizePerWorker;
        private int averagingFrequency;
        private int prefetchNumBatches;
        private Repartition repartition;
        private RepartitionStrategy repartitionStrategy;

        public Builder(int i) {
            this(null, i);
        }

        public Builder(Integer num, int i) {
            this.batchSizePerWorker = 16;
            this.averagingFrequency = 5;
            this.prefetchNumBatches = 0;
            this.repartition = Repartition.Always;
            this.repartitionStrategy = RepartitionStrategy.Balanced;
            if (num != null && num.intValue() <= 0) {
                throw new IllegalArgumentException("Invalid number of workers: " + num + " (must be >= 1)");
            }
            if (i <= 0) {
                throw new IllegalArgumentException("Invalid rdd data set size: " + i + " (must be >= 1)");
            }
            this.numWorkers = num;
            this.rddDataSetNumExamples = i;
        }

        public Builder batchSizePerWorker(int i) {
            this.batchSizePerWorker = i;
            return this;
        }

        public Builder averagingFrequency(int i) {
            if (i <= 0) {
                throw new IllegalArgumentException("Ivalid input: averaging frequency must be >= 1");
            }
            this.averagingFrequency = i;
            return this;
        }

        public Builder workerPrefetchNumBatches(int i) {
            this.prefetchNumBatches = i;
            return this;
        }

        public Builder saveUpdater(boolean z) {
            this.saveUpdater = z;
            return this;
        }

        public Builder repartionData(Repartition repartition) {
            this.repartition = repartition;
            return this;
        }

        public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) {
            this.repartitionStrategy = repartitionStrategy;
            return this;
        }

        public ParameterAveragingTrainingMaster build() {
            return new ParameterAveragingTrainingMaster(this);
        }
    }

    private ParameterAveragingTrainingMaster(Builder builder) {
        this.iterationCount = 0;
        this.saveUpdater = builder.saveUpdater;
        this.numWorkers = builder.numWorkers;
        this.rddDataSetNumExamples = builder.rddDataSetNumExamples;
        this.batchSizePerWorker = builder.batchSizePerWorker;
        this.averagingFrequency = builder.averagingFrequency;
        this.prefetchNumBatches = builder.prefetchNumBatches;
        this.repartition = builder.repartition;
        this.repartitionStrategy = builder.repartitionStrategy;
    }

    public ParameterAveragingTrainingMaster(boolean z, Integer num, int i, int i2, int i3, int i4) {
        this(z, num, i, i2, i3, i4, Repartition.Always, RepartitionStrategy.Balanced, false);
    }

    public ParameterAveragingTrainingMaster(boolean z, Integer num, int i, int i2, int i3, int i4, Repartition repartition, RepartitionStrategy repartitionStrategy, boolean z2) {
        this.iterationCount = 0;
        if (num.intValue() <= 0) {
            throw new IllegalArgumentException("Invalid number of workers: " + num + " (must be >= 1)");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid rdd data set size: " + i + " (must be >= 1)");
        }
        this.saveUpdater = z;
        this.numWorkers = num;
        this.rddDataSetNumExamples = i;
        this.batchSizePerWorker = i2;
        this.averagingFrequency = i3;
        this.prefetchNumBatches = i4;
        this.collectTrainingStats = z2;
        this.repartition = repartition;
        this.repartitionStrategy = repartitionStrategy;
        if (z2) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer sparkDl4jMultiLayer) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkDl4jMultiLayer.getNetwork().getLayerWiseConfigurations(), sparkDl4jMultiLayer.getNetwork().params(), sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray());
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        Broadcast broadcast = sparkDl4jMultiLayer.getSparkContext().broadcast(netBroadcastTuple);
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new ParameterAveragingTrainingWorker(broadcast, this.saveUpdater, new WorkerConfiguration(false, this.batchSizePerWorker, this.averagingFrequency, this.prefetchNumBatches, this.collectTrainingStats));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph sparkComputationGraph) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkComputationGraph.getNetwork().getConfiguration(), sparkComputationGraph.getNetwork().params(), sparkComputationGraph.getNetwork().getUpdater().getStateViewArray());
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        Broadcast broadcast = sparkComputationGraph.getSparkContext().broadcast(netBroadcastTuple);
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new ParameterAveragingTrainingWorker(broadcast, this.saveUpdater, new WorkerConfiguration(true, this.batchSizePerWorker, this.averagingFrequency, this.prefetchNumBatches, this.collectTrainingStats));
    }

    private int numObjectsEachWorker() {
        return (this.batchSizePerWorker * this.averagingFrequency) / this.rddDataSetNumExamples;
    }

    private int getNumDataSetObjectsPerSplit() {
        int intValue;
        if (this.rddDataSetNumExamples == 1) {
            intValue = this.numWorkers.intValue() * this.batchSizePerWorker * this.averagingFrequency;
        } else {
            int numObjectsEachWorker = numObjectsEachWorker();
            if (numObjectsEachWorker < 1) {
                numObjectsEachWorker = 1;
            }
            intValue = numObjectsEachWorker * this.numWorkers.intValue();
        }
        return intValue;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        javaRDD.persist(StorageLevel.MEMORY_ONLY());
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDD.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<DataSet>[] balancedRandomSplit = SparkUtils.balancedRandomSplit((int) count, numDataSetObjectsPerSplit, javaRDD);
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        int i = 1;
        for (JavaRDD<DataSet> javaRDD2 : balancedRandomSplit) {
            int i2 = i;
            i++;
            doIteration(sparkDl4jMultiLayer, javaRDD2, i2, balancedRandomSplit.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) count);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        int size = javaPairRDD.partitions().size();
        if (size >= COALESCE_THRESHOLD * this.numWorkers.intValue()) {
            log.info("Coalesing PortableDataStreams from {} to {} partitions", Integer.valueOf(size), this.numWorkers);
            javaPairRDD = javaPairRDD.coalesce(this.numWorkers.intValue());
        }
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaPairRDD.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaPairRDD[] balancedRandomSplit = SparkUtils.balancedRandomSplit((int) count, numDataSetObjectsPerSplit, javaPairRDD);
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        int i = 1;
        for (JavaPairRDD javaPairRDD2 : balancedRandomSplit) {
            int i2 = i;
            i++;
            doIterationPDS(sparkDl4jMultiLayer, null, javaPairRDD2.values(), i2, balancedRandomSplit.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) count);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<String> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDD.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<String>[] balancedRandomSplit = SparkUtils.balancedRandomSplit((int) count, numDataSetObjectsPerSplit, javaRDD);
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        int i = 1;
        for (JavaRDD<String> javaRDD2 : balancedRandomSplit) {
            int i2 = i;
            i++;
            doIterationPaths(sparkDl4jMultiLayer, null, javaRDD2, i2, balancedRandomSplit.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) count);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        executeTrainingMDS(sparkComputationGraph, javaRDD.map(new DataSetToMultiDataSetFn()));
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDD.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<MultiDataSet>[] balancedRandomSplit = SparkUtils.balancedRandomSplit((int) count, numDataSetObjectsPerSplit, javaRDD);
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        int i = 1;
        for (JavaRDD<MultiDataSet> javaRDD2 : balancedRandomSplit) {
            int i2 = i;
            i++;
            doIteration(sparkComputationGraph, javaRDD2, i2, balancedRandomSplit.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) count);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkComputationGraph sparkComputationGraph, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        int size = javaPairRDD.partitions().size();
        if (size >= COALESCE_THRESHOLD * this.numWorkers.intValue()) {
            log.info("Coalesing streams from {} to {} partitions", Integer.valueOf(size), this.numWorkers);
            javaPairRDD = javaPairRDD.coalesce(this.numWorkers.intValue());
        }
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaPairRDD.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaPairRDD[] balancedRandomSplit = SparkUtils.balancedRandomSplit((int) count, numDataSetObjectsPerSplit, javaPairRDD, new Random().nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        int i = 1;
        for (JavaPairRDD javaPairRDD2 : balancedRandomSplit) {
            int i2 = i;
            i++;
            doIterationPDS(null, sparkComputationGraph, javaPairRDD2.values(), i2, balancedRandomSplit.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) count);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingMDS(SparkComputationGraph sparkComputationGraph, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaPairRDD.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaPairRDD[] balancedRandomSplit = SparkUtils.balancedRandomSplit((int) count, numDataSetObjectsPerSplit, javaPairRDD, new Random().nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        int i = 1;
        for (JavaPairRDD javaPairRDD2 : balancedRandomSplit) {
            JavaRDD values = javaPairRDD2.values();
            if (this.collectTrainingStats) {
                this.stats.logRepartitionStart();
            }
            JavaRDD<PortableDataStream> repartition = SparkUtils.repartition(values, this.repartition, this.repartitionStrategy, numObjectsEachWorker(), this.numWorkers.intValue());
            if (this.collectTrainingStats && this.repartition != Repartition.Never) {
                this.stats.logRepartitionEnd();
            }
            int i2 = i;
            i++;
            doIterationPDS_MDS(sparkComputationGraph, repartition, i2, balancedRandomSplit.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) count);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingPaths(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDD.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<String>[] balancedRandomSplit = SparkUtils.balancedRandomSplit((int) count, numDataSetObjectsPerSplit, javaRDD);
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        int i = 1;
        for (JavaRDD<String> javaRDD2 : balancedRandomSplit) {
            int i2 = i;
            i++;
            doIterationPaths(null, sparkComputationGraph, javaRDD2, i2, balancedRandomSplit.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) count);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingPathsMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDD.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<String>[] balancedRandomSplit = SparkUtils.balancedRandomSplit((int) count, numDataSetObjectsPerSplit, javaRDD);
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        int i = 1;
        for (JavaRDD<String> javaRDD2 : balancedRandomSplit) {
            int i2 = i;
            i++;
            doIterationPathsMDS(sparkComputationGraph, javaRDD2, i2, balancedRandomSplit.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) count);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void setCollectTrainingStats(boolean z) {
        this.collectTrainingStats = z;
        if (!z) {
            this.stats = null;
        } else if (this.stats == null) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public boolean getIsCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public SparkTrainingStats getTrainingStats() {
        if (this.stats != null) {
            return this.stats.build();
        }
        return null;
    }

    private void doIteration(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, null, repartition.mapPartitions(new ExecuteWorkerFlatMap(getWorkerInstance(sparkDl4jMultiLayer))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIterationPDS(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<PortableDataStream> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartition.mapPartitions(sparkDl4jMultiLayer != null ? new ExecuteWorkerPDSFlatMap(getWorkerInstance(sparkDl4jMultiLayer)) : new ExecuteWorkerPDSFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIterationPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartition.mapPartitions(sparkDl4jMultiLayer != null ? new ExecuteWorkerPathFlatMap(getWorkerInstance(sparkDl4jMultiLayer)) : new ExecuteWorkerPathFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIterationPathsMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new ExecuteWorkerPathMDSFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIteration(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(), this.numWorkers.intValue());
        int size = javaRDD.partitions().size();
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new ExecuteWorkerMultiDataSetFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIterationPDS_MDS(SparkComputationGraph sparkComputationGraph, JavaRDD<PortableDataStream> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new ExecuteWorkerPDSMDSFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void processResults(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<ParameterAveragingTrainingResult> javaRDD, int i, int i2) {
        if (this.collectTrainingStats) {
            this.stats.logAggregateStartTime();
        }
        ParameterAveragingAggregationTuple parameterAveragingAggregationTuple = (ParameterAveragingAggregationTuple) javaRDD.aggregate((Object) null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
        INDArray parametersSum = parameterAveragingAggregationTuple.getParametersSum();
        int aggregationsCount = parameterAveragingAggregationTuple.getAggregationsCount();
        SparkTrainingStats sparkTrainingStats = parameterAveragingAggregationTuple.getSparkTrainingStats();
        if (this.collectTrainingStats) {
            this.stats.logAggregationEndTime();
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterStart();
        }
        parametersSum.divi(Integer.valueOf(aggregationsCount));
        INDArray updaterStateSum = parameterAveragingAggregationTuple.getUpdaterStateSum();
        if (updaterStateSum != null) {
            updaterStateSum.divi(Integer.valueOf(aggregationsCount));
        }
        if (sparkDl4jMultiLayer != null) {
            MultiLayerNetwork network = sparkDl4jMultiLayer.getNetwork();
            network.setParameters(parametersSum);
            if (updaterStateSum != null) {
                network.getUpdater().setStateViewArray((Layer) null, updaterStateSum, false);
            }
            sparkDl4jMultiLayer.setScore(parameterAveragingAggregationTuple.getScoreSum() / parameterAveragingAggregationTuple.getAggregationsCount());
        } else {
            ComputationGraph network2 = sparkComputationGraph.getNetwork();
            network2.setParams(parametersSum);
            if (updaterStateSum != null) {
                network2.getUpdater().setStateViewArray(updaterStateSum);
            }
            sparkComputationGraph.setScore(parameterAveragingAggregationTuple.getScoreSum() / parameterAveragingAggregationTuple.getAggregationsCount());
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
            this.stats.addWorkerStats(sparkTrainingStats);
        }
        log.info("Completed training of split {} of {}", Integer.valueOf(i), Integer.valueOf(i2));
        if (this.listeners != null) {
            if (sparkDl4jMultiLayer != null) {
                MultiLayerNetwork network3 = sparkDl4jMultiLayer.getNetwork();
                network3.setScore(sparkDl4jMultiLayer.getScore());
                Iterator<IterationListener> it = this.listeners.iterator();
                while (it.hasNext()) {
                    it.next().iterationDone(network3, this.iterationCount);
                }
            } else {
                ComputationGraph network4 = sparkComputationGraph.getNetwork();
                network4.setScore(sparkComputationGraph.getScore());
                Iterator<IterationListener> it2 = this.listeners.iterator();
                while (it2.hasNext()) {
                    it2.next().iterationDone(network4, this.iterationCount);
                }
            }
        }
        this.iterationCount++;
    }

    public boolean isSaveUpdater() {
        return this.saveUpdater;
    }

    public Integer getNumWorkers() {
        return this.numWorkers;
    }

    public int getRddDataSetNumExamples() {
        return this.rddDataSetNumExamples;
    }

    public int getBatchSizePerWorker() {
        return this.batchSizePerWorker;
    }

    public int getAveragingFrequency() {
        return this.averagingFrequency;
    }

    public int getPrefetchNumBatches() {
        return this.prefetchNumBatches;
    }

    public boolean isCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    public ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper getStats() {
        return this.stats;
    }

    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    public int getIterationCount() {
        return this.iterationCount;
    }

    public Repartition getRepartition() {
        return this.repartition;
    }

    public RepartitionStrategy getRepartitionStrategy() {
        return this.repartitionStrategy;
    }

    public void setSaveUpdater(boolean z) {
        this.saveUpdater = z;
    }

    public void setNumWorkers(Integer num) {
        this.numWorkers = num;
    }

    public void setRddDataSetNumExamples(int i) {
        this.rddDataSetNumExamples = i;
    }

    public void setBatchSizePerWorker(int i) {
        this.batchSizePerWorker = i;
    }

    public void setAveragingFrequency(int i) {
        this.averagingFrequency = i;
    }

    public void setPrefetchNumBatches(int i) {
        this.prefetchNumBatches = i;
    }

    public void setStats(ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper parameterAveragingTrainingMasterStatsHelper) {
        this.stats = parameterAveragingTrainingMasterStatsHelper;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void setListeners(Collection<IterationListener> collection) {
        this.listeners = collection;
    }

    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    public void setRepartition(Repartition repartition) {
        this.repartition = repartition;
    }

    public void setRepartitionStrategy(RepartitionStrategy repartitionStrategy) {
        this.repartitionStrategy = repartitionStrategy;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParameterAveragingTrainingMaster)) {
            return false;
        }
        ParameterAveragingTrainingMaster parameterAveragingTrainingMaster = (ParameterAveragingTrainingMaster) obj;
        if (!parameterAveragingTrainingMaster.canEqual(this) || isSaveUpdater() != parameterAveragingTrainingMaster.isSaveUpdater()) {
            return false;
        }
        Integer numWorkers = getNumWorkers();
        Integer numWorkers2 = parameterAveragingTrainingMaster.getNumWorkers();
        if (numWorkers == null) {
            if (numWorkers2 != null) {
                return false;
            }
        } else if (!numWorkers.equals(numWorkers2)) {
            return false;
        }
        if (getRddDataSetNumExamples() != parameterAveragingTrainingMaster.getRddDataSetNumExamples() || getBatchSizePerWorker() != parameterAveragingTrainingMaster.getBatchSizePerWorker() || getAveragingFrequency() != parameterAveragingTrainingMaster.getAveragingFrequency() || getPrefetchNumBatches() != parameterAveragingTrainingMaster.getPrefetchNumBatches() || isCollectTrainingStats() != parameterAveragingTrainingMaster.isCollectTrainingStats()) {
            return false;
        }
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats = getStats();
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats2 = parameterAveragingTrainingMaster.getStats();
        if (stats == null) {
            if (stats2 != null) {
                return false;
            }
        } else if (!stats.equals(stats2)) {
            return false;
        }
        Collection<IterationListener> listeners = getListeners();
        Collection<IterationListener> listeners2 = parameterAveragingTrainingMaster.getListeners();
        if (listeners == null) {
            if (listeners2 != null) {
                return false;
            }
        } else if (!listeners.equals(listeners2)) {
            return false;
        }
        if (getIterationCount() != parameterAveragingTrainingMaster.getIterationCount()) {
            return false;
        }
        Repartition repartition = getRepartition();
        Repartition repartition2 = parameterAveragingTrainingMaster.getRepartition();
        if (repartition == null) {
            if (repartition2 != null) {
                return false;
            }
        } else if (!repartition.equals(repartition2)) {
            return false;
        }
        RepartitionStrategy repartitionStrategy = getRepartitionStrategy();
        RepartitionStrategy repartitionStrategy2 = parameterAveragingTrainingMaster.getRepartitionStrategy();
        return repartitionStrategy == null ? repartitionStrategy2 == null : repartitionStrategy.equals(repartitionStrategy2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParameterAveragingTrainingMaster;
    }

    public int hashCode() {
        int i = (1 * 59) + (isSaveUpdater() ? 79 : 97);
        Integer numWorkers = getNumWorkers();
        int hashCode = (((((((((((i * 59) + (numWorkers == null ? 0 : numWorkers.hashCode())) * 59) + getRddDataSetNumExamples()) * 59) + getBatchSizePerWorker()) * 59) + getAveragingFrequency()) * 59) + getPrefetchNumBatches()) * 59) + (isCollectTrainingStats() ? 79 : 97);
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats = getStats();
        int hashCode2 = (hashCode * 59) + (stats == null ? 0 : stats.hashCode());
        Collection<IterationListener> listeners = getListeners();
        int hashCode3 = (((hashCode2 * 59) + (listeners == null ? 0 : listeners.hashCode())) * 59) + getIterationCount();
        Repartition repartition = getRepartition();
        int hashCode4 = (hashCode3 * 59) + (repartition == null ? 0 : repartition.hashCode());
        RepartitionStrategy repartitionStrategy = getRepartitionStrategy();
        return (hashCode4 * 59) + (repartitionStrategy == null ? 0 : repartitionStrategy.hashCode());
    }

    public String toString() {
        return "ParameterAveragingTrainingMaster(saveUpdater=" + isSaveUpdater() + ", numWorkers=" + getNumWorkers() + ", rddDataSetNumExamples=" + getRddDataSetNumExamples() + ", batchSizePerWorker=" + getBatchSizePerWorker() + ", averagingFrequency=" + getAveragingFrequency() + ", prefetchNumBatches=" + getPrefetchNumBatches() + ", collectTrainingStats=" + isCollectTrainingStats() + ", stats=" + getStats() + ", listeners=" + getListeners() + ", iterationCount=" + getIterationCount() + ", repartition=" + getRepartition() + ", repartitionStrategy=" + getRepartitionStrategy() + ")";
    }
}
