package org.datavec.api.transform.ndarray;

import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.class */
public class NDArrayScalarOpTransform extends BaseColumnTransform {
    private final MathOp mathOp;
    private final double scalar;

    public NDArrayScalarOpTransform(@JsonProperty("columnName") String str, @JsonProperty("mathOp") MathOp mathOp, @JsonProperty("scalar") double d) {
        super(str);
        this.mathOp = mathOp;
        this.scalar = d;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform
    public ColumnMetaData getNewColumnMetaData(String str, ColumnMetaData columnMetaData) {
        if (!(columnMetaData instanceof NDArrayMetaData)) {
            throw new IllegalStateException("Column " + str + " is not a NDArray column");
        }
        NDArrayMetaData mo6186clone = ((NDArrayMetaData) columnMetaData).mo6186clone();
        mo6186clone.setName(str);
        return mo6186clone;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform
    public NDArrayWritable map(Writable writable) {
        if (!(writable instanceof NDArrayWritable)) {
            throw new IllegalArgumentException("Input writable is not an NDArrayWritable: is " + writable.getClass());
        }
        INDArray dup = ((NDArrayWritable) writable).get().dup();
        switch (this.mathOp) {
            case Add:
                dup.addi(Double.valueOf(this.scalar));
                break;
            case Subtract:
                dup.subi(Double.valueOf(this.scalar));
                break;
            case Multiply:
                dup.muli(Double.valueOf(this.scalar));
                break;
            case Divide:
                dup.divi(Double.valueOf(this.scalar));
                break;
            case Modulus:
                throw new UnsupportedOperationException(this.mathOp + " is not supported for NDArrayWritable");
            case ReverseSubtract:
                dup.rsubi(Double.valueOf(this.scalar));
                break;
            case ReverseDivide:
                dup.rdivi(Double.valueOf(this.scalar));
                break;
            case ScalarMin:
                Transforms.min(dup, this.scalar, false);
                break;
            case ScalarMax:
                Transforms.max(dup, this.scalar, false);
                break;
            default:
                throw new UnsupportedOperationException("Unknown or not supported op: " + this.mathOp);
        }
        Nd4j.getExecutioner().commit();
        return new NDArrayWritable(dup);
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform, org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "NDArrayScalarOpTransform(mathOp=" + this.mathOp + ",scalar=" + this.scalar + ")";
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        if (obj instanceof INDArray) {
            return map((Writable) new NDArrayWritable((INDArray) obj)).get();
        }
        throw new RuntimeException("Unsupported class: " + obj.getClass());
    }

    public MathOp getMathOp() {
        return this.mathOp;
    }

    public double getScalar() {
        return this.scalar;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform, org.datavec.api.transform.transform.BaseTransform
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof NDArrayScalarOpTransform)) {
            return false;
        }
        NDArrayScalarOpTransform nDArrayScalarOpTransform = (NDArrayScalarOpTransform) obj;
        if (!nDArrayScalarOpTransform.canEqual(this)) {
            return false;
        }
        MathOp mathOp = getMathOp();
        MathOp mathOp2 = nDArrayScalarOpTransform.getMathOp();
        if (mathOp == null) {
            if (mathOp2 != null) {
                return false;
            }
        } else if (!mathOp.equals(mathOp2)) {
            return false;
        }
        return Double.compare(getScalar(), nDArrayScalarOpTransform.getScalar()) == 0;
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    protected boolean canEqual(Object obj) {
        return obj instanceof NDArrayScalarOpTransform;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform, org.datavec.api.transform.transform.BaseTransform
    public int hashCode() {
        MathOp mathOp = getMathOp();
        int hashCode = (1 * 59) + (mathOp == null ? 43 : mathOp.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getScalar());
        return (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
    }
}
