package org.datavec.api.transform.ndarray;

import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.class */
public class NDArrayMathFunctionTransform extends BaseColumnTransform {
    private static final boolean DUP = true;
    private final MathFunction mathFunction;

    public NDArrayMathFunctionTransform(@JsonProperty("columnName") String str, @JsonProperty("mathFunction") MathFunction mathFunction) {
        super(str);
        this.mathFunction = mathFunction;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform
    public ColumnMetaData getNewColumnMetaData(String str, ColumnMetaData columnMetaData) {
        ColumnMetaData mo6558clone = columnMetaData.mo6558clone();
        mo6558clone.setName(str);
        return mo6558clone;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform
    public NDArrayWritable map(Writable writable) {
        NDArrayWritable nDArrayWritable;
        NDArrayWritable nDArrayWritable2 = (NDArrayWritable) writable;
        INDArray iNDArray = nDArrayWritable2.get();
        if (iNDArray == null) {
            return nDArrayWritable2;
        }
        switch (this.mathFunction) {
            case ABS:
                nDArrayWritable = new NDArrayWritable(Transforms.abs(iNDArray, true));
                break;
            case ACOS:
                nDArrayWritable = new NDArrayWritable(Transforms.acos(iNDArray, true));
                break;
            case ASIN:
                nDArrayWritable = new NDArrayWritable(Transforms.asin(iNDArray, true));
                break;
            case ATAN:
                nDArrayWritable = new NDArrayWritable(Transforms.atan(iNDArray, true));
                break;
            case CEIL:
                nDArrayWritable = new NDArrayWritable(Transforms.ceil(iNDArray, true));
                break;
            case COS:
                nDArrayWritable = new NDArrayWritable(Transforms.cos(iNDArray, true));
                break;
            case COSH:
                throw new UnsupportedOperationException("sinh operation not yet supported for NDArray columns");
            case EXP:
                nDArrayWritable = new NDArrayWritable(Transforms.exp(iNDArray, true));
                break;
            case FLOOR:
                nDArrayWritable = new NDArrayWritable(Transforms.floor(iNDArray, true));
                break;
            case LOG:
                nDArrayWritable = new NDArrayWritable(Transforms.log(iNDArray, true));
                break;
            case LOG10:
                nDArrayWritable = new NDArrayWritable(Transforms.log(iNDArray, 10.0d, true));
                break;
            case SIGNUM:
                nDArrayWritable = new NDArrayWritable(Transforms.sign(iNDArray, true));
                break;
            case SIN:
                nDArrayWritable = new NDArrayWritable(Transforms.sin(iNDArray, true));
                break;
            case SINH:
                throw new UnsupportedOperationException("sinh operation not yet supported for NDArray columns");
            case SQRT:
                nDArrayWritable = new NDArrayWritable(Transforms.sqrt(iNDArray, true));
                break;
            case TAN:
                nDArrayWritable = new NDArrayWritable(Transforms.sin(iNDArray, true).divi(Transforms.cos(iNDArray, true)));
                break;
            case TANH:
                nDArrayWritable = new NDArrayWritable(Transforms.tanh(iNDArray, true));
                break;
            default:
                throw new RuntimeException("Unknown function: " + this.mathFunction);
        }
        Nd4j.getExecutioner().commit();
        return nDArrayWritable;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform, org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "NDArrayMathFunctionTransform(column=" + this.columnName + ",function=" + this.mathFunction + ")";
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        if (obj instanceof NDArrayWritable) {
            return map((Writable) obj);
        }
        if (obj instanceof INDArray) {
            return map((Writable) new NDArrayWritable((INDArray) obj)).get();
        }
        throw new UnsupportedOperationException("Unknown object type: " + (obj == null ? null : obj.getClass()));
    }

    public MathFunction getMathFunction() {
        return this.mathFunction;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform, org.datavec.api.transform.transform.BaseTransform
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof NDArrayMathFunctionTransform)) {
            return false;
        }
        NDArrayMathFunctionTransform nDArrayMathFunctionTransform = (NDArrayMathFunctionTransform) obj;
        if (!nDArrayMathFunctionTransform.canEqual(this)) {
            return false;
        }
        MathFunction mathFunction = getMathFunction();
        MathFunction mathFunction2 = nDArrayMathFunctionTransform.getMathFunction();
        return mathFunction == null ? mathFunction2 == null : mathFunction.equals(mathFunction2);
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    protected boolean canEqual(Object obj) {
        return obj instanceof NDArrayMathFunctionTransform;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform, org.datavec.api.transform.transform.BaseTransform
    public int hashCode() {
        MathFunction mathFunction = getMathFunction();
        return (1 * 59) + (mathFunction == null ? 43 : mathFunction.hashCode());
    }
}
