package org.deeplearning4j.example.mnist;

import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.dbn.GaussianRectifiedLinearDBN;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.activation.Activations;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/example/mnist/IrisExample.class */
public class IrisExample {
    private static Logger log = LoggerFactory.getLogger(IrisExample.class);

    public static void main(String[] strArr) {
        DataSet dataSet = (DataSet) new IrisDataSetIterator(150, 150).next();
        dataSet.normalizeZeroMeanZeroUnitVariance();
        log.info("Training on " + dataSet.numExamples());
        GaussianRectifiedLinearDBN build = new GaussianRectifiedLinearDBN.Builder().hiddenLayerSizes(new int[]{4, 2, 3}).normalizeByInputRows(true).numberOfInputs(4).numberOfOutPuts(3).useAdaGrad(true).useHiddenActivationsForwardProp(true).withL2(0.01d).useRegularization(false).withActivation(Activations.tanh()).withMomentum(0.1d).build();
        build.pretrain((DoubleMatrix) dataSet.getFirst(), 1, 1.0E-4d, 1000);
        build.finetune((DoubleMatrix) dataSet.getSecond(), 1.0E-4d, 1000);
        Evaluation evaluation = new Evaluation();
        evaluation.eval((DoubleMatrix) dataSet.getSecond(), build.predict((DoubleMatrix) dataSet.getFirst()));
        log.info(evaluation.stats());
    }
}
