package org.wikibrain.matrix.knn;

import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.DenseMatrixRow;

/* loaded from: input_file:org/wikibrain/matrix/knn/KmeansKNNFinder.class */
public class KmeansKNNFinder implements KNNFinder {
    private final DenseMatrix matrix;
    private int sampleSize = 50000;
    private int maxLeaf = 20;
    private int branchingFactor = 5;
    private Node root;

    /* loaded from: input_file:org/wikibrain/matrix/knn/KmeansKNNFinder$Candidate.class */
    private static class Candidate implements Comparable<Candidate> {
        Node n;
        double score;

        public Candidate(Node node, double d) {
            this.n = node;
            this.score = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(Candidate candidate) {
            return Double.compare(this.score, candidate.score);
        }
    }

    /* loaded from: input_file:org/wikibrain/matrix/knn/KmeansKNNFinder$Node.class */
    class Node {
        String path;
        DenseMatrixRow delegate;
        TIntList memberIds;
        Node[] children = null;
        List<DenseMatrixRow> members = new ArrayList();

        Node(String str) {
            this.path = str;
        }

        void build() {
            if (this.members.size() <= KmeansKNNFinder.this.maxLeaf) {
                endBuild();
                return;
            }
            initializeRandomly();
            for (Node node : this.children) {
                node.updateCenter();
            }
            double d = 1.0E-9d;
            for (int i = 0; i < 5; i++) {
                double reallocateMembers = reallocateMembers();
                if ((reallocateMembers / d) - 1.0d < 0.001d) {
                    break;
                }
                for (Node node2 : this.children) {
                    node2.updateCenter();
                }
                d = reallocateMembers;
            }
            endBuild();
            for (Node node3 : this.children) {
                node3.build();
            }
        }

        void place(DenseMatrixRow denseMatrixRow) {
            if (this.children == null) {
                this.memberIds.add(denseMatrixRow.getRowIndex());
            } else {
                findClosestChild(denseMatrixRow).place(denseMatrixRow);
            }
        }

        private void endBuild() {
            this.members = null;
            this.memberIds = new TIntArrayList();
        }

        private void initializeRandomly() {
            this.children = new Node[KmeansKNNFinder.this.branchingFactor];
            for (int i = 0; i < this.children.length; i++) {
                this.children[i] = new Node(this.path + i);
            }
            Collections.shuffle(this.members);
            for (int i2 = 0; i2 < this.members.size(); i2++) {
                this.children[i2 % KmeansKNNFinder.this.branchingFactor].members.add(this.members.get(i2));
            }
        }

        private double updateCenter() {
            if (this.members.isEmpty()) {
                this.delegate = null;
                return 0.0d;
            }
            double[] dArr = new double[this.members.get(0).getNumCols()];
            for (DenseMatrixRow denseMatrixRow : this.members) {
                for (int i = 0; i < dArr.length; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + denseMatrixRow.getColValue(i);
                }
            }
            for (int i3 = 0; i3 < dArr.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] / this.members.size();
            }
            double d = 0.0d;
            double d2 = -10.0d;
            for (DenseMatrixRow denseMatrixRow2 : this.members) {
                double cosine = KmeansKNNFinder.cosine(dArr, denseMatrixRow2);
                d += cosine;
                if (cosine > d2) {
                    d2 = cosine;
                    this.delegate = denseMatrixRow2;
                }
            }
            return d / this.members.size();
        }

        private double reallocateMembers() {
            for (Node node : this.children) {
                node.members.clear();
            }
            double d = 0.0d;
            for (DenseMatrixRow denseMatrixRow : this.members) {
                Node findClosestChild = findClosestChild(denseMatrixRow);
                d += findClosestChild.similarity(denseMatrixRow);
                findClosestChild.members.add(denseMatrixRow);
            }
            return d / this.members.size();
        }

        private Node findClosestChild(DenseMatrixRow denseMatrixRow) {
            double d = -10.0d;
            Node node = null;
            for (Node node2 : this.children) {
                double similarity = node2.similarity(denseMatrixRow);
                if (similarity > d) {
                    node = node2;
                    d = similarity;
                }
            }
            if (node == null) {
                throw new IllegalStateException();
            }
            return node;
        }

        private double similarity(DenseMatrixRow denseMatrixRow) {
            return KmeansKNNFinder.cosine(this.delegate, denseMatrixRow);
        }

        private double similarity(float[] fArr) {
            return KmeansKNNFinder.cosine(fArr, this.delegate);
        }
    }

    public KmeansKNNFinder(DenseMatrix denseMatrix) {
        this.matrix = denseMatrix;
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public void build() throws IOException {
        this.root = new Node("R");
        this.root.members.addAll(getSample());
        this.root.build();
        Iterator<DenseMatrixRow> it = this.matrix.iterator();
        while (it.hasNext()) {
            this.root.place(it.next());
        }
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public Neighborhood query(float[] fArr, int i, int i2, TIntSet tIntSet) {
        if (tIntSet != null) {
            throw new UnsupportedOperationException();
        }
        NeighborhoodAccumulator neighborhoodAccumulator = new NeighborhoodAccumulator(i);
        TreeSet treeSet = new TreeSet();
        treeSet.add(new Candidate(this.root, -1.0d));
        int i3 = 0;
        while (!treeSet.isEmpty()) {
            Node node = ((Candidate) treeSet.pollLast()).n;
            for (int i4 : node.memberIds.toArray()) {
                try {
                    DenseMatrixRow row = this.matrix.getRow(i4);
                    neighborhoodAccumulator.visit(row.getRowIndex(), cosine(fArr, row));
                    i3++;
                } catch (IOException e) {
                    throw new IllegalStateException(e);
                }
            }
            if (i3 >= i2) {
                break;
            }
            if (node.children != null) {
                for (Node node2 : node.children) {
                    treeSet.add(new Candidate(node2, cosine(fArr, node2.delegate)));
                }
            }
        }
        return neighborhoodAccumulator.get();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public void save(File file) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public boolean load(File file) throws IOException {
        throw new UnsupportedOperationException();
    }

    public void setSampleSize(int i) {
        this.sampleSize = i;
    }

    public void setMaxLeaf(int i) {
        this.maxLeaf = i;
    }

    public void setBranchingFactor(int i) {
        this.branchingFactor = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v24, types: [java.util.List] */
    private List<DenseMatrixRow> getSample() throws IOException {
        ArrayList arrayList = new ArrayList();
        for (int i : this.matrix.getRowIds()) {
            arrayList.add(Integer.valueOf(i));
        }
        Collections.shuffle(arrayList);
        if (arrayList.size() > this.sampleSize) {
            arrayList = arrayList.subList(0, this.sampleSize);
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(this.matrix.getRow(((Integer) it.next()).intValue()));
        }
        return arrayList2;
    }

    static double cosine(DenseMatrixRow denseMatrixRow, DenseMatrixRow denseMatrixRow2) {
        if (denseMatrixRow == null || denseMatrixRow2 == null) {
            return 0.0d;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < denseMatrixRow.getNumCols(); i++) {
            double colValue = denseMatrixRow.getColValue(i);
            double colValue2 = denseMatrixRow2.getColValue(i);
            d += colValue * colValue;
            d2 += colValue2 * colValue2;
            d3 += colValue * colValue2;
        }
        if (d * d2 == 0.0d) {
            return 0.0d;
        }
        return d3 / Math.sqrt(d * d2);
    }

    static double cosine(double[] dArr, DenseMatrixRow denseMatrixRow) {
        if (dArr == null || denseMatrixRow == null) {
            return 0.0d;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d4 = dArr[i];
            double colValue = denseMatrixRow.getColValue(i);
            d += d4 * d4;
            d2 += colValue * colValue;
            d3 += d4 * colValue;
        }
        if (d * d2 == 0.0d) {
            return 0.0d;
        }
        return d3 / Math.sqrt(d * d2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double cosine(float[] fArr, DenseMatrixRow denseMatrixRow) {
        if (fArr == null || denseMatrixRow == null) {
            return 0.0d;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < fArr.length; i++) {
            double d4 = fArr[i];
            double colValue = denseMatrixRow.getColValue(i);
            d += d4 * d4;
            d2 += colValue * colValue;
            d3 += d4 * colValue;
        }
        if (d * d2 == 0.0d) {
            return 0.0d;
        }
        return d3 / Math.sqrt(d * d2);
    }
}
