package org.nd4j.linalg.learning.regularization;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.schedule.FixedSchedule;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/nd4j/linalg/learning/regularization/WeightDecay.class */
public class WeightDecay implements Regularization {
    protected final ISchedule coeff;
    protected final boolean applyLR;

    public WeightDecay(double d, boolean z) {
        this(new FixedSchedule(d), z);
    }

    public WeightDecay(@NonNull @JsonProperty("coeff") ISchedule iSchedule, @JsonProperty("applyLR") boolean z) {
        if (iSchedule == null) {
            throw new NullPointerException("coeff is marked @NonNull but is null");
        }
        this.coeff = iSchedule;
        this.applyLR = z;
    }

    @Override // org.nd4j.linalg.learning.regularization.Regularization
    public Regularization.ApplyStep applyStep() {
        return Regularization.ApplyStep.POST_UPDATER;
    }

    @Override // org.nd4j.linalg.learning.regularization.Regularization
    public void apply(INDArray iNDArray, INDArray iNDArray2, double d, int i, int i2) {
        double valueAt = this.coeff.valueAt(i, i2);
        if (this.applyLR) {
            valueAt *= d;
        }
        Nd4j.exec(new Axpy(iNDArray, iNDArray2, iNDArray2, valueAt));
    }

    @Override // org.nd4j.linalg.learning.regularization.Regularization
    public double score(INDArray iNDArray, int i, int i2) {
        double doubleValue = iNDArray.norm2Number().doubleValue();
        return this.coeff.valueAt(i, i2) * 0.5d * doubleValue * doubleValue;
    }

    @Override // org.nd4j.linalg.learning.regularization.Regularization
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Regularization m8047clone() {
        return new WeightDecay(this.coeff.m8065clone(), this.applyLR);
    }

    public ISchedule getCoeff() {
        return this.coeff;
    }

    public boolean isApplyLR() {
        return this.applyLR;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof WeightDecay)) {
            return false;
        }
        WeightDecay weightDecay = (WeightDecay) obj;
        if (!weightDecay.canEqual(this)) {
            return false;
        }
        ISchedule coeff = getCoeff();
        ISchedule coeff2 = weightDecay.getCoeff();
        if (coeff == null) {
            if (coeff2 != null) {
                return false;
            }
        } else if (!coeff.equals(coeff2)) {
            return false;
        }
        return isApplyLR() == weightDecay.isApplyLR();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof WeightDecay;
    }

    public int hashCode() {
        ISchedule coeff = getCoeff();
        return (((1 * 59) + (coeff == null ? 43 : coeff.hashCode())) * 59) + (isApplyLR() ? 79 : 97);
    }

    public String toString() {
        return "WeightDecay(coeff=" + getCoeff() + ", applyLR=" + isApplyLR() + ")";
    }
}
