package org.deeplearning4j.example.iris;

import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.rbm.GaussianRectifiedLinearRBM;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/example/iris/IrisRBMExample.class */
public class IrisRBMExample {
    private static Logger log = LoggerFactory.getLogger(IrisRBMExample.class);

    public static void main(String[] strArr) {
        IrisDataSetIterator irisDataSetIterator = new IrisDataSetIterator(150, 150);
        DataSet dataSet = (DataSet) irisDataSetIterator.next();
        dataSet.normalizeZeroMeanZeroUnitVariance();
        log.info("Training on " + dataSet.numExamples());
        GaussianRectifiedLinearRBM build = new GaussianRectifiedLinearRBM.Builder().numberOfVisible(irisDataSetIterator.inputColumns()).useAdaGrad(true).numHidden(10).normalizeByInputRows(false).useRegularization(false).build();
        build.trainTillConvergence((DoubleMatrix) dataSet.getFirst(), 0.001d, new Object[]{1, Double.valueOf(0.001d), 2000});
        log.info("\nData " + String.valueOf("\n" + dataSet.getFirst()).replaceAll(";", "\n"));
        log.info("\nReconstruct " + String.valueOf("\n" + build.reconstruct(build.getInput())).replaceAll(";", "\n"));
    }
}
