package org.wikibrain.matrix.knn;

import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.matrix.DenseMatrix;

/* loaded from: input_file:org/wikibrain/matrix/knn/RandomProjectionKNNFinder.class */
public class RandomProjectionKNNFinder implements KNNFinder {
    private static final Logger LOG = LoggerFactory.getLogger(RandomProjectionKNNFinder.class);
    public static final int NUM_BITS = 128;
    private final DenseMatrix matrix;
    private final int dimensions;
    private long[] bits;
    private int[] ids;
    private double[][] vectors;
    private double[] means;
    private double[] devs;

    public RandomProjectionKNNFinder(DenseMatrix denseMatrix) throws IOException {
        this.matrix = denseMatrix;
        this.ids = denseMatrix.getRowIds();
        this.dimensions = denseMatrix.getRow(this.ids[0]).getNumCols();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public void build() throws IOException {
        makeVectors();
        this.bits = new long[this.ids.length * 2];
        long[] jArr = new long[2];
        for (int i = 0; i < this.ids.length; i++) {
            project(this.matrix.getRow(this.ids[i]).getValues(), jArr);
            this.bits[i * 2] = jArr[0];
            this.bits[(i * 2) + 1] = jArr[1];
        }
    }

    private void makeVectors() throws IOException {
        this.means = new double[this.dimensions];
        int min = Math.min(5000, this.ids.length);
        for (int i = 0; i < min; i++) {
            float[] values = this.matrix.getRow(this.ids[i]).getValues();
            if (values.length != this.dimensions) {
                throw new IllegalStateException();
            }
            for (int i2 = 0; i2 < values.length; i2++) {
                double[] dArr = this.means;
                int i3 = i2;
                dArr[i3] = dArr[i3] + values[i2];
            }
        }
        for (int i4 = 0; i4 < this.dimensions; i4++) {
            double[] dArr2 = this.means;
            int i5 = i4;
            dArr2[i5] = dArr2[i5] / min;
        }
        this.devs = new double[this.dimensions];
        Arrays.fill(this.devs, 1.0E-4d);
        for (int i6 = 0; i6 < min; i6++) {
            float[] values2 = this.matrix.getRow(this.ids[i6]).getValues();
            for (int i7 = 0; i7 < values2.length; i7++) {
                double[] dArr3 = this.devs;
                int i8 = i7;
                dArr3[i8] = dArr3[i8] + ((values2[i7] - this.means[i7]) * (values2[i7] - this.means[i7]));
            }
        }
        for (int i9 = 0; i9 < this.dimensions; i9++) {
            this.devs[i9] = Math.sqrt(this.devs[i9] / min);
            LOG.debug("dimension " + i9 + " has mean " + this.means[i9] + " and std-dev " + this.devs[i9]);
        }
        Random random = new Random();
        this.vectors = new double[NUM_BITS][this.dimensions];
        double[] dArr4 = new double[NUM_BITS];
        for (int i10 = 0; i10 < this.dimensions; i10++) {
            for (int i11 = 0; i11 < this.vectors.length; i11++) {
                this.vectors[i11][i10] = random.nextGaussian() / 2.0d;
                int i12 = i11;
                dArr4[i12] = dArr4[i12] + (this.vectors[i11][i10] * this.vectors[i11][i10]);
            }
        }
        for (int i13 = 0; i13 < this.vectors.length; i13++) {
            for (int i14 = 0; i14 < this.dimensions; i14++) {
                double[] dArr5 = this.vectors[i13];
                int i15 = i14;
                dArr5[i15] = dArr5[i15] / (dArr4[i13] + 1.0E-6d);
            }
        }
    }

    private void project(float[] fArr, long[] jArr) {
        if (fArr.length != this.dimensions) {
            throw new IllegalArgumentException("Expected " + this.dimensions + " dimensions, found " + fArr.length);
        }
        double[] dArr = new double[this.dimensions];
        for (int i = 0; i < this.dimensions; i++) {
            dArr[i] = (fArr[i] - this.means[i]) / this.devs[i];
        }
        long j = 0;
        for (int i2 = 0; i2 < this.vectors.length / 2; i2++) {
            if (dot(this.vectors[i2], dArr) > 0.0d) {
                j |= 1 << i2;
            }
        }
        long j2 = 0;
        for (int length = this.vectors.length / 2; length < this.vectors.length; length++) {
            if (dot(this.vectors[length], dArr) > 0.0d) {
                j2 |= 1 << length;
            }
        }
        jArr[0] = j;
        jArr[1] = j2;
    }

    private double dot(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d;
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public Neighborhood query(float[] fArr, int i, int i2, TIntSet tIntSet) {
        if (tIntSet != null) {
            tIntSet = new TIntHashSet(tIntSet.size() * 4);
            tIntSet.addAll(tIntSet);
        }
        long[] jArr = new long[2];
        project(fArr, jArr);
        long j = jArr[0];
        long j2 = jArr[1];
        int[] iArr = new int[129];
        for (int i3 = 0; i3 < this.ids.length; i3++) {
            if (tIntSet == null || tIntSet.contains(this.ids[i3])) {
                int bitCount = (NUM_BITS - Long.bitCount(this.bits[2 * i3] ^ j)) - Long.bitCount(this.bits[(2 * i3) + 1] ^ j2);
                iArr[bitCount] = iArr[bitCount] + 1;
            }
        }
        int i4 = 0;
        int i5 = 128;
        while (i5 > 0) {
            i4 += iArr[i5];
            if (i4 >= i2) {
                break;
            }
            i5--;
        }
        NeighborhoodAccumulator neighborhoodAccumulator = new NeighborhoodAccumulator(i);
        for (int i6 = 0; i6 < this.ids.length; i6++) {
            if ((tIntSet == null || tIntSet.contains(this.ids[i6])) && (NUM_BITS - Long.bitCount(this.bits[2 * i6] ^ j)) - Long.bitCount(this.bits[(2 * i6) + 1] ^ j2) >= i5) {
                try {
                    neighborhoodAccumulator.visit(this.ids[i6], KmeansKNNFinder.cosine(fArr, this.matrix.getRow(this.ids[i6])));
                } catch (IOException e) {
                    throw new IllegalStateException(e);
                }
            }
        }
        return neighborhoodAccumulator.get();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public void save(File file) throws IOException {
        file.getParentFile().mkdirs();
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
        objectOutputStream.writeObject(new Object[]{this.vectors, this.bits, this.means, this.devs});
        objectOutputStream.close();
    }

    @Override // org.wikibrain.matrix.knn.KNNFinder
    public boolean load(File file) throws IOException {
        if (!file.isFile()) {
            LOG.warn("Not loading knn model. File doesn't exist: " + file);
            return false;
        }
        if (file.lastModified() < this.matrix.getPath().lastModified()) {
            LOG.warn("Not loading knn model. File " + file + " older than matrix: " + this.matrix.getPath());
            return false;
        }
        try {
            Object[] objArr = (Object[]) new ObjectInputStream(new FileInputStream(file)).readObject();
            double[][] dArr = (double[][]) objArr[0];
            long[] jArr = (long[]) objArr[1];
            double[] dArr2 = (double[]) objArr[2];
            double[] dArr3 = (double[]) objArr[3];
            if (jArr.length != this.ids.length * 2) {
                LOG.warn("Not loading knn model. Expected " + (2 * this.ids.length) + " longs, found " + jArr.length);
                return false;
            }
            if (dArr.length != 128 || dArr[0].length != this.dimensions) {
                LOG.warn("Not loading knn model. Invalid vectors dimensions.");
                return false;
            }
            if (dArr2.length != this.dimensions || dArr3.length != this.dimensions) {
                LOG.warn("Not loading knn model. Invalid mean or devs dimensions.");
                return false;
            }
            this.vectors = dArr;
            this.bits = jArr;
            this.means = dArr2;
            this.devs = dArr3;
            return true;
        } catch (ClassNotFoundException e) {
            throw new IOException(e);
        }
    }
}
