package org.nd4j.linalg.api.ops.impl.layers.convolution;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.class */
public class LocalResponseNormalization extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) LocalResponseNormalization.class);
    protected LocalResponseNormalizationConfig config;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization$LocalResponseNormalizationBuilder.class */
    public static class LocalResponseNormalizationBuilder {
        private SameDiff sameDiff;
        private SDVariable[] inputFunctions;
        private INDArray[] inputs;
        private INDArray[] outputs;
        private boolean inPlace;
        private LocalResponseNormalizationConfig config;

        LocalResponseNormalizationBuilder() {
        }

        public LocalResponseNormalizationBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public LocalResponseNormalizationBuilder inputFunctions(SDVariable[] sDVariableArr) {
            this.inputFunctions = sDVariableArr;
            return this;
        }

        public LocalResponseNormalizationBuilder inputs(INDArray[] iNDArrayArr) {
            this.inputs = iNDArrayArr;
            return this;
        }

        public LocalResponseNormalizationBuilder outputs(INDArray[] iNDArrayArr) {
            this.outputs = iNDArrayArr;
            return this;
        }

        public LocalResponseNormalizationBuilder inPlace(boolean z) {
            this.inPlace = z;
            return this;
        }

        public LocalResponseNormalizationBuilder config(LocalResponseNormalizationConfig localResponseNormalizationConfig) {
            this.config = localResponseNormalizationConfig;
            return this;
        }

        public LocalResponseNormalization build() {
            return new LocalResponseNormalization(this.sameDiff, this.inputFunctions, this.inputs, this.outputs, this.inPlace, this.config);
        }

        public String toString() {
            return "LocalResponseNormalization.LocalResponseNormalizationBuilder(sameDiff=" + this.sameDiff + ", inputFunctions=" + Arrays.deepToString(this.inputFunctions) + ", inputs=" + Arrays.deepToString(this.inputs) + ", outputs=" + Arrays.deepToString(this.outputs) + ", inPlace=" + this.inPlace + ", config=" + this.config + ")";
        }
    }

    public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] sDVariableArr, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, boolean z, LocalResponseNormalizationConfig localResponseNormalizationConfig) {
        super(null, sameDiff, sDVariableArr, z);
        this.config = localResponseNormalizationConfig;
        if (iNDArrayArr != null) {
            addInputArgument(iNDArrayArr);
        }
        if (iNDArrayArr2 != null) {
            addOutputArgument(iNDArrayArr2);
        }
        addArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        return this.config.toProperties();
    }

    private void addArgs() {
        addTArgument(this.config.getBias());
        addTArgument(this.config.getAlpha());
        addTArgument(this.config.getBeta());
        addIArgument(this.config.getDepth());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "lrn";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        AttrValue attrOrThrow = nodeDef.getAttrOrThrow("alpha");
        AttrValue attrOrThrow2 = nodeDef.getAttrOrThrow(BatchNormalizationParamInitializer.BETA);
        AttrValue attrOrThrow3 = nodeDef.getAttrOrThrow("bias");
        AttrValue attrOrThrow4 = nodeDef.getAttrOrThrow("depth_radius");
        double f = attrOrThrow.getF();
        double f2 = attrOrThrow2.getF();
        this.config = LocalResponseNormalizationConfig.builder().alpha(f).beta(f2).bias(attrOrThrow3.getF()).depth((int) attrOrThrow4.getI()).build();
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
        OnnxProto3.AttributeProto attributeProto = map.get("alpha");
        OnnxProto3.AttributeProto attributeProto2 = map.get(BatchNormalizationParamInitializer.BETA);
        OnnxProto3.AttributeProto attributeProto3 = map.get("bias");
        this.config = LocalResponseNormalizationConfig.builder().alpha(attributeProto.getF()).beta(attributeProto2.getF()).bias(attributeProto3.getF()).depth((int) map.get("size").getF()).build();
        addArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        PropertyMapping build = PropertyMapping.builder().tfAttrName("depth_radius").propertyNames(new String[]{"depth"}).onnxAttrName("size").build();
        PropertyMapping build2 = PropertyMapping.builder().tfAttrName("alpha").onnxAttrName("alpha").propertyNames(new String[]{"alpha"}).build();
        PropertyMapping build3 = PropertyMapping.builder().tfAttrName(BatchNormalizationParamInitializer.BETA).onnxAttrName(BatchNormalizationParamInitializer.BETA).propertyNames(new String[]{BatchNormalizationParamInitializer.BETA}).build();
        PropertyMapping build4 = PropertyMapping.builder().tfAttrName("bias").onnxAttrName("bias").propertyNames(new String[]{"bias"}).build();
        HashMap hashMap2 = new HashMap();
        hashMap2.put("depth", build);
        hashMap2.put("alpha", build2);
        hashMap2.put(BatchNormalizationParamInitializer.BETA, build3);
        hashMap2.put("bias", build4);
        hashMap.put(tensorflowName(), hashMap2);
        hashMap.put(onnxName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return Collections.singletonList(LocalResponseNormalizationDerivative.derivativeBuilder().inPlace(this.inPlace).sameDiff(this.sameDiff).inputFunctions(new SDVariable[]{arg(), list.get(0)}).config(this.config).build().outputVariable());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "LRN";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "LRN";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list.get(0).isFPType(), "Input 0 should be a floating point type for %s, got %s", getClass(), list.get(0));
        return Collections.singletonList(list.get(0));
    }

    public static LocalResponseNormalizationBuilder builder() {
        return new LocalResponseNormalizationBuilder();
    }

    public LocalResponseNormalizationConfig getConfig() {
        return this.config;
    }

    public LocalResponseNormalization() {
    }
}
