package org.deeplearning4j.example.mnist;

import java.io.File;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.dbn.CDBN;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.util.SerializationUtils;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/example/mnist/Test2sAnd4s.class */
public class Test2sAnd4s {
    private static Logger log = LoggerFactory.getLogger(Test2sAnd4s.class);

    public static void main(String[] strArr) throws Exception {
        File file = new File("twoandfours.bin");
        if (!file.exists()) {
            Create2sAnd4sDataSet.main(null);
        }
        ListDataSetIterator listDataSetIterator = new ListDataSetIterator(DataSet.load(file).asList());
        CDBN cdbn = null;
        if (strArr.length >= 1) {
            cdbn = (CDBN) SerializationUtils.readObject(new File(strArr[0]));
        }
        Evaluation evaluation = new Evaluation();
        while (listDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) listDataSetIterator.next();
            evaluation.eval((DoubleMatrix) dataSet.getSecond(), cdbn.predict((DoubleMatrix) dataSet.getFirst()));
            log.info("Current stats " + evaluation.stats());
        }
        log.info("Prediction f scores and accuracy");
        log.info(evaluation.stats());
    }
}
