package org.nd4j.linalg.api.ops.impl.transforms;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.PropertyAccessor;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.class */
public abstract class BaseDynamicTransformOp extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BaseDynamicTransformOp.class);

    public BaseDynamicTransformOp() {
    }

    public BaseDynamicTransformOp(SameDiff sameDiff, SDVariable[] sDVariableArr, boolean z) {
        super(null, sameDiff, sDVariableArr, z);
    }

    public BaseDynamicTransformOp(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        super((String) null, iNDArrayArr, iNDArrayArr2);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        long[] shape;
        long[] shape2;
        if (numInputArguments() == 2) {
            return super.calculateOutputShape();
        }
        SDVariable[] args = args();
        if (args.length < 2) {
            if (args[0] == null || (this.inputArguments.isEmpty() && args[0].getShape() == null)) {
                return Collections.emptyList();
            }
            return Collections.singletonList(LongShapeDescriptor.fromShape(!this.inputArguments.isEmpty() ? this.inputArguments.get(0).shape() : args[0].getShape(), !this.inputArguments.isEmpty() ? this.inputArguments.get(0).dataType() : args[0].dataType()));
        }
        if (this.inputArguments.size() != 2 || this.inputArguments.get(0) == null || this.inputArguments.get(1) == null) {
            shape = args[0].getShape();
            shape2 = args[1].getShape();
        } else {
            shape = this.inputArguments.get(0).shape();
            shape2 = this.inputArguments.get(1).shape();
        }
        if (args[0] == null || args[0].getShape() == null) {
            return Collections.emptyList();
        }
        if (args[1] == null || args[1].getShape() == null) {
            return Collections.emptyList();
        }
        DataType pickPairwiseDataType = Shape.pickPairwiseDataType(this.inputArguments.size() > 0 ? this.inputArguments.get(0).dataType() : args[0].dataType(), this.inputArguments.size() > 1 ? this.inputArguments.get(1).dataType() : args[1].dataType());
        if (!Arrays.equals(shape, shape2)) {
            Shape.assertBroadcastable(shape, shape2, getClass());
            return Collections.singletonList(LongShapeDescriptor.fromShape(Shape.broadcastOutputShape(shape, shape2), pickPairwiseDataType));
        }
        try {
            return Collections.singletonList(LongShapeDescriptor.fromShape(shape, pickPairwiseDataType));
        } catch (Throwable th) {
            throw new RuntimeException("calculateOutputShape() failed for [" + opName() + PropertyAccessor.PROPERTY_KEY_SUFFIX, th);
        }
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), list);
        return Collections.singletonList(Shape.pickPairwiseDataType(list.get(0), list.get(1)));
    }
}
