package org.deeplearning4j.example.mnist;

import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.RawMnistDataSetIterator;
import org.deeplearning4j.dbn.GaussianRectifiedLinearDBN;
import org.deeplearning4j.eval.Evaluation;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/example/mnist/RawDBNMnistExample.class */
public class RawDBNMnistExample {
    private static Logger log = LoggerFactory.getLogger(RawDBNMnistExample.class);

    public static void main(String[] strArr) throws Exception {
        RawMnistDataSetIterator rawMnistDataSetIterator = new RawMnistDataSetIterator(10, 40);
        GaussianRectifiedLinearDBN build = new GaussianRectifiedLinearDBN.Builder().useAdaGrad(true).useRegularization(false).hiddenLayerSizes(new int[]{500, 400, 250}).normalizeByInputRows(true).numberOfInputs(784).numberOfOutPuts(10).build();
        while (rawMnistDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) rawMnistDataSetIterator.next();
            dataSet.normalizeZeroMeanZeroUnitVariance();
            build.pretrain((DoubleMatrix) dataSet.getFirst(), 1, 1.0E-4d, 10000);
        }
        rawMnistDataSetIterator.reset();
        while (rawMnistDataSetIterator.hasNext()) {
            DataSet dataSet2 = (DataSet) rawMnistDataSetIterator.next();
            dataSet2.normalizeZeroMeanZeroUnitVariance();
            build.setInput((DoubleMatrix) dataSet2.getFirst());
            build.finetune((DoubleMatrix) dataSet2.getSecond(), 0.001d, 10000);
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream("mnist-dbn.bin"));
        build.write(bufferedOutputStream);
        bufferedOutputStream.flush();
        bufferedOutputStream.close();
        log.info("Saved dbn");
        rawMnistDataSetIterator.reset();
        Evaluation evaluation = new Evaluation();
        while (rawMnistDataSetIterator.hasNext()) {
            DataSet dataSet3 = (DataSet) rawMnistDataSetIterator.next();
            evaluation.eval((DoubleMatrix) dataSet3.getSecond(), build.predict((DoubleMatrix) dataSet3.getFirst()));
        }
        log.info("Prediciton f scores and accuracy");
        log.info(evaluation.stats());
    }
}
