package org.nd4j.linalg.api.ops.impl.layers.convolution;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import onnx.OnnxProto3;
import org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter;
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.class */
public class Conv1D extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) Conv1D.class);
    protected Conv1DConfig config;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D$Conv1DBuilder.class */
    public static class Conv1DBuilder {
        private SameDiff sameDiff;
        private SDVariable[] inputFunctions;
        private INDArray[] inputArrays;
        private INDArray[] outputs;
        private Conv1DConfig config;

        Conv1DBuilder() {
        }

        public Conv1DBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public Conv1DBuilder inputFunctions(SDVariable[] sDVariableArr) {
            this.inputFunctions = sDVariableArr;
            return this;
        }

        public Conv1DBuilder inputArrays(INDArray[] iNDArrayArr) {
            this.inputArrays = iNDArrayArr;
            return this;
        }

        public Conv1DBuilder outputs(INDArray[] iNDArrayArr) {
            this.outputs = iNDArrayArr;
            return this;
        }

        public Conv1DBuilder config(Conv1DConfig conv1DConfig) {
            this.config = conv1DConfig;
            return this;
        }

        public Conv1D build() {
            return new Conv1D(this.sameDiff, this.inputFunctions, this.inputArrays, this.outputs, this.config);
        }

        public String toString() {
            return "Conv1D.Conv1DBuilder(sameDiff=" + this.sameDiff + ", inputFunctions=" + Arrays.deepToString(this.inputFunctions) + ", inputArrays=" + Arrays.deepToString(this.inputArrays) + ", outputs=" + Arrays.deepToString(this.outputs) + ", config=" + this.config + ")";
        }
    }

    public Conv1D(SameDiff sameDiff, SDVariable[] sDVariableArr, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, Conv1DConfig conv1DConfig) {
        super((String) null, iNDArrayArr, iNDArrayArr2);
        this.sameDiff = sameDiff;
        this.config = conv1DConfig;
        addArgs();
        sameDiff.putFunctionForId(getOwnName(), this);
        sameDiff.addArgsFor(sDVariableArr, this);
    }

    protected void addArgs() {
        if (this.config == null) {
            this.config = Conv1DConfig.builder().build();
        }
        addIArgument(this.config.getK(), this.config.getS(), this.config.getP(), ArrayUtil.fromBoolean(this.config.isSameMode()), ArrayUtil.fromBoolean(this.config.isNWC()));
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public long[] iArgs() {
        if (this.iArguments.size() == 0) {
            addArgs();
        }
        return super.iArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Object getValue(Field field) {
        if (this.config == null && !this.iArguments.isEmpty()) {
            this.config = Conv1DConfig.builder().s(this.iArguments.get(0).longValue()).p(this.iArguments.get(1).longValue()).isSameMode(this.iArguments.get(2).longValue() == 1).dataFormat(this.iArguments.get(3).longValue() == 1 ? Conv1DConfig.NCW : Conv1DConfig.NWC).build();
        }
        return this.config.getValue(field);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        return this.config.toProperties();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, map, nodeDef, graphDef);
        addArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isConfigProperties() {
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String configFieldName() {
        return "config";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
        OnnxGraphMapper.getInstance().initFunctionFromProperties(nodeProto.getOpType(), this, map, nodeProto, graphProto);
        addArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
        HashMap hashMap = new HashMap();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Map<String, Field> fieldsForFunction = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        linkedHashMap.put("kH", new ConditionalFieldValueNDArrayShapeAdapter("NHW", 2, 0, fieldsForFunction.get("dataFormat")));
        linkedHashMap.put("kW", new ConditionalFieldValueNDArrayShapeAdapter("NHW", 3, 1, fieldsForFunction.get("dataFormat")));
        linkedHashMap.put("sH", new ConditionalFieldValueIntIndexArrayAdapter("NHW", 2, 1, fieldsForFunction.get("dataFormat")));
        linkedHashMap.put("sW", new ConditionalFieldValueIntIndexArrayAdapter("NHW", 3, 2, fieldsForFunction.get("dataFormat")));
        linkedHashMap.put("isSameMode", new StringEqualsAdapter("SAME"));
        linkedHashMap.put("isNHWC", new StringEqualsAdapter("NHWC"));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("kH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
        hashMap2.put("kW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
        hashMap2.put("dH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
        hashMap2.put("dW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
        hashMap2.put("sH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
        hashMap2.put("sW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
        hashMap2.put("isSameMode", new StringEqualsAdapter("SAME"));
        hashMap2.put("isNHWC", new StringEqualsAdapter("NHC"));
        hashMap.put(tensorflowName(), linkedHashMap);
        hashMap.put(onnxName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PropertyMapping build = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[]{"s"}).build();
        PropertyMapping build2 = PropertyMapping.builder().propertyNames(new String[]{OCNNParamInitializer.K_KEY}).tfInputPosition(1).shapePosition(0).onnxAttrName("kernel_shape").build();
        PropertyMapping build3 = PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[]{"p"}).build();
        PropertyMapping build4 = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[]{"dataFormat"}).build();
        PropertyMapping build5 = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[]{"isNHWC"}).build();
        PropertyMapping build6 = PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[]{"isSameMode"}).tfAttrName("padding").build();
        hashMap2.put("s", build);
        hashMap2.put(OCNNParamInitializer.K_KEY, build2);
        hashMap2.put("p", build3);
        hashMap2.put("isSameMode", build6);
        hashMap2.put("dataFormat", build4);
        hashMap2.put("isNHWC", build5);
        try {
            hashMap.put(onnxName(), hashMap2);
        } catch (NoOpNameFoundException e) {
        }
        try {
            hashMap.put(tensorflowName(), hashMap2);
        } catch (NoOpNameFoundException e2) {
        }
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "conv1d";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "Conv";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "Conv1D";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String[] tensorflowNames() {
        return new String[]{"Conv1D"};
    }

    public static Conv1DBuilder builder() {
        return new Conv1DBuilder();
    }

    public Conv1DConfig getConfig() {
        return this.config;
    }

    public Conv1D() {
    }
}
