package org.nd4j.linalg.api.ops.impl.transforms;

import java.util.Arrays;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/transforms/LogSoftMax.class */
public class LogSoftMax extends BaseTransformOp {
    public LogSoftMax(SameDiff sameDiff, SDVariable sDVariable, boolean z) {
        super(sameDiff, sDVariable, z);
    }

    public LogSoftMax(SameDiff sameDiff, SDVariable sDVariable, long[] jArr, boolean z, Object[] objArr) {
        super(sameDiff, sDVariable, jArr, z, objArr);
    }

    public LogSoftMax(SameDiff sameDiff, SDVariable sDVariable, Object[] objArr) {
        super(sameDiff, sDVariable, objArr);
    }

    public LogSoftMax() {
    }

    public LogSoftMax(INDArray iNDArray, INDArray iNDArray2) {
        this(iNDArray, (INDArray) null, iNDArray2);
    }

    public LogSoftMax(INDArray iNDArray, INDArray iNDArray2, long j) {
        this(iNDArray, null, iNDArray2, j);
    }

    public LogSoftMax(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, long j) {
        super(iNDArray, iNDArray2, iNDArray3, j);
        if (iNDArray != iNDArray3) {
            iNDArray3.assign(iNDArray);
        }
    }

    public LogSoftMax(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        this(iNDArray, iNDArray2, iNDArray3, iNDArray.lengthLong());
    }

    public LogSoftMax(INDArray iNDArray) {
        super(iNDArray);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.Op
    public int opNum() {
        return 40;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "logsoftmax";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "LogSoftmax";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "LogSoftmax";
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec() {
        exec(1);
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public boolean isExecSpecial() {
        return true;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec(int... iArr) {
        if (iArr[0] != 1) {
            throw new IllegalArgumentException("Only supports row wise calculations");
        }
        Nd4j.getExecutioner().exec(new OldSoftMax(this.x, this.z));
        Transforms.log(this.z, false);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return Arrays.asList(f().logSoftmaxDerivative(arg(), list.get(0)));
    }
}
