package org.deeplearning4j.example.mnist;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.RawMnistDataSetIterator;
import org.deeplearning4j.dbn.DBN;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.util.SerializationUtils;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/example/mnist/RawMnistGradientDescent.class */
public class RawMnistGradientDescent {
    private static Logger log = LoggerFactory.getLogger(RawMnistGradientDescent.class);

    public static void main(String[] strArr) throws Exception {
        RawMnistDataSetIterator listDataSetIterator;
        if (strArr.length < 2) {
            listDataSetIterator = new RawMnistDataSetIterator(10, 60000);
        } else {
            int parseInt = Integer.parseInt(strArr[1]);
            List asList = ((DataSet) new RawMnistDataSetIterator(60000, 60000).next()).asList();
            listDataSetIterator = new ListDataSetIterator(asList.subList(parseInt, asList.size()), 10);
        }
        DBN dbn = strArr.length < 2 ? (DBN) new DBN.Builder().useAdaGrad(true).hiddenLayerSizes(new int[]{500, 400, 250}).renderWeights(100).numberOfInputs(784).numberOfOutPuts(10).useRegularization(true).build() : (DBN) SerializationUtils.readObject(new File(strArr[0]));
        int i = 0;
        while (listDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) listDataSetIterator.next();
            long currentTimeMillis = System.currentTimeMillis();
            dbn.pretrain((DoubleMatrix) dataSet.getFirst(), 1, 1.0E-4d, 1000);
            log.info("Pretrain took " + TimeUnit.MILLISECONDS.toSeconds(System.currentTimeMillis() - currentTimeMillis) + " seconds");
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream("mnist-pretrain-dbn.bin-" + i + "-sgd"));
            dbn.write(bufferedOutputStream);
            bufferedOutputStream.flush();
            bufferedOutputStream.close();
            log.info("Saved dbn");
            i++;
        }
        BufferedOutputStream bufferedOutputStream2 = new BufferedOutputStream(new FileOutputStream("mnist-dbn.bin"));
        dbn.write(bufferedOutputStream2);
        bufferedOutputStream2.flush();
        bufferedOutputStream2.close();
        log.info("Saved dbn");
        listDataSetIterator.reset();
        while (listDataSetIterator.hasNext()) {
            dbn.finetune((DoubleMatrix) ((DataSet) listDataSetIterator.next()).getSecond(), 0.01d, 1000);
        }
        listDataSetIterator.reset();
        Evaluation evaluation = new Evaluation();
        while (listDataSetIterator.hasNext()) {
            DataSet dataSet2 = (DataSet) listDataSetIterator.next();
            evaluation.eval((DoubleMatrix) dataSet2.getSecond(), dbn.predict((DoubleMatrix) dataSet2.getFirst()));
        }
        log.info("Prediciton f scores and accuracy");
        log.info(evaluation.stats());
    }
}
