package org.wikibrain.matrix.knn;

import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.TreeSet;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.DenseMatrixRow;

/* loaded from: input_file:org/wikibrain/matrix/knn/KDTreeKNN.class */
public class KDTreeKNN implements KNNFinder {
    private final DenseMatrix matrix;
    private final int[] allIds;
    private final int dimensions;
    private int maxSampleSize = 5000;
    private int maxLeaf = 100;
    List<float[]> centroids;
    List<int[]> members;

    /* loaded from: input_file:org/wikibrain/matrix/knn/KDTreeKNN$Candidate.class */
    private static class Candidate implements Comparable<Candidate> {
        final int clusterNum;
        final double score;

        public Candidate(int i, double d) {
            this.clusterNum = i;
            this.score = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(Candidate candidate) {
            return Double.compare(this.score, candidate.score);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/wikibrain/matrix/knn/KDTreeKNN$Node.class */
    public static class Node {
        String path;
        float[] centroid;
        Node left;
        Node right;
        int[] memberIds;

        public Node(String str) {
            this.path = str;
        }
    }

    public KDTreeKNN(DenseMatrix denseMatrix) throws IOException {
        this.matrix = denseMatrix;
        this.allIds = denseMatrix.getRowIds();
        this.dimensions = denseMatrix.getRow(this.allIds[0]).getNumCols();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public void build() throws IOException {
        Node node = new Node("R");
        node.memberIds = new int[this.allIds.length];
        System.arraycopy(this.allIds, 0, node.memberIds, 0, this.allIds.length);
        shuffle(node.memberIds);
        this.centroids = new ArrayList();
        this.members = new ArrayList();
        build(node);
    }

    private void build(Node node) throws IOException {
        double dot;
        double dot2;
        if (node.memberIds.length < this.maxLeaf) {
            this.centroids.add(node.centroid);
            this.members.add(node.memberIds);
            return;
        }
        double[] dArr = new double[this.dimensions];
        double[] dArr2 = new double[this.dimensions];
        node.left = new Node(node.path + "L");
        node.right = new Node(node.path + "R");
        node.left.centroid = new float[this.dimensions];
        node.right.centroid = new float[this.dimensions];
        int min = Math.min(node.memberIds.length, this.maxSampleSize);
        int i = 0;
        while (i < 5) {
            int i2 = 0;
            int i3 = 0;
            Arrays.fill(dArr, 0.0d);
            Arrays.fill(dArr2, 0.0d);
            double d = 0.0d;
            int i4 = 0;
            while (i4 < min) {
                DenseMatrixRow row = this.matrix.getRow(node.memberIds[i4]);
                if (i == 0) {
                    dot = i4 < min / 2 ? 1.0d : 0.0d;
                    dot2 = 1.0d - dot;
                } else {
                    dot = row.dot(node.left.centroid);
                    dot2 = row.dot(node.right.centroid);
                }
                if (dot >= dot2) {
                    for (int i5 = 0; i5 < this.dimensions; i5++) {
                        int i6 = i5;
                        dArr[i6] = dArr[i6] + row.getColValue(i5);
                    }
                    i2++;
                } else {
                    for (int i7 = 0; i7 < this.dimensions; i7++) {
                        int i8 = i7;
                        dArr2[i8] = dArr2[i8] + row.getColValue(i7);
                    }
                    i3++;
                }
                d += Math.max(dot, dot2);
                i4++;
            }
            double d2 = i == 0 ? 0.0d : d / min;
            normalize(dArr);
            normalize(dArr2);
            for (int i9 = 0; i9 < this.dimensions; i9++) {
                node.left.centroid[i9] = (float) dArr[i9];
            }
            for (int i10 = 0; i10 < this.dimensions; i10++) {
                node.right.centroid[i10] = (float) dArr2[i10];
            }
            i++;
        }
        TIntArrayList tIntArrayList = new TIntArrayList();
        TIntArrayList tIntArrayList2 = new TIntArrayList();
        for (int i11 : node.memberIds) {
            DenseMatrixRow row2 = this.matrix.getRow(i11);
            if (row2.dot(node.left.centroid) >= row2.dot(node.right.centroid)) {
                tIntArrayList.add(i11);
            } else {
                tIntArrayList2.add(i11);
            }
        }
        node.left.memberIds = tIntArrayList.toArray();
        node.right.memberIds = tIntArrayList2.toArray();
        if (node.left.memberIds.length + node.right.memberIds.length != node.memberIds.length) {
            throw new IllegalStateException();
        }
        build(node.left);
        build(node.right);
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public Neighborhood query(float[] fArr, int i, int i2, TIntSet tIntSet) {
        TreeSet treeSet = new TreeSet();
        for (int i3 = 0; i3 < this.centroids.size(); i3++) {
            treeSet.add(new Candidate(i3, dot(this.centroids.get(i3), fArr)));
        }
        NeighborhoodAccumulator neighborhoodAccumulator = new NeighborhoodAccumulator(i);
        int i4 = 0;
        while (!treeSet.isEmpty()) {
            for (int i5 : this.members.get(((Candidate) treeSet.pollLast()).clusterNum)) {
                if (tIntSet == null || tIntSet.contains(i5)) {
                    try {
                        DenseMatrixRow row = this.matrix.getRow(i5);
                        neighborhoodAccumulator.visit(row.getRowIndex(), cosine(fArr, row));
                        i4++;
                    } catch (IOException e) {
                        throw new IllegalStateException(e);
                    }
                }
            }
            if (i4 >= i2) {
                break;
            }
        }
        return neighborhoodAccumulator.get();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public void save(File file) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public boolean load(File file) throws IOException {
        throw new UnsupportedOperationException();
    }

    public void setMaxSampleSize(int i) {
        this.maxSampleSize = i;
    }

    public void setMaxLeaf(int i) {
        this.maxLeaf = i;
    }

    static double cosine(DenseMatrixRow denseMatrixRow, DenseMatrixRow denseMatrixRow2) {
        if (denseMatrixRow == null || denseMatrixRow2 == null) {
            return 0.0d;
        }
        return denseMatrixRow.dot(denseMatrixRow2);
    }

    static double cosine(float[] fArr, DenseMatrixRow denseMatrixRow) {
        if (fArr == null || denseMatrixRow == null) {
            return 0.0d;
        }
        return denseMatrixRow.dot(fArr);
    }

    private static void shuffle(int[] iArr) {
        Random random = new Random();
        for (int length = iArr.length - 1; length > 0; length--) {
            int nextInt = random.nextInt(length + 1);
            int i = iArr[nextInt];
            iArr[nextInt] = iArr[length];
            iArr[length] = i;
        }
    }

    private double dot(float[] fArr, float[] fArr2) {
        double d = 0.0d;
        for (int i = 0; i < fArr.length; i++) {
            d += fArr[i] * fArr2[i];
        }
        return d;
    }

    private static void normalize(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr[i];
        }
        double sqrt = Math.sqrt(d) + 1.0E-5d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / sqrt;
        }
    }
}
