package org.deeplearning4j.example.mnist;

import java.io.File;
import java.util.ArrayList;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.dbn.CDBN;
import org.deeplearning4j.gradient.multilayer.WeightPlotListener;
import org.deeplearning4j.iterativereduce.actor.multilayer.ActorNetworkRunner;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.util.SerializationUtils;

/* loaded from: input_file:org/deeplearning4j/example/mnist/Classify2sAnd4s.class */
public class Classify2sAnd4s {
    public static void main(String[] strArr) throws Exception {
        File file = new File("twoandfours.bin");
        if (!file.exists()) {
            Create2sAnd4sDataSet.main(null);
        }
        DataSet load = DataSet.load(file);
        ListDataSetIterator listDataSetIterator = new ListDataSetIterator(load.asList());
        CDBN cdbn = null;
        ArrayList arrayList = new ArrayList();
        new WeightPlotListener();
        Conf conf = new Conf();
        conf.initFromData(load);
        conf.setFinetuneEpochs(10000);
        conf.setFinetuneLearningRate(0.1d);
        conf.setLayerSizes(new int[]{500, 400, 250});
        conf.setUseAdaGrad(true);
        conf.setSplit(10);
        conf.setNumPasses(100);
        conf.setMultiLayerClazz(CDBN.class);
        conf.setUseRegularization(false);
        conf.setDeepLearningParams(new Object[]{1, Double.valueOf(0.1d), 10000});
        conf.setMultiLayerGradientListeners(arrayList);
        if (strArr.length >= 1) {
            cdbn = (CDBN) SerializationUtils.readObject(new File(strArr[0]));
        }
        ActorNetworkRunner actorNetworkRunner = cdbn == null ? new ActorNetworkRunner("master", listDataSetIterator) : new ActorNetworkRunner("master", listDataSetIterator, cdbn);
        actorNetworkRunner.setup(conf);
        actorNetworkRunner.train();
    }
}
