package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
/* loaded from: input_file:org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.class */
public class Word2VecPerformer implements VoidFunction<Pair<List<VocabWord>, AtomicLong>> {
    private INDArray table;
    private Broadcast<AtomicLong> wordCount;
    private InMemoryLookupTable weights;
    private int vectorLength;
    private static double MAX_EXP = 6.0d;
    private static final transient Logger log = LoggerFactory.getLogger(Word2VecPerformer.class);
    private boolean useAdaGrad = false;
    private double negative = 5.0d;
    private int numWords = 1;
    private int window = 5;
    private AtomicLong nextRandom = new AtomicLong(5);
    private double alpha = 0.025d;
    private double minAlpha = 0.01d;
    private int totalWords = 1;
    private int lastChecked = 0;
    private double[] expTable = new double[1000];

    public Word2VecPerformer(SparkConf sparkConf, Broadcast<AtomicLong> broadcast, InMemoryLookupTable inMemoryLookupTable) {
        this.weights = inMemoryLookupTable;
        this.wordCount = broadcast;
        setup(sparkConf);
    }

    public void setup(SparkConf sparkConf) {
        this.useAdaGrad = sparkConf.getBoolean("org.deeplearning4j.scaleout.perform.models.word2vec.adagrad", false);
        this.negative = sparkConf.getDouble("org.deeplearning4j.scaleout.perform.models.word2vec.negative", 5.0d);
        this.numWords = sparkConf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.numwords", 1);
        this.window = sparkConf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.window", 5);
        this.alpha = sparkConf.getDouble("org.deeplearning4j.scaleout.perform.models.word2vec.alpha", 0.02500000037252903d);
        this.minAlpha = sparkConf.getDouble("org.deeplearning4j.scaleout.perform.models.word2vec.minalpha", 0.009999999776482582d);
        this.totalWords = sparkConf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.numwords", 1);
        this.vectorLength = sparkConf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.length", 100);
        initExpTable();
        if (this.negative <= 0.0d || !sparkConf.contains("org.deeplearning4j.scaleout.perform.models.word2vec.table")) {
            return;
        }
        this.table = Nd4j.read(new DataInputStream(new ByteArrayInputStream(sparkConf.get("org.deeplearning4j.scaleout.perform.models.word2vec.table").getBytes())));
    }

    public void trainSentence(List<VocabWord> list, double d) {
        if (list == null || list.isEmpty()) {
            return;
        }
        for (int i = 0; i < list.size(); i++) {
            if (!list.get(i).getWord().endsWith("STOP")) {
                this.nextRandom.set((this.nextRandom.get() * 25214903917L) + 11);
                skipGram(i, list, ((int) this.nextRandom.get()) % this.window, d);
            }
        }
    }

    public void skipGram(int i, List<VocabWord> list, int i2, double d) {
        int i3;
        VocabWord vocabWord = list.get(i);
        if (vocabWord == null || list.isEmpty()) {
            return;
        }
        int i4 = ((this.window * 2) + 1) - i2;
        for (int i5 = i2; i5 < i4; i5++) {
            if (i5 != this.window && (i3 = (i - this.window) + i5) >= 0 && i3 < list.size()) {
                iterateSample(vocabWord, list.get(i3), d);
            }
        }
    }

    public void iterateSample(VocabWord vocabWord, VocabWord vocabWord2, double d) {
        int i;
        double gradient;
        int length;
        if (vocabWord2 == null || vocabWord2.getIndex() < 0) {
            return;
        }
        INDArray vector = this.weights.vector(vocabWord2.getWord());
        INDArray create = Nd4j.create(this.vectorLength);
        for (int i2 = 0; i2 < vocabWord.getCodeLength(); i2++) {
            byte byteValue = ((Byte) vocabWord.getCodes().get(i2)).byteValue();
            INDArray slice = this.weights.getSyn1().slice(((Integer) vocabWord.getPoints().get(i2)).intValue());
            double dot = Nd4j.getBlasWrapper().dot(vector, slice);
            if (dot >= (-MAX_EXP) && dot < MAX_EXP && (length = (int) ((dot + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))) < this.expTable.length) {
                double gradient2 = ((1 - byteValue) - this.expTable[length]) * (this.useAdaGrad ? vocabWord.getGradient(i2, d, this.alpha) : d);
                Nd4j.getBlasWrapper().level1().axpy(vector.length(), gradient2, slice, create);
                Nd4j.getBlasWrapper().level1().axpy(vector.length(), gradient2, vector, slice);
            }
        }
        if (this.negative > 0.0d) {
            int index = vocabWord.getIndex();
            INDArray slice2 = this.weights.getSyn1Neg().slice(index);
            for (int i3 = 0; i3 < this.negative + 1.0d; i3++) {
                if (i3 == 0) {
                    i = 1;
                } else {
                    this.nextRandom.set((this.nextRandom.get() * 25214903917L) + 11);
                    index = this.table.getInt(new int[]{((int) (this.nextRandom.get() >> 16)) % ((int) this.table.length())});
                    if (index == 0) {
                        index = (((int) this.nextRandom.get()) % (this.numWords - 1)) + 1;
                    }
                    if (index != vocabWord.getIndex()) {
                        i = 0;
                    }
                }
                double dot2 = Nd4j.getBlasWrapper().dot(vector, slice2);
                if (dot2 > MAX_EXP) {
                    gradient = this.useAdaGrad ? vocabWord.getGradient(index, i - 1, this.alpha) : (i - 1) * d;
                } else if (dot2 < (-MAX_EXP)) {
                    gradient = i * (this.useAdaGrad ? vocabWord.getGradient(index, d, this.alpha) : d);
                } else {
                    gradient = this.useAdaGrad ? vocabWord.getGradient(index, i - this.expTable[(int) ((dot2 + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))], this.alpha) : (i - this.expTable[(int) ((dot2 + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))]) * d;
                }
                if (slice2.data().dataType() == DataType.DOUBLE) {
                    Nd4j.getBlasWrapper().axpy(gradient, create, vector);
                } else {
                    Nd4j.getBlasWrapper().axpy((float) gradient, create, vector);
                }
                if (slice2.data().dataType() == DataType.DOUBLE) {
                    Nd4j.getBlasWrapper().axpy(gradient, slice2, vector);
                } else {
                    Nd4j.getBlasWrapper().axpy((float) gradient, slice2, vector);
                }
            }
        }
        if (create.data().dataType() == DataType.DOUBLE) {
            Nd4j.getBlasWrapper().axpy(1.0d, create, vector);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, create, vector);
        }
    }

    private void initExpTable() {
        for (int i = 0; i < this.expTable.length; i++) {
            double exp = FastMath.exp((((i / this.expTable.length) * 2.0d) - 1.0d) * MAX_EXP);
            this.expTable[i] = exp / (exp + 1.0d);
        }
    }

    public void call(Pair<List<VocabWord>, AtomicLong> pair) throws Exception {
        double doubleValue = ((AtomicLong) this.wordCount.getValue()).doubleValue();
        List<VocabWord> list = (List) pair.getFirst();
        trainSentence(list, Math.max(this.minAlpha, this.alpha * (1.0d - ((1.0d * doubleValue) / this.totalWords))));
        int size = 0 + list.size();
        double d = size + doubleValue;
        if (Math.abs(d - this.lastChecked) >= 10000.0d) {
            this.lastChecked = (int) d;
            log.info("Words so far " + d + " out of " + this.totalWords);
        }
        ((AtomicLong) pair.getSecond()).getAndAdd(size);
    }
}
