package org.deeplearning4j.nn.layers.objdetect;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/objdetect/YoloUtils.class */
public class YoloUtils {
    public static INDArray activate(INDArray iNDArray, INDArray iNDArray2) {
        return activate(iNDArray, iNDArray2, LayerWorkspaceMgr.noWorkspaces());
    }

    public static INDArray activate(INDArray iNDArray, INDArray iNDArray2, LayerWorkspaceMgr layerWorkspaceMgr) {
        int size = (int) iNDArray2.size(0);
        int size2 = (int) iNDArray2.size(2);
        int size3 = (int) iNDArray2.size(3);
        int size4 = (int) iNDArray.size(0);
        int size5 = ((int) (iNDArray2.size(1) / size4)) - 5;
        INDArray create = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, iNDArray2.dataType(), iNDArray2.shape(), 'c');
        INDArray reshape = create.reshape('c', size, size4, 5 + size5, size2, size3);
        INDArray dup = iNDArray2.dup('c');
        INDArray reshape2 = dup.reshape('c', size, size4, 5 + size5, size2, size3);
        Transforms.sigmoid(reshape2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all()), false);
        INDArray exp = Transforms.exp(reshape2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(2, 4), NDArrayIndex.all(), NDArrayIndex.all()), false);
        Broadcast.mul(exp, iNDArray, exp, 1, 2);
        Transforms.sigmoid(reshape2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(4L), NDArrayIndex.all(), NDArrayIndex.all()), false);
        create.assign(dup);
        INDArray reshape3 = reshape2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(5, 5 + size5), NDArrayIndex.all(), NDArrayIndex.all()).permute(0, 1, 3, 4, 2).dup('c').reshape('c', size * size4 * size2 * size3, size5);
        Transforms.softmax(reshape3, false);
        reshape.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(5, 5 + size5), NDArrayIndex.all(), NDArrayIndex.all()).assign(reshape3.reshape('c', size, size4, size2, size3, size5).permute(0, 1, 4, 2, 3));
        return create;
    }

    public static double overlap(double d, double d2, double d3, double d4) {
        if (d3 < d) {
            if (d4 < d) {
                return 0.0d;
            }
            return Math.min(d2, d4) - d;
        }
        if (d2 < d3) {
            return 0.0d;
        }
        return Math.min(d2, d4) - d3;
    }

    public static double iou(DetectedObject detectedObject, DetectedObject detectedObject2) {
        double centerX = detectedObject.getCenterX() - (detectedObject.getWidth() / 2.0d);
        double centerX2 = detectedObject.getCenterX() + (detectedObject.getWidth() / 2.0d);
        double centerY = detectedObject.getCenterY() - (detectedObject.getHeight() / 2.0d);
        double centerY2 = detectedObject.getCenterY() + (detectedObject.getHeight() / 2.0d);
        double overlap = overlap(centerX, centerX2, detectedObject2.getCenterX() - (detectedObject2.getWidth() / 2.0d), detectedObject2.getCenterX() + (detectedObject2.getWidth() / 2.0d)) * overlap(centerY, centerY2, detectedObject2.getCenterY() - (detectedObject2.getHeight() / 2.0d), detectedObject2.getCenterY() + (detectedObject2.getHeight() / 2.0d));
        return overlap / (((detectedObject.getWidth() * detectedObject.getHeight()) + (detectedObject2.getWidth() * detectedObject2.getHeight())) - overlap);
    }

    public static void nms(List<DetectedObject> list, double d) {
        for (int i = 0; i < list.size(); i++) {
            for (int i2 = 0; i2 < list.size(); i2++) {
                DetectedObject detectedObject = list.get(i);
                DetectedObject detectedObject2 = list.get(i2);
                if (detectedObject != null && detectedObject2 != null && detectedObject.getPredictedClass() == detectedObject2.getPredictedClass() && detectedObject.getConfidence() < detectedObject2.getConfidence() && iou(detectedObject, detectedObject2) > d) {
                    list.set(i, null);
                }
            }
        }
        Iterator<DetectedObject> it2 = list.iterator();
        while (it2.hasNext()) {
            if (it2.next() == null) {
                it2.remove();
            }
        }
    }

    public static List<DetectedObject> getPredictedObjects(INDArray iNDArray, INDArray iNDArray2, double d, double d2) {
        if (iNDArray2.rank() != 4) {
            throw new IllegalStateException("Invalid network output activations array: should be rank 4. Got array with shape " + Arrays.toString(iNDArray2.shape()));
        }
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalStateException("Invalid confidence threshold: must be in range [0,1]. Got: " + d);
        }
        int size = (int) iNDArray2.size(0);
        int size2 = (int) iNDArray2.size(2);
        int size3 = (int) iNDArray2.size(3);
        int size4 = (int) iNDArray.size(0);
        int size5 = ((int) (iNDArray2.size(1) / size4)) - 5;
        INDArray reshape = iNDArray2.dup('c').reshape(size, size4, 5 + size5, size2, size3);
        INDArray iNDArray3 = reshape.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(4L), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray iNDArray4 = reshape.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(5, 5 + size5), NDArrayIndex.all(), NDArrayIndex.all());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size3; i2++) {
                for (int i3 = 0; i3 < size2; i3++) {
                    for (int i4 = 0; i4 < size4; i4++) {
                        double d3 = iNDArray3.getDouble(i, i4, i3, i2);
                        if (d3 >= d) {
                            double d4 = reshape.getDouble(i, i4, 0, i3, i2);
                            double d5 = reshape.getDouble(i, i4, 1, i3, i2);
                            double d6 = reshape.getDouble(i, i4, 2, i3, i2);
                            double d7 = reshape.getDouble(i, i4, 3, i3, i2);
                            double d8 = d4 + i2;
                            double d9 = d5 + i3;
                            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                            Throwable th = null;
                            try {
                                try {
                                    INDArray dup = iNDArray4.get(NDArrayIndex.point(i), NDArrayIndex.point(i4), NDArrayIndex.all(), NDArrayIndex.point(i3), NDArrayIndex.point(i2)).dup();
                                    if (scopeOutOfWorkspaces != null) {
                                        if (0 != 0) {
                                            try {
                                                scopeOutOfWorkspaces.close();
                                            } catch (Throwable th2) {
                                                th.addSuppressed(th2);
                                            }
                                        } else {
                                            scopeOutOfWorkspaces.close();
                                        }
                                    }
                                    arrayList.add(new DetectedObject(i, d8, d9, d6, d7, dup, d3));
                                } finally {
                                }
                            } catch (Throwable th3) {
                                if (scopeOutOfWorkspaces != null) {
                                    if (th != null) {
                                        try {
                                            scopeOutOfWorkspaces.close();
                                        } catch (Throwable th4) {
                                            th.addSuppressed(th4);
                                        }
                                    } else {
                                        scopeOutOfWorkspaces.close();
                                    }
                                }
                                throw th3;
                            }
                        }
                    }
                }
            }
        }
        if (d2 > 0.0d) {
            nms(arrayList, d2);
        }
        return arrayList;
    }
}
