package org.deeplearning4j.nn.conf.layers.variational;

import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThan;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.class */
public class BernoulliReconstructionDistribution implements ReconstructionDistribution {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BernoulliReconstructionDistribution.class);
    private final IActivation activationFn;

    public BernoulliReconstructionDistribution() {
        this(Activation.SIGMOID);
    }

    public BernoulliReconstructionDistribution(Activation activation) {
        this(activation.getActivationFunction());
    }

    public BernoulliReconstructionDistribution(IActivation iActivation) {
        this.activationFn = iActivation;
        if ((iActivation instanceof ActivationSigmoid) || (iActivation instanceof ActivationHardSigmoid)) {
            return;
        }
        log.warn("Using BernoulliRecontructionDistribution with activation function \"" + iActivation + "\". Using sigmoid/hard sigmoid is recommended to bound probabilities in range 0 to 1");
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public boolean hasLossFunction() {
        return false;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public int distributionInputSize(int i) {
        return i;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public double negLogProbability(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        INDArray calcLogProbArray = calcLogProbArray(iNDArray, iNDArray2);
        return z ? (-calcLogProbArray.sumNumber().doubleValue()) / iNDArray.size(0) : -calcLogProbArray.sumNumber().doubleValue();
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray exampleNegLogProbability(INDArray iNDArray, INDArray iNDArray2) {
        return calcLogProbArray(iNDArray, iNDArray2).sum(true, 1).negi();
    }

    private INDArray calcLogProbArray(INDArray iNDArray, INDArray iNDArray2) {
        INDArray castTo = iNDArray.castTo(iNDArray2.dataType());
        INDArray dup = iNDArray2.dup();
        this.activationFn.getActivation(dup, false);
        INDArray log2 = Transforms.log(dup, true);
        INDArray log3 = Transforms.log(dup.rsubi(Double.valueOf(1.0d)), false);
        BooleanIndexing.replaceWhere(log2, Double.valueOf(0.0d), Conditions.isInfinite());
        BooleanIndexing.replaceWhere(log3, Double.valueOf(0.0d), Conditions.isInfinite());
        return log2.muli(castTo).addi(castTo.rsub(Double.valueOf(1.0d)).muli(log3));
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray gradient(INDArray iNDArray, INDArray iNDArray2) {
        INDArray dup = iNDArray2.dup();
        this.activationFn.getActivation(dup, true);
        INDArray first = this.activationFn.backprop(iNDArray2.dup(), iNDArray.castTo(iNDArray2.dataType()).sub(dup).divi(dup.rsub(Double.valueOf(1.0d)).muli(dup))).getFirst();
        BooleanIndexing.replaceWhere(first, Double.valueOf(0.0d), Conditions.isNan());
        return first.negi();
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateRandom(INDArray iNDArray) {
        INDArray dup = iNDArray.dup();
        this.activationFn.getActivation(dup, false);
        INDArray rand = Nd4j.rand(dup.shape());
        INDArray createUninitialized = Nd4j.createUninitialized(DataType.BOOL, dup.shape());
        Nd4j.getExecutioner().execAndReturn((TransformOp) new OldLessThan(rand, dup, createUninitialized));
        return createUninitialized.castTo(DataType.FLOAT);
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateAtMean(INDArray iNDArray) {
        INDArray dup = iNDArray.dup();
        this.activationFn.getActivation(dup, false);
        return dup;
    }

    public String toString() {
        return "BernoulliReconstructionDistribution(afn=" + this.activationFn + ")";
    }

    public IActivation getActivationFn() {
        return this.activationFn;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BernoulliReconstructionDistribution)) {
            return false;
        }
        BernoulliReconstructionDistribution bernoulliReconstructionDistribution = (BernoulliReconstructionDistribution) obj;
        if (!bernoulliReconstructionDistribution.canEqual(this)) {
            return false;
        }
        IActivation activationFn = getActivationFn();
        IActivation activationFn2 = bernoulliReconstructionDistribution.getActivationFn();
        return activationFn == null ? activationFn2 == null : activationFn.equals(activationFn2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BernoulliReconstructionDistribution;
    }

    public int hashCode() {
        IActivation activationFn = getActivationFn();
        return (1 * 59) + (activationFn == null ? 43 : activationFn.hashCode());
    }
}
