package org.deeplearning4j.nn.graph.vertex.impl;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.class */
public class L2NormalizeVertex extends BaseGraphVertex {
    private static final int[] DEFAULT_RANK2_DIMS = {1};
    private static final int[] DEFAULT_RANK3_DIMS = {1, 2};
    private static final int[] DEFAULT_RANK4_DIMS = {1, 2, 3};
    private int[] dimension;
    private double eps;

    public L2NormalizeVertex(ComputationGraph computationGraph, String str, int i, int[] iArr, double d, DataType dataType) {
        this(computationGraph, str, i, null, null, iArr, d, dataType);
    }

    public L2NormalizeVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2, int[] iArr, double d, DataType dataType) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2, dataType);
        this.dimension = iArr;
        this.eps = d;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: inputs not set (L2NormalizeVertex " + this.vertexName + " idx " + this.vertexIndex + ")");
        }
        INDArray iNDArray = this.inputs[0];
        INDArray norm2 = iNDArray.norm2(getDimensions(iNDArray));
        Transforms.max(norm2, this.eps, false);
        MemoryWorkspace notifyScopeBorrowed = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS);
        Throwable th = null;
        try {
            if (iNDArray.rank() == 2) {
                INDArray divColumnVector = iNDArray.divColumnVector(norm2);
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                return divColumnVector;
            }
            INDArray exec = Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastDivOp(iNDArray, norm2, Nd4j.createUninitialized(iNDArray.shape(), iNDArray.ordering()), 0));
            if (notifyScopeBorrowed != null) {
                if (0 != 0) {
                    try {
                        notifyScopeBorrowed.close();
                    } catch (Throwable th3) {
                        th.addSuppressed(th3);
                    }
                } else {
                    notifyScopeBorrowed.close();
                }
            }
            return exec;
        } catch (Throwable th4) {
            if (notifyScopeBorrowed != null) {
                if (0 != 0) {
                    try {
                        notifyScopeBorrowed.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    notifyScopeBorrowed.close();
                }
            }
            throw th4;
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray createUninitialized;
        if (!canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: errors not set (L2NormalizeVertex " + this.vertexName + " idx " + this.vertexIndex + ")");
        }
        INDArray iNDArray = this.inputs[0];
        int[] dimensions = getDimensions(iNDArray);
        INDArray norm2 = iNDArray.norm2(dimensions);
        INDArray pow = Transforms.pow(norm2, (Number) Double.valueOf(3.0d), true);
        Transforms.max(norm2, this.eps, false);
        Transforms.max(pow, this.eps, false);
        if (iNDArray.rank() == 2) {
            MemoryWorkspace notifyScopeBorrowed = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);
            Throwable th = null;
            try {
                try {
                    createUninitialized = this.epsilon.divColumnVector(norm2);
                    if (notifyScopeBorrowed != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeBorrowed.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            notifyScopeBorrowed.close();
                        }
                    }
                    createUninitialized.subi(iNDArray.divColumnVector(pow).muliColumnVector(this.epsilon.mul(iNDArray).sum(1)));
                } finally {
                }
            } catch (Throwable th3) {
                if (notifyScopeBorrowed != null) {
                    if (th != null) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                throw th3;
            }
        } else {
            INDArray sum = this.epsilon.mul(iNDArray).sum(dimensions);
            INDArray createUninitialized2 = Nd4j.createUninitialized(iNDArray.shape(), iNDArray.ordering());
            Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastDivOp(iNDArray, pow, createUninitialized2, 0));
            Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastMulOp(createUninitialized2, sum, createUninitialized2, 0));
            createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, this.epsilon.dataType(), this.epsilon.shape(), this.epsilon.ordering());
            Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastDivOp(this.epsilon, norm2, createUninitialized, 0));
            createUninitialized.subi(createUninitialized2);
        }
        return new Pair<>(null, new INDArray[]{createUninitialized});
    }

    private int[] getDimensions(INDArray iNDArray) {
        if (this.dimension != null && this.dimension.length >= 1) {
            return this.dimension;
        }
        switch (iNDArray.rank()) {
            case 2:
                return DEFAULT_RANK2_DIMS;
            case 3:
                return DEFAULT_RANK3_DIMS;
            case 4:
                return DEFAULT_RANK4_DIMS;
            default:
                throw new RuntimeException();
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            return null;
        }
        return new Pair<>(iNDArrayArr[0], maskState);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "L2NormalizeVertex(id=" + getVertexIndex() + ",name=\"" + getVertexName() + ",dim=\"" + this.dimension + "\")";
    }
}
