package org.deeplearning4j.example.mnist;

import org.apache.commons.math3.random.MersenneTwister;
import org.deeplearning4j.da.DenoisingAutoEncoder;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.RawMnistDataSetIterator;
import org.deeplearning4j.datasets.mnist.draw.DrawMnistGreyScale;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.plot.FilterRenderer;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:org/deeplearning4j/example/mnist/DenoisingAutoEncoderMnistExample.class */
public class DenoisingAutoEncoderMnistExample {
    public static void main(String[] strArr) throws Exception {
        DenoisingAutoEncoder build = new DenoisingAutoEncoder.Builder().numberOfVisible(784).numHidden(500).normalizeByInputRows(true).withLossFunction(NeuralNetwork.LossFunction.NEGATIVELOGLIKELIHOOD).useAdaGrad(true).useRegularization(true).withSparsity(0.0d).withL2(0.01d).withOptmizationAlgo(NeuralNetwork.OptimizationAlgorithm.GRADIENT_DESCENT).withMomentum(0.5d).build();
        RawMnistDataSetIterator rawMnistDataSetIterator = new RawMnistDataSetIterator(10, 30);
        for (int i = 0; i < 20; i++) {
            while (rawMnistDataSetIterator.hasNext()) {
                build.trainTillConvergence((DoubleMatrix) ((DataSet) rawMnistDataSetIterator.next()).getFirst(), 0.1d, new Object[]{Double.valueOf(0.6d), Double.valueOf(0.1d), 1000});
            }
            rawMnistDataSetIterator.reset();
        }
        new FilterRenderer().renderFilters(build.getW(), "example-render.jpg", 28, 28);
        while (rawMnistDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) rawMnistDataSetIterator.next();
            DoubleMatrix reconstruct = build.reconstruct((DoubleMatrix) dataSet.getFirst());
            for (int i2 = 0; i2 < dataSet.numExamples(); i2++) {
                DoubleMatrix mul = ((DoubleMatrix) dataSet.get(i2).getFirst()).mul(255.0d);
                DoubleMatrix mul2 = MatrixUtil.binomial(reconstruct.getRow(i2), 1, new MersenneTwister(123)).mul(255.0d);
                DrawMnistGreyScale drawMnistGreyScale = new DrawMnistGreyScale(mul);
                drawMnistGreyScale.title = "REAL";
                drawMnistGreyScale.draw();
                DrawMnistGreyScale drawMnistGreyScale2 = new DrawMnistGreyScale(mul2, 1000, 1000);
                drawMnistGreyScale2.title = "TEST";
                drawMnistGreyScale2.draw();
                Thread.sleep(10000L);
                drawMnistGreyScale.frame.dispose();
                drawMnistGreyScale2.frame.dispose();
            }
        }
    }
}
