package org.deeplearning4j.spark.impl.computationgraph;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.spark.impl.common.misc.ScoreReport;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple3;

/* loaded from: input_file:org/deeplearning4j/spark/impl/computationgraph/IterativeReduceFlatMapCG.class */
public class IterativeReduceFlatMapCG implements FlatMapFunction<Iterator<MultiDataSet>, Tuple3<INDArray, ComputationGraphUpdater, ScoreReport>> {
    protected static Logger log = LoggerFactory.getLogger(IterativeReduceFlatMapCG.class);
    private String json;
    private Broadcast<INDArray> params;
    private Broadcast<ComputationGraphUpdater> updater;

    public IterativeReduceFlatMapCG(String str, Broadcast<INDArray> broadcast, Broadcast<ComputationGraphUpdater> broadcast2) {
        this.json = str;
        this.params = broadcast;
        this.updater = broadcast2;
        if (broadcast2.getValue() == null) {
            throw new IllegalArgumentException("Updater shouldn't be null");
        }
    }

    public Iterable<Tuple3<INDArray, ComputationGraphUpdater, ScoreReport>> call(Iterator<MultiDataSet> it) throws Exception {
        if (!it.hasNext()) {
            return Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        org.nd4j.linalg.dataset.MultiDataSet merge = org.nd4j.linalg.dataset.MultiDataSet.merge(arrayList);
        ComputationGraph computationGraph = new ComputationGraph(ComputationGraphConfiguration.fromJson(this.json));
        computationGraph.setInitDone(true);
        computationGraph.init();
        computationGraph.setListeners(new IterationListener[]{new ScoreIterationListener(1)});
        INDArray dup = ((INDArray) this.params.getValue()).dup();
        ComputationGraphUpdater clone = ((ComputationGraphUpdater) this.updater.getValue()).clone();
        if (dup.length() != computationGraph.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast parameters");
        }
        computationGraph.setParams(dup);
        computationGraph.setUpdater(clone);
        computationGraph.fit(merge);
        ScoreReport scoreReport = new ScoreReport();
        scoreReport.setS(computationGraph.score());
        scoreReport.setM(Runtime.getRuntime().maxMemory());
        return Collections.singletonList(new Tuple3(computationGraph.params(false), computationGraph.getUpdater(), scoreReport));
    }
}
