package org.deeplearning4j.example.lfw;

import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator;
import org.deeplearning4j.dbn.GaussianRectifiedLinearDBN;
import org.deeplearning4j.eval.Evaluation;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/example/lfw/LFWExample.class */
public class LFWExample {
    private static Logger log = LoggerFactory.getLogger(LFWExample.class);

    public static void main(String[] strArr) throws Exception {
        LFWDataSetIterator lFWDataSetIterator = new LFWDataSetIterator(10, 10);
        GaussianRectifiedLinearDBN build = new GaussianRectifiedLinearDBN.Builder().useAdaGrad(true).useRegularization(false).hiddenLayerSizes(new int[]{500, 400, 250}).normalizeByInputRows(true).numberOfInputs(lFWDataSetIterator.inputColumns()).numberOfOutPuts(lFWDataSetIterator.totalOutcomes()).build();
        while (lFWDataSetIterator.hasNext()) {
            build.pretrain((DoubleMatrix) ((DataSet) lFWDataSetIterator.next()).getFirst(), 1, 0.001d, 10000);
        }
        lFWDataSetIterator.reset();
        while (lFWDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) lFWDataSetIterator.next();
            build.setInput((DoubleMatrix) dataSet.getFirst());
            build.finetune((DoubleMatrix) dataSet.getSecond(), 0.001d, 10000);
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream("mnist-dbn.bin"));
        build.write(bufferedOutputStream);
        bufferedOutputStream.flush();
        bufferedOutputStream.close();
        log.info("Saved dbn");
        lFWDataSetIterator.reset();
        Evaluation evaluation = new Evaluation();
        while (lFWDataSetIterator.hasNext()) {
            DataSet dataSet2 = (DataSet) lFWDataSetIterator.next();
            evaluation.eval((DoubleMatrix) dataSet2.getSecond(), build.predict((DoubleMatrix) dataSet2.getFirst()));
        }
        log.info("Prediciton f scores and accuracy");
        log.info(evaluation.stats());
    }
}
