package org.nd4j.linalg.api.ops.factory;

import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.GradientOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.LossFunction;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.accum.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.Pow;
import org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear;
import org.nd4j.linalg.api.ops.impl.transforms.Set;
import org.nd4j.linalg.api.ops.impl.transforms.Step;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.PropertyAccessor;

/* loaded from: input_file:org/nd4j/linalg/api/ops/factory/DefaultOpFactory.class */
public class DefaultOpFactory implements OpFactory {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DefaultOpFactory.class);

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public GradientOp createGradientOp(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 122802112:
                if (str.equals(GradientBackwardsMarker.OP_NAME)) {
                    z = 3;
                    break;
                }
                break;
            case 588526081:
                if (str.equals("softmaxderivative")) {
                    z = false;
                    break;
                }
                break;
            case 850385326:
                if (str.equals("tanhderivative")) {
                    z = 2;
                    break;
                }
                break;
            case 2133403285:
                if (str.equals("sigmoidderivative")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new SoftMaxDerivative(iNDArray, iNDArray2, iNDArray3);
            case true:
                return new SigmoidDerivative(iNDArray, iNDArray2, iNDArray3);
            case true:
                return new TanhDerivative(iNDArray, iNDArray2, iNDArray3);
            case true:
                return new GradientBackwardsMarker(iNDArray, iNDArray2, iNDArray3);
            default:
                throw new IllegalStateException("Illegal opName " + str);
        }
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public Op createShape(String str, INDArray iNDArray, INDArray iNDArray2, Object[] objArr) {
        throw new IllegalArgumentException("Illegal opName for create shape op" + str);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public LossFunction createLossFunction(String str, INDArray iNDArray, INDArray iNDArray2) {
        try {
            return (LossFunction) ((Op) DifferentialFunctionClassHolder.getInstance().getInstance(str).getClass().getDeclaredConstructor(INDArray.class, INDArray.class).newInstance(iNDArray, iNDArray2));
        } catch (Exception e) {
            throw new IllegalArgumentException("Illegal op " + str);
        }
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public Accumulation createAccum(String str, INDArray iNDArray) {
        return createAccum(str, iNDArray, null, iNDArray, null);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public Accumulation createAccum(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        return createAccum(str, iNDArray, iNDArray2, iNDArray3, null);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public Accumulation createAccum(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Object[] objArr) {
        Accumulation accumulation;
        boolean z = -1;
        switch (str.hashCode()) {
            case 114211:
                if (str.equals("std")) {
                    z = true;
                    break;
                }
                break;
            case 116519:
                if (str.equals(BatchNormalizationParamInitializer.GLOBAL_VAR)) {
                    z = 2;
                    break;
                }
                break;
            case 3355703:
                if (str.equals("mmul")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
                accumulation = new StandardDeviation(iNDArray, iNDArray2, iNDArray3, iNDArray.length(), ((Boolean) objArr[0]).booleanValue());
                break;
            case true:
                accumulation = new Variance(iNDArray, iNDArray2, iNDArray3, iNDArray.length(), ((Boolean) objArr[0]).booleanValue());
                break;
            default:
                try {
                    accumulation = (Accumulation) DifferentialFunctionClassHolder.getInstance().getInstance(str).getClass().getConstructor(INDArray.class, INDArray.class, INDArray.class, Long.TYPE).newInstance(iNDArray, iNDArray2, iNDArray3, Long.valueOf(iNDArray.length()));
                    break;
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
        }
        if (accumulation == null) {
            throw new IllegalArgumentException("Illegal operation opName " + str);
        }
        accumulation.setExtraArgs(objArr);
        return accumulation;
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public Accumulation createAccum(String str, INDArray iNDArray, INDArray iNDArray2) {
        return createAccum(str, iNDArray, iNDArray2, iNDArray, null);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public IndexAccumulation createIndexAccum(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Object[] objArr) {
        try {
            IndexAccumulation indexAccumulation = (IndexAccumulation) DifferentialFunctionClassHolder.getInstance().getInstance(str).getClass().getConstructor(INDArray.class, INDArray.class).newInstance(iNDArray, iNDArray2);
            indexAccumulation.setExtraArgs(objArr);
            return indexAccumulation;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public IndexAccumulation createIndexAccum(String str, INDArray iNDArray) {
        return createIndexAccum(str, iNDArray, null, iNDArray, null);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public IndexAccumulation createIndexAccum(String str, INDArray iNDArray, INDArray iNDArray2) {
        return createIndexAccum(str, iNDArray, iNDArray2, iNDArray, null);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public TransformOp createTransform(String str, INDArray iNDArray, INDArray iNDArray2) {
        return createTransform(str, iNDArray, iNDArray2, iNDArray, null);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public TransformOp createTransform(String str, INDArray iNDArray) {
        return createTransform(str, iNDArray, null, iNDArray, null);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public TransformOp createTransform(String str, INDArray iNDArray, Object[] objArr) {
        return createTransform(str, iNDArray, null, iNDArray, objArr);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public TransformOp createTransform(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        return createTransform(str, iNDArray, iNDArray2, iNDArray3, null);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public TransformOp createTransform(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Object[] objArr) {
        TransformOp transformOp;
        boolean z = -1;
        switch (str.hashCode()) {
            case 111192:
                if (str.equals("pow")) {
                    z = 4;
                    break;
                }
                break;
            case 113762:
                if (str.equals("set")) {
                    z = true;
                    break;
                }
                break;
            case 3496700:
                if (str.equals("relu")) {
                    z = 2;
                    break;
                }
                break;
            case 3540684:
                if (str.equals("step")) {
                    z = 3;
                    break;
                }
                break;
            case 355933058:
                if (str.equals("_softmaxderivative")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                transformOp = new org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative(iNDArray, iNDArray3);
                break;
            case true:
                transformOp = new Set(iNDArray, iNDArray2, iNDArray3, iNDArray3.length());
                break;
            case true:
                transformOp = new RectifedLinear(iNDArray, iNDArray3, iNDArray.length(), (objArr == null || objArr[0] == null) ? 0.0d : ((Double) objArr[0]).doubleValue());
                break;
            case true:
                transformOp = new Step(iNDArray, iNDArray2, iNDArray3, iNDArray.length(), (objArr == null || objArr[0] == null) ? 0.0d : ((Double) objArr[0]).doubleValue());
                break;
            case true:
                transformOp = new Pow(iNDArray, iNDArray3, ((Double) objArr[0]).doubleValue());
                break;
            default:
                try {
                    transformOp = iNDArray2 == null ? (TransformOp) DifferentialFunctionClassHolder.getInstance().getInstance(str).getClass().getConstructor(INDArray.class, INDArray.class).newInstance(iNDArray, iNDArray3) : (TransformOp) DifferentialFunctionClassHolder.getInstance().getInstance(str).getClass().getConstructor(INDArray.class, INDArray.class, INDArray.class).newInstance(iNDArray, iNDArray2, iNDArray3);
                    break;
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
        }
        transformOp.setExtraArgs(objArr);
        return transformOp;
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public ScalarOp createScalarTransform(String str, INDArray iNDArray, INDArray iNDArray2, double d) {
        return createScalarTransform(str, iNDArray, iNDArray2, iNDArray, null, d);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public ScalarOp createScalarTransform(String str, INDArray iNDArray, double d) {
        return createScalarTransform(str, iNDArray, null, iNDArray, null, d);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public ScalarOp createScalarTransform(String str, INDArray iNDArray, Object[] objArr, double d) {
        return createScalarTransform(str, iNDArray, null, iNDArray, null, d);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public ScalarOp createScalarTransform(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, double d) {
        return createScalarTransform(str, iNDArray, iNDArray2, iNDArray3, null, d);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public ScalarOp createScalarTransform(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Object[] objArr, double d) {
        try {
            ScalarOp scalarOp = (ScalarOp) DifferentialFunctionClassHolder.getInstance().getInstance(str).getClass().getConstructor(INDArray.class, INDArray.class, INDArray.class, Long.TYPE, Number.class).newInstance(iNDArray, iNDArray2, iNDArray3, Long.valueOf(iNDArray.length()), Double.valueOf(d));
            scalarOp.setExtraArgs(objArr);
            return scalarOp;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public BroadcastOp createBroadcastOp(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        return createBroadcastOp(str, iNDArray, iNDArray2, iNDArray3, null, iArr);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public BroadcastOp createBroadcastOp(String str, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Object[] objArr, int... iArr) {
        try {
            BroadcastOp broadcastOp = (BroadcastOp) DifferentialFunctionClassHolder.getInstance().getInstance(str).getClass().getConstructor(INDArray.class, INDArray.class, INDArray.class, int[].class).newInstance(iNDArray, iNDArray2, iNDArray3, iArr);
            broadcastOp.setExtraArgs(objArr);
            return broadcastOp;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public BroadcastOp createBroadcastOp(String str, INDArray iNDArray, INDArray iNDArray2, int... iArr) {
        return createBroadcastOp(str, iNDArray, iNDArray2, iNDArray, null, iArr);
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public int getOpNumByName(String str) {
        try {
            return DifferentialFunctionClassHolder.getInstance().getInstance(str).opNum();
        } catch (Exception e) {
            throw new RuntimeException("OpName failed: [" + str + PropertyAccessor.PROPERTY_KEY_SUFFIX, e);
        }
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public int getOpNumIfExists(String str) {
        if (DifferentialFunctionClassHolder.getInstance().hasName(str)) {
            return getOpNumByName(str);
        }
        return -1;
    }

    @Override // org.nd4j.linalg.api.ops.factory.OpFactory
    public Op getOpByName(String str) {
        return (Op) DifferentialFunctionClassHolder.getInstance().getInstance(str);
    }
}
