package org.deeplearning4j.clustering.kdtree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/clustering/kdtree/KDTree.class */
public class KDTree implements Serializable {
    private KDNode root;
    private int dims;
    public static final int GREATER = 1;
    public static final int LESS = 0;
    private int size = 0;
    private HyperRect rect;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/clustering/kdtree/KDTree$KDNode.class */
    public static class KDNode {
        private INDArray point;
        private KDNode left;
        private KDNode right;
        private KDNode parent;

        public KDNode(INDArray iNDArray) {
            this.point = iNDArray;
        }

        public INDArray getPoint() {
            return this.point;
        }

        public KDNode getLeft() {
            return this.left;
        }

        public void setLeft(KDNode kDNode) {
            this.left = kDNode;
        }

        public KDNode getRight() {
            return this.right;
        }

        public void setRight(KDNode kDNode) {
            this.right = kDNode;
        }

        public KDNode getParent() {
            return this.parent;
        }

        public void setParent(KDNode kDNode) {
            this.parent = kDNode;
        }
    }

    public KDTree(int i) {
        this.dims = 100;
        this.dims = i;
    }

    public void insert(INDArray iNDArray) {
        if (!iNDArray.isVector() || iNDArray.length() != this.dims) {
            throw new IllegalArgumentException("Point must be a vector of length " + this.dims);
        }
        if (this.root != null) {
            int i = 0;
            KDNode kDNode = this.root;
            KDNode kDNode2 = new KDNode(iNDArray);
            while (Nd4j.getExecutioner().execAndReturn((ReduceOp) new Any(kDNode.getPoint().neq(iNDArray))).z().getInt(0) != 0) {
                int successor = successor(kDNode, iNDArray, i);
                KDNode left = successor < 1 ? kDNode.getLeft() : kDNode.getRight();
                if (left == null) {
                    if (successor < 1) {
                        kDNode.setLeft(kDNode2);
                    } else {
                        kDNode.setRight(kDNode2);
                    }
                    this.rect.enlargeTo(iNDArray);
                    kDNode2.setParent(kDNode);
                } else {
                    i = (i + 1) % this.dims;
                    kDNode = left;
                }
            }
            return;
        }
        this.root = new KDNode(iNDArray);
        this.rect = new HyperRect(HyperRect.point(iNDArray));
        this.size++;
    }

    public INDArray delete(INDArray iNDArray) {
        int i;
        KDNode kDNode = this.root;
        int i2 = 0;
        while (true) {
            i = i2;
            if (kDNode == null || kDNode.point == iNDArray) {
                break;
            }
            kDNode = successor(kDNode, iNDArray, i) < 1 ? kDNode.getLeft() : kDNode.getRight();
            i2 = (i + 1) % this.dims;
        }
        if (kDNode != null) {
            if (kDNode == this.root) {
                this.root = delete(this.root, i);
            } else {
                kDNode = delete(kDNode, i);
            }
            this.size--;
            if (this.size == 1) {
                this.rect = new HyperRect(HyperRect.point(iNDArray));
            } else if (this.size == 0) {
                this.rect = null;
            }
        }
        return kDNode.getPoint();
    }

    public List<Pair<Double, INDArray>> knn(INDArray iNDArray, double d) {
        ArrayList arrayList = new ArrayList();
        knn(this.root, iNDArray, this.rect, d, arrayList, 0);
        Collections.sort(arrayList, new Comparator<Pair<Double, INDArray>>() { // from class: org.deeplearning4j.clustering.kdtree.KDTree.1
            @Override // java.util.Comparator
            public int compare(Pair<Double, INDArray> pair, Pair<Double, INDArray> pair2) {
                return Double.compare(pair.getKey().doubleValue(), pair2.getKey().doubleValue());
            }
        });
        return arrayList;
    }

    private void knn(KDNode kDNode, INDArray iNDArray, HyperRect hyperRect, double d, List<Pair<Double, INDArray>> list, int i) {
        if (kDNode == null || hyperRect == null || hyperRect.minDistance(iNDArray) > d) {
            return;
        }
        int i2 = (i + 1) % this.dims;
        double doubleValue = Nd4j.getExecutioner().execAndReturn((ReduceOp) new EuclideanDistance(iNDArray, kDNode.point, new int[0])).getFinalResult().doubleValue();
        if (doubleValue <= d) {
            list.add(Pair.of(Double.valueOf(doubleValue), kDNode.getPoint()));
        }
        HyperRect lower = hyperRect.getLower(kDNode.point, i);
        HyperRect upper = hyperRect.getUpper(kDNode.point, i);
        knn(kDNode.getLeft(), iNDArray, lower, d, list, i2);
        knn(kDNode.getRight(), iNDArray, upper, d, list, i2);
    }

    public Pair<Double, INDArray> nn(INDArray iNDArray) {
        return nn(this.root, iNDArray, this.rect, Double.POSITIVE_INFINITY, null, 0);
    }

    private Pair<Double, INDArray> nn(KDNode kDNode, INDArray iNDArray, HyperRect hyperRect, double d, INDArray iNDArray2, int i) {
        if (kDNode == null || hyperRect.minDistance(iNDArray) > d) {
            return Pair.of(Double.valueOf(Double.POSITIVE_INFINITY), null);
        }
        int i2 = (i + 1) % this.dims;
        double doubleValue = Nd4j.getExecutioner().execAndReturn((ReduceOp) new EuclideanDistance(iNDArray, Nd4j.zeros(iNDArray.shape()), new int[0])).getFinalResult().doubleValue();
        if (doubleValue < d) {
            iNDArray2 = kDNode.getPoint();
            d = doubleValue;
        }
        HyperRect lower = hyperRect.getLower(kDNode.point, i);
        HyperRect upper = hyperRect.getUpper(kDNode.point, i);
        if (iNDArray.getDouble(i) < kDNode.point.getDouble(i)) {
            Pair<Double, INDArray> nn = nn(kDNode.getLeft(), iNDArray, lower, d, iNDArray2, i2);
            Pair<Double, INDArray> nn2 = nn(kDNode.getRight(), iNDArray, upper, d, iNDArray2, i2);
            if (nn.getKey().doubleValue() < d) {
                return nn;
            }
            if (nn2.getKey().doubleValue() < d) {
                return nn2;
            }
        } else {
            Pair<Double, INDArray> nn3 = nn(kDNode.getRight(), iNDArray, upper, d, iNDArray2, i2);
            Pair<Double, INDArray> nn4 = nn(kDNode.getLeft(), iNDArray, lower, d, iNDArray2, i2);
            if (nn3.getKey().doubleValue() < d) {
                return nn3;
            }
            if (nn4.getKey().doubleValue() < d) {
                return nn4;
            }
        }
        return Pair.of(Double.valueOf(d), iNDArray2);
    }

    private KDNode delete(KDNode kDNode, int i) {
        if (kDNode.getLeft() != null && kDNode.getRight() != null) {
            if (kDNode.getParent() == null) {
                return null;
            }
            if (kDNode.getParent().getLeft() == kDNode) {
                kDNode.getParent().setLeft(null);
                return null;
            }
            kDNode.getParent().setRight(null);
            return null;
        }
        int i2 = (i + 1) % this.dims;
        Pair<KDNode, Integer> pair = null;
        if (kDNode.getRight() != null) {
            pair = min(kDNode.getRight(), i, i2);
        } else if (kDNode.getLeft() != null) {
            pair = max(kDNode.getLeft(), i, i2);
        }
        if (pair == null) {
            return null;
        }
        kDNode.point = pair.getKey().point;
        KDNode parent = pair.getKey().getParent();
        if (parent.getLeft() == pair.getKey()) {
            parent.setLeft(delete(pair.getKey(), i));
        } else if (parent.getRight() == pair.getKey()) {
            parent.setRight(delete(pair.getKey(), i));
        }
        return kDNode;
    }

    private Pair<KDNode, Integer> max(KDNode kDNode, int i, int i2) {
        int i3 = (i2 + 1) % this.dims;
        if (i2 == i) {
            KDNode left = kDNode.getLeft();
            if (left != null) {
                return max(left, i, i3);
            }
        } else if (kDNode.getLeft() != null || kDNode.getRight() != null) {
            Pair<KDNode, Integer> pair = null;
            Pair<KDNode, Integer> pair2 = null;
            if (kDNode.getLeft() != null) {
                pair = max(kDNode.getLeft(), i, i3);
            }
            if (kDNode.getRight() != null) {
                pair2 = max(kDNode.getRight(), i, i3);
            }
            return (pair == null || pair2 == null) ? pair != null ? pair : pair2 : pair.getKey().getPoint().getDouble((long) i) > pair2.getKey().getPoint().getDouble((long) i) ? pair : pair2;
        }
        return Pair.of(kDNode, Integer.valueOf(i2));
    }

    private Pair<KDNode, Integer> min(KDNode kDNode, int i, int i2) {
        int i3 = (i2 + 1) % this.dims;
        if (i2 == i) {
            KDNode left = kDNode.getLeft();
            if (left != null) {
                return min(left, i, i3);
            }
        } else if (kDNode.getLeft() != null || kDNode.getRight() != null) {
            Pair<KDNode, Integer> pair = null;
            Pair<KDNode, Integer> pair2 = null;
            if (kDNode.getLeft() != null) {
                pair = min(kDNode.getLeft(), i, i3);
            }
            if (kDNode.getRight() != null) {
                pair2 = min(kDNode.getRight(), i, i3);
            }
            return (pair == null || pair2 == null) ? pair != null ? pair : pair2 : pair.getKey().getPoint().getDouble((long) i) < pair2.getKey().getPoint().getDouble((long) i) ? pair : pair2;
        }
        return Pair.of(kDNode, Integer.valueOf(i2));
    }

    public int size() {
        return this.size;
    }

    private int successor(KDNode kDNode, INDArray iNDArray, int i) {
        for (int i2 = i; i2 < this.dims; i2++) {
            double d = iNDArray.getDouble(i2);
            double d2 = kDNode.getPoint().getDouble(i2);
            if (d < d2) {
                return 0;
            }
            if (d > d2) {
                return 1;
            }
        }
        throw new IllegalStateException("Point is equal!");
    }
}
