package org.wowtools.neo4j.rtree.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.neo4j.graphdb.Direction;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.RelationshipType;
import org.neo4j.graphdb.ResourceIterator;
import org.wowtools.neo4j.rtree.internal.define.Labels;
import org.wowtools.neo4j.rtree.internal.define.PropertyNames;
import org.wowtools.neo4j.rtree.internal.define.Relationships;
import org.wowtools.neo4j.rtree.internal.nearest.MinDist;
import org.wowtools.neo4j.rtree.internal.nearest.MinDistComparator;
import org.wowtools.neo4j.rtree.pojo.PointNd;
import org.wowtools.neo4j.rtree.util.DistanceResult;

/* loaded from: input_file:org/wowtools/neo4j/rtree/util/NearestNeighbour.class */
public abstract class NearestNeighbour<T extends DistanceResult> {
    private final DistanceResultNodeFilter filter;
    private final int maxHits;
    private final PointNd pointNd;
    public static final DistanceResultNodeFilter alwaysTrue = distanceResult -> {
        return true;
    };
    private static final Comparator<DistanceResult> comp = new Comparator<DistanceResult>() { // from class: org.wowtools.neo4j.rtree.util.NearestNeighbour.1
        @Override // java.util.Comparator
        public int compare(DistanceResult distanceResult, DistanceResult distanceResult2) {
            return Double.compare(distanceResult.getDist(), distanceResult2.getDist());
        }
    };

    public NearestNeighbour(DistanceResultNodeFilter distanceResultNodeFilter, int i, PointNd pointNd) {
        this.pointNd = pointNd;
        this.filter = distanceResultNodeFilter;
        this.maxHits = i;
    }

    public NearestNeighbour(int i, PointNd pointNd) {
        this.pointNd = pointNd;
        this.filter = alwaysTrue;
        this.maxHits = i;
    }

    public abstract T createDistanceResult(PointNd pointNd, String str);

    public List<T> find(Node node) {
        ArrayList arrayList = new ArrayList(this.maxHits);
        PriorityQueue<Node> priorityQueue = new PriorityQueue<>(20, new MinDistComparator(this.pointNd));
        priorityQueue.add(node);
        while (!priorityQueue.isEmpty()) {
            Node remove = priorityQueue.remove();
            if (remove.hasLabel(Labels.RTREE_BRANCH)) {
                nnExpandInternal(remove, arrayList, this.maxHits, priorityQueue);
            } else {
                nnExpandLeaf(remove, this.filter, arrayList, this.maxHits);
            }
        }
        return arrayList;
    }

    private void nnExpandInternal(Node node, List<T> list, int i, PriorityQueue<Node> priorityQueue) {
        ResourceIterator it = node.getRelationships(Direction.OUTGOING, new RelationshipType[]{Relationships.RTREE_PARENT_TO_CHILD}).iterator();
        while (it.hasNext()) {
            Node endNode = ((Relationship) it.next()).getEndNode();
            Map properties = endNode.getProperties(new String[]{PropertyNames.mbrMax, PropertyNames.mbrMin});
            double d = MinDist.get((double[]) properties.get(PropertyNames.mbrMin), (double[]) properties.get(PropertyNames.mbrMax), this.pointNd);
            int size = list.size();
            if (size < i || d <= list.get(size - 1).getDist()) {
                priorityQueue.add(endNode);
            }
        }
    }

    private void nnExpandLeaf(Node node, DistanceResultNodeFilter distanceResultNodeFilter, List<T> list, int i) {
        int intValue = ((Integer) node.getProperty(PropertyNames.size)).intValue();
        String[] strArr = new String[intValue];
        for (int i2 = 0; i2 < intValue; i2++) {
            strArr[i2] = "entryDataId" + i2;
        }
        node.getProperties(strArr).forEach((str, obj) -> {
            T createDistanceResult = createDistanceResult(this.pointNd, (String) obj);
            double dist = createDistanceResult.getDist();
            if (distanceResultNodeFilter.accept(createDistanceResult)) {
                int size = list.size();
                if (size < i || dist < ((DistanceResult) list.get(size - 1)).getDist()) {
                    add(list, createDistanceResult, i);
                }
            }
        });
    }

    private void add(List<T> list, T t, int i) {
        int size = list.size();
        if (size == i) {
            list.remove(size - 1);
        }
        int binarySearch = Collections.binarySearch(list, t, comp);
        if (binarySearch < 0) {
            binarySearch = -(binarySearch + 1);
        }
        list.add(binarySearch, t);
    }
}
