package org.deeplearning4j.models.embeddings.reader.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import lombok.NonNull;
import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Counter;
import org.nd4j.linalg.util.MathUtils;
import org.nd4j.shade.guava.collect.Lists;
import org.nd4j.util.SetUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.class */
public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BasicModelUtils.class);
    public static final String EXISTS = "exists";
    public static final String CORRECT = "correct";
    public static final String WRONG = "wrong";
    protected volatile VocabCache<T> vocabCache;
    protected volatile WeightLookupTable<T> lookupTable;
    protected volatile boolean normalized = false;

    /* loaded from: input_file:org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils$ArrayComparator.class */
    public static class ArrayComparator implements Comparator<Double[]> {
        @Override // java.util.Comparator
        public int compare(Double[] dArr, Double[] dArr2) {
            if (Double.isNaN(dArr[0].doubleValue()) && Double.isNaN(dArr2[0].doubleValue())) {
                return 0;
            }
            if (Double.isNaN(dArr[0].doubleValue()) && !Double.isNaN(dArr2[0].doubleValue())) {
                return -1;
            }
            if (Double.isNaN(dArr[0].doubleValue()) || !Double.isNaN(dArr2[0].doubleValue())) {
                return Double.compare(dArr[0].doubleValue(), dArr2[0].doubleValue());
            }
            return 1;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils$SimilarityComparator.class */
    public static class SimilarityComparator implements Comparator<WordSimilarity> {
        @Override // java.util.Comparator
        public int compare(WordSimilarity wordSimilarity, WordSimilarity wordSimilarity2) {
            if (Double.isNaN(wordSimilarity.getSimilarity()) && Double.isNaN(wordSimilarity2.getSimilarity())) {
                return 0;
            }
            if (Double.isNaN(wordSimilarity.getSimilarity()) && !Double.isNaN(wordSimilarity2.getSimilarity())) {
                return -1;
            }
            if (Double.isNaN(wordSimilarity.getSimilarity()) || !Double.isNaN(wordSimilarity2.getSimilarity())) {
                return Double.compare(wordSimilarity2.getSimilarity(), wordSimilarity.getSimilarity());
            }
            return 1;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils$WordSimilarity.class */
    public static class WordSimilarity {
        private String word;
        private double similarity;

        public String getWord() {
            return this.word;
        }

        public double getSimilarity() {
            return this.similarity;
        }

        public void setWord(String str) {
            this.word = str;
        }

        public void setSimilarity(double d) {
            this.similarity = d;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof WordSimilarity)) {
                return false;
            }
            WordSimilarity wordSimilarity = (WordSimilarity) obj;
            if (!wordSimilarity.canEqual(this)) {
                return false;
            }
            String word = getWord();
            String word2 = wordSimilarity.getWord();
            if (word == null) {
                if (word2 != null) {
                    return false;
                }
            } else if (!word.equals(word2)) {
                return false;
            }
            return Double.compare(getSimilarity(), wordSimilarity.getSimilarity()) == 0;
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof WordSimilarity;
        }

        public int hashCode() {
            String word = getWord();
            int hashCode = (1 * 59) + (word == null ? 43 : word.hashCode());
            long doubleToLongBits = Double.doubleToLongBits(getSimilarity());
            return (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        }

        public String toString() {
            return "BasicModelUtils.WordSimilarity(word=" + getWord() + ", similarity=" + getSimilarity() + ")";
        }

        public WordSimilarity(String str, double d) {
            this.word = str;
            this.similarity = d;
        }
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public void init(@NonNull WeightLookupTable<T> weightLookupTable) {
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable is marked @NonNull but is null");
        }
        this.vocabCache = weightLookupTable.getVocabCache();
        this.lookupTable = weightLookupTable;
        this.normalized = false;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public double similarity(@NonNull String str, @NonNull String str2) {
        if (str == null) {
            throw new NullPointerException("label1 is marked @NonNull but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("label2 is marked @NonNull but is null");
        }
        if (str == null || str2 == null) {
            log.debug("LABELS: " + str + ": " + (str == null ? "null" : EXISTS) + ";" + str2 + " vec2:" + (str2 == null ? "null" : EXISTS));
            return Double.NaN;
        }
        if (!this.vocabCache.hasToken(str)) {
            log.debug("Unknown token 1 requested: [{}]", str);
            return Double.NaN;
        }
        if (!this.vocabCache.hasToken(str2)) {
            log.debug("Unknown token 2 requested: [{}]", str2);
            return Double.NaN;
        }
        INDArray dup = this.lookupTable.vector(str).dup();
        INDArray dup2 = this.lookupTable.vector(str2).dup();
        if (dup == null || dup2 == null) {
            log.debug(str + ": " + (dup == null ? "null" : EXISTS) + ";" + str2 + " vec2:" + (dup2 == null ? "null" : EXISTS));
            return Double.NaN;
        }
        if (str.equals(str2)) {
            return 1.0d;
        }
        return Transforms.cosineSim(dup, dup2);
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearest(String str, int i) {
        ArrayList arrayList = new ArrayList(wordsNearest(Arrays.asList(str), new ArrayList(), i + 1));
        if (arrayList.contains(str)) {
            arrayList.remove(str);
        }
        while (arrayList.size() > i) {
            arrayList.remove(arrayList.size() - 1);
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Map<String, Double> accuracy(List<String> list) {
        HashMap hashMap = new HashMap();
        Counter counter = new Counter();
        String str = "";
        for (String str2 : list) {
            if (str2.startsWith(SVMLightRecordReader.FEATURE_DELIMITER)) {
                double count = counter.getCount(CORRECT);
                double count2 = counter.getCount(WRONG);
                if (str.isEmpty()) {
                    str = str2;
                } else {
                    hashMap.put(str, Double.valueOf((100.0d * count) / (count + count2)));
                    str = str2;
                    counter.clear();
                }
            } else {
                String[] split = str2.split(" ");
                if (split[3].equals(wordsNearest(Arrays.asList(split[1], split[2]), Arrays.asList(split[0]), 1).iterator().next())) {
                    counter.incrementCount(CORRECT, 1.0d);
                } else {
                    counter.incrementCount(WRONG, 1.0d);
                }
            }
        }
        if (!str.isEmpty()) {
            double count3 = counter.getCount(CORRECT);
            hashMap.put(str, Double.valueOf((100.0d * count3) / (count3 + counter.getCount(WRONG))));
        }
        return hashMap;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public List<String> similarWordsInVocabTo(String str, double d) {
        ArrayList arrayList = new ArrayList();
        for (String str2 : this.vocabCache.words()) {
            if (MathUtils.stringSimilarity(str, str2) >= d) {
                arrayList.add(str2);
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearest(@NonNull Collection<String> collection, @NonNull Collection<String> collection2, int i) {
        if (collection == null) {
            throw new NullPointerException("positive is marked @NonNull but is null");
        }
        if (collection2 == null) {
            throw new NullPointerException("negative is marked @NonNull but is null");
        }
        Iterator it2 = SetUtils.union(new HashSet(collection), new HashSet(collection2)).iterator();
        while (it2.hasNext()) {
            if (!this.vocabCache.containsWord((String) it2.next())) {
                return new ArrayList();
            }
        }
        INDArray create = Nd4j.create(collection.size() + collection2.size(), this.lookupTable.layerSize());
        int i2 = 0;
        Iterator<String> it3 = collection.iterator();
        while (it3.hasNext()) {
            int i3 = i2;
            i2++;
            create.putRow(i3, this.lookupTable.vector(it3.next()));
        }
        Iterator<String> it4 = collection2.iterator();
        while (it4.hasNext()) {
            int i4 = i2;
            i2++;
            create.putRow(i4, this.lookupTable.vector(it4.next()).mul((Number) (-1)));
        }
        Collection<String> wordsNearest = wordsNearest(create.isMatrix() ? create.mean(0).reshape(1L, create.size(1)) : create, i + collection.size() + collection2.size());
        ArrayList arrayList = new ArrayList();
        for (String str : wordsNearest) {
            if (!collection.contains(str) && !collection2.contains(str) && arrayList.size() < i) {
                arrayList.add(str);
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearestSum(String str, int i) {
        return wordsNearestSum(this.lookupTable.vector(str), i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray adjustRank(INDArray iNDArray) {
        if (this.lookupTable instanceof InMemoryLookupTable) {
            INDArray syn0 = ((InMemoryLookupTable) this.lookupTable).getSyn0();
            if (!iNDArray.dataType().equals(syn0.dataType())) {
                return iNDArray.castTo(syn0.dataType());
            }
            if (iNDArray.rank() == 0 || iNDArray.rank() > 2) {
                throw new IllegalStateException("Invalid rank for wordsNearest method");
            }
            if (iNDArray.rank() == 1) {
                return iNDArray.reshape(1L, -1L);
            }
        }
        return iNDArray;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearest(INDArray iNDArray, int i) {
        INDArray adjustRank = adjustRank(iNDArray);
        if (!(this.lookupTable instanceof InMemoryLookupTable)) {
            Counter counter = new Counter();
            Iterator<String> it2 = this.vocabCache.words().iterator();
            while (it2.hasNext()) {
                counter.incrementCount(it2.next(), (float) Transforms.cosineSim(adjustRank, this.lookupTable.vector(r0)));
            }
            counter.keepTopNElements(i);
            return counter.keySet();
        }
        INDArray syn0 = ((InMemoryLookupTable) this.lookupTable).getSyn0();
        if (!this.normalized) {
            synchronized (this) {
                if (!this.normalized) {
                    syn0.diviColumnVector(syn0.norm2(1));
                    this.normalized = true;
                }
            }
        }
        List<Double> topN = getTopN(Transforms.unitVec(adjustRank).mmul(syn0.transpose()), i + 20);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < topN.size(); i2++) {
            String wordAtIndex = this.vocabCache.wordAtIndex(topN.get(i2).intValue());
            if (wordAtIndex != null && !wordAtIndex.equals(WordVectorsImpl.DEFAULT_UNK) && !wordAtIndex.equals("STOP")) {
                arrayList.add(new WordSimilarity(wordAtIndex, Transforms.cosineSim(adjustRank, this.lookupTable.vector(wordAtIndex))));
            }
        }
        Collections.sort(arrayList, new SimilarityComparator());
        return getLabels(arrayList, i);
    }

    private List<Double> getTopN(INDArray iNDArray, int i) {
        ArrayComparator arrayComparator = new ArrayComparator();
        PriorityQueue priorityQueue = new PriorityQueue(iNDArray.rows(), arrayComparator);
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            Double[] dArr = {Double.valueOf(iNDArray.getDouble(i2)), Double.valueOf(i2)};
            if (priorityQueue.size() < i) {
                priorityQueue.add(dArr);
            } else if (arrayComparator.compare(dArr, (Double[]) priorityQueue.peek()) > 0) {
                priorityQueue.poll();
                priorityQueue.add(dArr);
            }
        }
        ArrayList arrayList = new ArrayList();
        while (!priorityQueue.isEmpty()) {
            arrayList.add(Double.valueOf(((Double[]) priorityQueue.poll())[1].doubleValue()));
        }
        return Lists.reverse(arrayList);
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearestSum(INDArray iNDArray, int i) {
        if (!(this.lookupTable instanceof InMemoryLookupTable)) {
            Counter counter = new Counter();
            Iterator<String> it2 = this.vocabCache.words().iterator();
            while (it2.hasNext()) {
                counter.incrementCount(it2.next(), (float) Transforms.cosineSim(iNDArray, this.lookupTable.vector(r0)));
            }
            counter.keepTopNElements(i);
            return counter.keySet();
        }
        INDArray syn0 = ((InMemoryLookupTable) this.lookupTable).getSyn0();
        INDArray iNDArray2 = Nd4j.sortWithIndices(syn0.mulRowVector(syn0.norm2(0).rdivi((Number) 1).muli(iNDArray)).sum(1), 0, false)[0];
        ArrayList arrayList = new ArrayList();
        if (i > iNDArray2.length()) {
            i = (int) iNDArray2.length();
        }
        int i2 = i;
        for (int i3 = 0; i3 < i2; i3++) {
            String wordAtIndex = this.vocabCache.wordAtIndex(iNDArray2.getInt(i3));
            if (wordAtIndex == null || wordAtIndex.equals(WordVectorsImpl.DEFAULT_UNK) || wordAtIndex.equals("STOP")) {
                i2++;
                if (i2 >= iNDArray2.length()) {
                    break;
                }
            } else {
                arrayList.add(this.vocabCache.wordAtIndex(iNDArray2.getInt(i3)));
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearestSum(Collection<String> collection, Collection<String> collection2, int i) {
        INDArray create = Nd4j.create(this.lookupTable.layerSize());
        Iterator<String> it2 = collection.iterator();
        while (it2.hasNext()) {
            create.addi(this.lookupTable.vector(it2.next()));
        }
        Iterator<String> it3 = collection2.iterator();
        while (it3.hasNext()) {
            create.addi(this.lookupTable.vector(it3.next()).mul((Number) (-1)));
        }
        return wordsNearestSum(create, i);
    }

    public static List<String> getLabels(List<WordSimilarity> list, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            arrayList.add(list.get(i2).getWord());
            if (arrayList.size() >= i) {
                break;
            }
        }
        return arrayList;
    }
}
