package org.deeplearning4j.nn.conf.dropout;

import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"mask"})
/* loaded from: input_file:org/deeplearning4j/nn/conf/dropout/SpatialDropout.class */
public class SpatialDropout implements IDropout {
    private double p;
    private ISchedule pSchedule;
    private transient INDArray mask;

    public SpatialDropout(double d) {
        this(d, null);
        if (d < 0.0d) {
            throw new IllegalArgumentException("Activation retain probability must be > 0. Got: " + d);
        }
        if (d == 0.0d) {
            throw new IllegalArgumentException("Invalid probability value: Dropout with 0.0 probability of retaining activations is not supported");
        }
    }

    public SpatialDropout(ISchedule iSchedule) {
        this(Double.NaN, iSchedule);
    }

    protected SpatialDropout(@JsonProperty("p") double d, @JsonProperty("pSchedule") ISchedule iSchedule) {
        this.p = d;
        this.pSchedule = iSchedule;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public INDArray applyDropout(INDArray iNDArray, INDArray iNDArray2, int i, int i2, LayerWorkspaceMgr layerWorkspaceMgr) {
        Preconditions.checkArgument(iNDArray.rank() == 5 || iNDArray.rank() == 4 || iNDArray.rank() == 3, "Cannot apply spatial dropout to activations of rank %s: spatial dropout can only be used for rank 3, 4 or 5 activations (input activations shape: %s)", Integer.valueOf(iNDArray.rank()), iNDArray.shape());
        double valueAt = this.pSchedule != null ? this.pSchedule.valueAt(i, i2) : this.p;
        this.mask = layerWorkspaceMgr.createUninitialized((LayerWorkspaceMgr) ArrayType.INPUT, iNDArray.size(0), iNDArray.size(1)).assign(Double.valueOf(1.0d));
        Nd4j.getExecutioner().exec((RandomOp) new DropOutInverted(this.mask, valueAt));
        Broadcast.mul(iNDArray, this.mask, iNDArray2, 0, 1);
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public INDArray backprop(INDArray iNDArray, INDArray iNDArray2, int i, int i2) {
        Preconditions.checkState(this.mask != null, "Cannot perform backprop: Dropout mask array is absent (already cleared?)");
        Broadcast.mul(iNDArray, this.mask, iNDArray2, 0, 1);
        this.mask = null;
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public void clear() {
        this.mask = null;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public IDropout m6419clone() {
        return new SpatialDropout(this.p, this.pSchedule);
    }

    public double getP() {
        return this.p;
    }

    public ISchedule getPSchedule() {
        return this.pSchedule;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setP(double d) {
        this.p = d;
    }

    public void setPSchedule(ISchedule iSchedule) {
        this.pSchedule = iSchedule;
    }

    public void setMask(INDArray iNDArray) {
        this.mask = iNDArray;
    }

    public String toString() {
        return "SpatialDropout(p=" + getP() + ", pSchedule=" + getPSchedule() + ", mask=" + getMask() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof SpatialDropout)) {
            return false;
        }
        SpatialDropout spatialDropout = (SpatialDropout) obj;
        if (!spatialDropout.canEqual(this) || Double.compare(getP(), spatialDropout.getP()) != 0) {
            return false;
        }
        ISchedule pSchedule = getPSchedule();
        ISchedule pSchedule2 = spatialDropout.getPSchedule();
        return pSchedule == null ? pSchedule2 == null : pSchedule.equals(pSchedule2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof SpatialDropout;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getP());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        ISchedule pSchedule = getPSchedule();
        return (i * 59) + (pSchedule == null ? 43 : pSchedule.hashCode());
    }
}
