package org.deeplearning4j.nn.conf.layers.variational;

import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/variational/ExponentialReconstructionDistribution.class */
public class ExponentialReconstructionDistribution implements ReconstructionDistribution {
    private final IActivation activationFn;

    public ExponentialReconstructionDistribution() {
        this(KerasLayer.INIT_IDENTITY);
    }

    @Deprecated
    public ExponentialReconstructionDistribution(String str) {
        this(Activation.fromString(str).getActivationFunction());
    }

    public ExponentialReconstructionDistribution(Activation activation) {
        this(activation.getActivationFunction());
    }

    public ExponentialReconstructionDistribution(IActivation iActivation) {
        this.activationFn = iActivation;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public boolean hasLossFunction() {
        return false;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public int distributionInputSize(int i) {
        return i;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public double negLogProbability(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        INDArray dup = iNDArray2.dup();
        this.activationFn.getActivation(dup, false);
        double d = -Transforms.exp(dup, true).muli(iNDArray).rsubi(dup).sumNumber().doubleValue();
        return z ? d / iNDArray.size(0) : d;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray exampleNegLogProbability(INDArray iNDArray, INDArray iNDArray2) {
        INDArray dup = iNDArray2.dup();
        this.activationFn.getActivation(dup, false);
        return Transforms.exp(dup, true).muli(iNDArray).rsubi(dup).sum(1).negi();
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray gradient(INDArray iNDArray, INDArray iNDArray2) {
        return this.activationFn.backprop(iNDArray2.dup(), iNDArray.mul(Transforms.exp(this.activationFn.getActivation(iNDArray2.dup(), true), true)).subi(Double.valueOf(1.0d))).getFirst();
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateRandom(INDArray iNDArray) {
        return Transforms.log(Nd4j.rand(iNDArray.shape()), false).divi(Transforms.exp(this.activationFn.getActivation(iNDArray.dup(), false), true)).negi();
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateAtMean(INDArray iNDArray) {
        return Transforms.exp(this.activationFn.getActivation(iNDArray.dup(), false), true).rdivi(Double.valueOf(1.0d));
    }

    public String toString() {
        return "ExponentialReconstructionDistribution(afn=" + this.activationFn + ")";
    }

    public IActivation getActivationFn() {
        return this.activationFn;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ExponentialReconstructionDistribution)) {
            return false;
        }
        ExponentialReconstructionDistribution exponentialReconstructionDistribution = (ExponentialReconstructionDistribution) obj;
        if (!exponentialReconstructionDistribution.canEqual(this)) {
            return false;
        }
        IActivation activationFn = getActivationFn();
        IActivation activationFn2 = exponentialReconstructionDistribution.getActivationFn();
        return activationFn == null ? activationFn2 == null : activationFn.equals(activationFn2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ExponentialReconstructionDistribution;
    }

    public int hashCode() {
        IActivation activationFn = getActivationFn();
        return (1 * 59) + (activationFn == null ? 43 : activationFn.hashCode());
    }
}
