package org.deeplearning4j.iterativereduce.impl.single;

import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.RecordReader;
import org.deeplearning4j.iterativereduce.runtime.ComputableWorker;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.factory.LayerFactories;
import org.deeplearning4j.scaleout.api.ir.ParameterVectorUpdateable;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/impl/single/WorkerNode.class */
public class WorkerNode implements ComputableWorker<ParameterVectorUpdateable>, DeepLearningConfigurable {
    private static final Logger LOG = LoggerFactory.getLogger(WorkerNode.class);
    private Layer neuralNetwork;
    private RecordReader recordParser;

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable compute() {
        while (this.recordParser.nextKeyValue()) {
            try {
                this.neuralNetwork.fit(((DataSet) this.recordParser.getCurrentValue()).getFeatureMatrix());
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return new ParameterVectorUpdateable(this.neuralNetwork.params());
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void setRecordReader(RecordReader recordReader) {
        this.recordParser = recordReader;
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable compute(List<ParameterVectorUpdateable> list) {
        return compute();
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable getResults() {
        return new ParameterVectorUpdateable(this.neuralNetwork.params());
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void setup(Configuration configuration) {
        NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(configuration.get("org.deeplearning4j.scaleout.neuralnetconf"));
        this.neuralNetwork = LayerFactories.getFactory(fromJson.getLayer()).create(fromJson);
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void update(ParameterVectorUpdateable parameterVectorUpdateable) {
        this.neuralNetwork.setParams(parameterVectorUpdateable.get());
    }

    public void setup(org.canova.api.conf.Configuration configuration) {
    }
}
