package org.deeplearning4j.nn;

import java.io.Serializable;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.nn.activation.ActivationFunction;
import org.deeplearning4j.nn.activation.Sigmoid;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:org/deeplearning4j/nn/HiddenLayer.class */
public class HiddenLayer implements Serializable {
    private static final long serialVersionUID = 915783367350830495L;
    private int nIn;
    private int nOut;
    private DoubleMatrix W;
    private DoubleMatrix b;
    private RandomGenerator rng;
    private DoubleMatrix input;
    private ActivationFunction activationFunction;
    private RealDistribution dist;

    /* loaded from: input_file:org/deeplearning4j/nn/HiddenLayer$Builder.class */
    public static class Builder {
        private int nIn;
        private int nOut;
        private DoubleMatrix W;
        private DoubleMatrix b;
        private RandomGenerator rng;
        private DoubleMatrix input;
        private ActivationFunction activationFunction = new Sigmoid();
        private RealDistribution dist;

        public Builder dist(RealDistribution realDistribution) {
            this.dist = realDistribution;
            return this;
        }

        public Builder nIn(int i) {
            this.nIn = i;
            return this;
        }

        public Builder nOut(int i) {
            this.nOut = i;
            return this;
        }

        public Builder withWeights(DoubleMatrix doubleMatrix) {
            this.W = doubleMatrix;
            return this;
        }

        public Builder withRng(RandomGenerator randomGenerator) {
            this.rng = randomGenerator;
            return this;
        }

        public Builder withActivation(ActivationFunction activationFunction) {
            this.activationFunction = activationFunction;
            return this;
        }

        public Builder withBias(DoubleMatrix doubleMatrix) {
            this.b = doubleMatrix;
            return this;
        }

        public Builder withInput(DoubleMatrix doubleMatrix) {
            this.input = doubleMatrix;
            return this;
        }

        public HiddenLayer build() {
            HiddenLayer hiddenLayer = new HiddenLayer(this.nIn, this.nOut, this.W, this.b, this.rng, this.input);
            hiddenLayer.activationFunction = this.activationFunction;
            hiddenLayer.dist = this.dist;
            return hiddenLayer;
        }
    }

    private HiddenLayer() {
        this.activationFunction = new Sigmoid();
    }

    public HiddenLayer(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix3, ActivationFunction activationFunction) {
        this(i, i2, doubleMatrix, doubleMatrix2, randomGenerator, doubleMatrix3, activationFunction, null);
    }

    public HiddenLayer(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix3) {
        this(i, i2, doubleMatrix, doubleMatrix2, randomGenerator, doubleMatrix3, null, null);
    }

    public HiddenLayer(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix3, ActivationFunction activationFunction, RealDistribution realDistribution) {
        this.activationFunction = new Sigmoid();
        this.nIn = i;
        this.nOut = i2;
        this.input = doubleMatrix3;
        if (activationFunction != null) {
            this.activationFunction = activationFunction;
        }
        if (randomGenerator == null) {
            this.rng = new MersenneTwister(1234);
        } else {
            this.rng = randomGenerator;
        }
        if (realDistribution == null) {
            this.dist = new NormalDistribution(this.rng, 0.0d, 0.01d, 1.0E-9d);
        } else {
            this.dist = realDistribution;
        }
        if (doubleMatrix == null) {
            this.W = DoubleMatrix.zeros(i, i2);
            for (int i3 = 0; i3 < this.W.rows; i3++) {
                this.W.putRow(i3, new DoubleMatrix(this.dist.sample(this.W.columns)));
            }
        } else {
            this.W = doubleMatrix;
        }
        if (doubleMatrix2 == null) {
            this.b = DoubleMatrix.zeros(i2);
        } else {
            this.b = doubleMatrix2;
        }
    }

    public HiddenLayer(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix3, RealDistribution realDistribution) {
        this.activationFunction = new Sigmoid();
        this.nIn = i;
        this.nOut = i2;
        this.input = doubleMatrix3;
        if (randomGenerator == null) {
            this.rng = new MersenneTwister(1234);
        } else {
            this.rng = randomGenerator;
        }
        if (realDistribution == null) {
            this.dist = new NormalDistribution(this.rng, 0.0d, 0.01d, 1.0E-9d);
        } else {
            this.dist = realDistribution;
        }
        if (doubleMatrix == null) {
            this.W = DoubleMatrix.zeros(i, i2);
            for (int i3 = 0; i3 < this.W.rows; i3++) {
                this.W.putRow(i3, new DoubleMatrix(this.dist.sample(this.W.columns)));
            }
        } else {
            this.W = doubleMatrix;
        }
        if (doubleMatrix2 == null) {
            this.b = DoubleMatrix.zeros(i2);
        } else {
            this.b = doubleMatrix2;
        }
    }

    public synchronized int getnIn() {
        return this.nIn;
    }

    public synchronized void setnIn(int i) {
        this.nIn = i;
    }

    public synchronized int getnOut() {
        return this.nOut;
    }

    public synchronized void setnOut(int i) {
        this.nOut = i;
    }

    public synchronized DoubleMatrix getW() {
        return this.W;
    }

    public synchronized void setW(DoubleMatrix doubleMatrix) {
        this.W = doubleMatrix;
    }

    public synchronized DoubleMatrix getB() {
        return this.b;
    }

    public synchronized void setB(DoubleMatrix doubleMatrix) {
        this.b = doubleMatrix;
    }

    public synchronized RandomGenerator getRng() {
        return this.rng;
    }

    public synchronized void setRng(RandomGenerator randomGenerator) {
        this.rng = randomGenerator;
    }

    public synchronized DoubleMatrix getInput() {
        return this.input;
    }

    public synchronized void setInput(DoubleMatrix doubleMatrix) {
        this.input = doubleMatrix;
    }

    public synchronized ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    public synchronized void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public HiddenLayer m14clone() {
        HiddenLayer hiddenLayer = new HiddenLayer();
        hiddenLayer.b = this.b.dup();
        hiddenLayer.W = this.W.dup();
        if (this.input != null) {
            hiddenLayer.input = this.input.dup();
        }
        if (this.dist != null) {
            hiddenLayer.dist = this.dist;
        }
        hiddenLayer.activationFunction = this.activationFunction;
        hiddenLayer.nOut = this.nOut;
        hiddenLayer.nIn = this.nIn;
        hiddenLayer.rng = this.rng;
        return hiddenLayer;
    }

    public HiddenLayer transpose() {
        HiddenLayer hiddenLayer = new HiddenLayer();
        hiddenLayer.b = this.b.dup();
        hiddenLayer.W = this.W.transpose();
        if (this.input != null) {
            hiddenLayer.input = this.input.transpose();
        }
        if (this.dist != null) {
            hiddenLayer.dist = this.dist;
        }
        hiddenLayer.activationFunction = this.activationFunction;
        hiddenLayer.nOut = this.nIn;
        hiddenLayer.nIn = this.nOut;
        hiddenLayer.rng = this.rng;
        return hiddenLayer;
    }

    public synchronized DoubleMatrix activate() {
        return (DoubleMatrix) getActivationFunction().apply(getInput().mmul(getW()).addRowVector(getB()));
    }

    public synchronized DoubleMatrix activate(DoubleMatrix doubleMatrix) {
        if (doubleMatrix != null) {
            this.input = doubleMatrix;
        }
        return activate();
    }

    public DoubleMatrix sampleHGivenV(DoubleMatrix doubleMatrix) {
        this.input = doubleMatrix;
        return MatrixUtil.binomial(activate(), 1, this.rng);
    }

    public DoubleMatrix sample_h_given_v() {
        return MatrixUtil.binomial(activate(), 1, this.rng);
    }
}
