package org.wikibrain.matrix.knn;

import gnu.trove.set.TIntSet;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.matrix.DenseMatrix;

/* loaded from: input_file:org/wikibrain/matrix/knn/LSHForestKNNFinder.class */
public class LSHForestKNNFinder implements KNNFinder {
    private static final Logger LOG = LoggerFactory.getLogger(LSHForestKNNFinder.class);
    private static final int NUM_BITS = 16;
    private int numTrees = 5;
    private short[][] bits;
    private final DenseMatrix matrix;
    private final int dimensions;
    private final int[] ids;
    private double[][][] vectors;
    private double[] means;
    private double[] devs;

    public LSHForestKNNFinder(DenseMatrix denseMatrix) throws IOException {
        this.matrix = denseMatrix;
        this.ids = denseMatrix.getRowIds();
        this.dimensions = denseMatrix.getRow(this.ids[0]).getNumCols();
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [short[], short[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [double[][], double[][][]] */
    @Override // org.wikibrain.matrix.knn.KNNFinder
    public synchronized void build() throws IOException {
        analyzeSample();
        this.bits = new short[this.numTrees];
        this.vectors = new double[this.numTrees];
        for (int i = 0; i < this.numTrees; i++) {
            buildTree(i);
        }
    }

    private void analyzeSample() throws IOException {
        this.means = new double[this.dimensions];
        int min = Math.min(5000, this.ids.length);
        for (int i = 0; i < min; i++) {
            float[] values = this.matrix.getRow(this.ids[i]).getValues();
            if (values.length != this.dimensions) {
                throw new IllegalStateException();
            }
            for (int i2 = 0; i2 < values.length; i2++) {
                double[] dArr = this.means;
                int i3 = i2;
                dArr[i3] = dArr[i3] + values[i2];
            }
        }
        for (int i4 = 0; i4 < this.dimensions; i4++) {
            double[] dArr2 = this.means;
            int i5 = i4;
            dArr2[i5] = dArr2[i5] / min;
        }
        this.devs = new double[this.dimensions];
        Arrays.fill(this.devs, 1.0E-4d);
        for (int i6 = 0; i6 < min; i6++) {
            float[] values2 = this.matrix.getRow(this.ids[i6]).getValues();
            for (int i7 = 0; i7 < values2.length; i7++) {
                double[] dArr3 = this.devs;
                int i8 = i7;
                dArr3[i8] = dArr3[i8] + ((values2[i7] - this.means[i7]) * (values2[i7] - this.means[i7]));
            }
        }
        for (int i9 = 0; i9 < this.dimensions; i9++) {
            this.devs[i9] = Math.sqrt(this.devs[i9] / min);
            LOG.info("dimension " + i9 + " has mean " + this.means[i9] + " and std-dev " + this.devs[i9]);
        }
    }

    private void buildTree(int i) throws IOException {
        double[][] dArr = new double[NUM_BITS][this.dimensions];
        this.vectors[i] = dArr;
        Random random = new Random();
        double[] dArr2 = new double[NUM_BITS];
        for (int i2 = 0; i2 < this.dimensions; i2++) {
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3][i2] = random.nextGaussian() / 2.0d;
                int i4 = i3;
                dArr2[i4] = dArr2[i4] + (dArr[i3][i2] * dArr[i3][i2]);
            }
        }
        for (int i5 = 0; i5 < dArr.length; i5++) {
            for (int i6 = 0; i6 < this.dimensions; i6++) {
                double[] dArr3 = dArr[i5];
                int i7 = i6;
                dArr3[i7] = dArr3[i7] / (dArr2[i5] + 1.0E-6d);
            }
        }
        short[] sArr = new short[this.ids.length];
        this.bits[i] = sArr;
        for (int i8 = 0; i8 < this.ids.length; i8++) {
            sArr[i8] = project(i, this.matrix.getRow(this.ids[i8]).getValues());
        }
    }

    private short project(int i, float[] fArr) {
        if (fArr.length != this.dimensions) {
            throw new IllegalArgumentException("Expected " + this.dimensions + " dimensions, found " + fArr.length);
        }
        double[][] dArr = this.vectors[i];
        double[] dArr2 = new double[this.dimensions];
        for (int i2 = 0; i2 < this.dimensions; i2++) {
            dArr2[i2] = (fArr[i2] - this.means[i2]) / this.devs[i2];
        }
        short s = 0;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (dot(dArr[i3], dArr2) > 0.0d) {
                s = (short) (s | (1 << i3));
            }
        }
        return s;
    }

    private double dot(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d;
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public Neighborhood query(float[] fArr, int i, int i2, TIntSet tIntSet) {
        short[] sArr = new short[this.numTrees];
        for (int i3 = 0; i3 < this.numTrees; i3++) {
            sArr[i3] = project(i3, fArr);
        }
        new Random();
        byte[] bArr = new byte[this.ids.length];
        Arrays.fill(bArr, (byte) -1);
        int[] iArr = new int[17];
        for (int i4 = 0; i4 < this.ids.length; i4++) {
            if (tIntSet == null || tIntSet.contains(this.ids[i4])) {
                int i5 = -1;
                for (int i6 = 0; i6 < this.numTrees; i6++) {
                    i5 = Math.max(Integer.numberOfLeadingZeros((sArr[i6] ^ this.bits[i6][i4]) & 65535) - NUM_BITS, i5);
                }
                if (i5 < 0 || i5 > 127) {
                    throw new IllegalStateException();
                }
                bArr[i4] = (byte) i5;
                int i7 = i5;
                iArr[i7] = iArr[i7] + 1;
            }
        }
        int i8 = 0;
        int i9 = NUM_BITS;
        while (i9 > 0) {
            i8 += iArr[i9];
            if (i8 >= i2) {
                break;
            }
            i9--;
        }
        System.out.println("threshold is " + i9 + " for " + Arrays.toString(iArr));
        NeighborhoodAccumulator neighborhoodAccumulator = new NeighborhoodAccumulator(i);
        for (int i10 = 0; i10 < this.ids.length; i10++) {
            if (bArr[i10] >= i9) {
                try {
                    neighborhoodAccumulator.visit(this.ids[i10], KmeansKNNFinder.cosine(fArr, this.matrix.getRow(this.ids[i10])));
                } catch (IOException e) {
                    throw new IllegalStateException(e);
                }
            }
        }
        return neighborhoodAccumulator.get();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public void save(File file) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public boolean load(File file) throws IOException {
        throw new UnsupportedOperationException();
    }

    private String paddedShortBinary(int i) {
        return String.format("%16s", Integer.toBinaryString(i & 65535)).replace(' ', '0');
    }
}
