package org.deeplearning4j.autoencoder;

import java.io.Serializable;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:org/deeplearning4j/autoencoder/DeepAutoEncoder.class */
public class DeepAutoEncoder implements Serializable {
    private static final long serialVersionUID = -3571832097247806784L;
    private BaseMultiLayerNetwork encoder;
    private BaseMultiLayerNetwork decoder;
    private Object[] trainingParams;

    public DeepAutoEncoder(BaseMultiLayerNetwork baseMultiLayerNetwork, Object[] objArr) {
        this.encoder = baseMultiLayerNetwork;
        this.trainingParams = objArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void train(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, double d) {
        this.encoder.trainNetwork(doubleMatrix, doubleMatrix2, this.trainingParams);
        this.decoder = new BaseMultiLayerNetwork.Builder().withClazz(this.encoder.getClass()).buildEmpty();
        this.decoder.asDecoder(this.encoder);
        this.decoder.trainNetwork(this.encoder.predict(doubleMatrix), doubleMatrix, this.trainingParams);
    }

    public DoubleMatrix encode(DoubleMatrix doubleMatrix) {
        return this.encoder.predict(doubleMatrix);
    }

    public DoubleMatrix decode(DoubleMatrix doubleMatrix) {
        return this.decoder.predict(doubleMatrix);
    }
}
