package org.deeplearning4j.models.embeddings.inmemory;

import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.net.HttpURLConnection;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.plot.BarnesHutTsne;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.class */
public class InMemoryLookupTable<T extends SequenceElement> implements WeightLookupTable<T> {
    protected INDArray syn0;
    protected INDArray syn1;
    protected int vectorLength;
    protected transient Random rng;
    protected AtomicDouble lr;
    protected double[] expTable;
    protected long seed;
    protected INDArray table;
    protected INDArray syn1Neg;
    protected boolean useAdaGrad;
    protected double negative;
    protected boolean useHS;
    protected VocabCache<T> vocab;
    protected Map<Integer, INDArray> codes;
    protected AdaGrad adaGrad;
    protected Long tableId;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) InMemoryLookupTable.class);
    protected static double MAX_EXP = 6.0d;

    /* loaded from: input_file:org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable$Builder.class */
    public static class Builder<T extends SequenceElement> {
        protected VocabCache<T> vocabCache;
        protected int vectorLength = 100;
        protected boolean useAdaGrad = false;
        protected double lr = 0.025d;
        protected Random gen = Nd4j.getRandom();
        protected long seed = 123;
        protected double negative = 0.0d;
        protected boolean useHS = true;

        public Builder<T> useHierarchicSoftmax(boolean z) {
            this.useHS = z;
            return this;
        }

        public Builder<T> cache(@NonNull VocabCache<T> vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("vocab is marked @NonNull but is null");
            }
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder<T> negative(double d) {
            this.negative = d;
            return this;
        }

        public Builder<T> vectorLength(int i) {
            this.vectorLength = i;
            return this;
        }

        public Builder<T> useAdaGrad(boolean z) {
            this.useAdaGrad = z;
            return this;
        }

        @Deprecated
        public Builder<T> lr(double d) {
            this.lr = d;
            return this;
        }

        public Builder<T> gen(Random random) {
            this.gen = random;
            return this;
        }

        public Builder<T> seed(long j) {
            this.seed = j;
            return this;
        }

        public InMemoryLookupTable<T> build() {
            if (this.vocabCache == null) {
                throw new IllegalStateException("Vocab cache must be specified");
            }
            InMemoryLookupTable<T> inMemoryLookupTable = new InMemoryLookupTable<>(this.vocabCache, this.vectorLength, this.useAdaGrad, this.lr, this.gen, this.negative, this.useHS);
            inMemoryLookupTable.seed = this.seed;
            return inMemoryLookupTable;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable$WeightIterator.class */
    protected class WeightIterator implements Iterator<INDArray> {
        protected int currIndex = 0;

        protected WeightIterator() {
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.currIndex < InMemoryLookupTable.this.syn0.rows();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public INDArray next() {
            INDArray slice = InMemoryLookupTable.this.syn0.slice(this.currIndex);
            this.currIndex++;
            return slice;
        }

        @Override // java.util.Iterator
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }

    public InMemoryLookupTable() {
        this.rng = Nd4j.getRandom();
        this.lr = new AtomicDouble(0.025d);
        this.seed = 123L;
        this.negative = 0.0d;
        this.useHS = true;
        this.codes = new ConcurrentHashMap();
    }

    public InMemoryLookupTable(VocabCache<T> vocabCache, int i, boolean z, double d, Random random, double d2, boolean z2) {
        this(vocabCache, i, z, d, random, d2);
        this.useHS = z2;
    }

    public InMemoryLookupTable(VocabCache<T> vocabCache, int i, boolean z, double d, Random random, double d2) {
        this.rng = Nd4j.getRandom();
        this.lr = new AtomicDouble(0.025d);
        this.seed = 123L;
        this.negative = 0.0d;
        this.useHS = true;
        this.codes = new ConcurrentHashMap();
        this.vocab = vocabCache;
        this.vectorLength = i;
        this.useAdaGrad = z;
        this.lr.set(d);
        this.rng = random;
        this.negative = d2;
        initExpTable();
        if (z) {
            initAdaGrad();
        }
    }

    protected void initAdaGrad() {
        int[] iArr = {this.vocab.numWords() + 1, this.vectorLength};
        int prod = ArrayUtil.prod(iArr);
        this.adaGrad = new AdaGrad(iArr, this.lr.get());
        this.adaGrad.setStateViewArray(Nd4j.zeros(iArr).reshape(1L, prod), iArr, Nd4j.order().charValue(), true);
    }

    public double[] getExpTable() {
        return this.expTable;
    }

    public void setExpTable(double[] dArr) {
        this.expTable = dArr;
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public double getGradient(int i, double d) {
        if (this.adaGrad == null) {
            initAdaGrad();
        }
        return this.adaGrad.getGradient(d, i, ArrayUtil.toInts(this.syn0.shape()));
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public int layerSize() {
        return this.vectorLength;
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void resetWeights(boolean z) {
        if (this.rng == null) {
            this.rng = Nd4j.getRandom();
        }
        this.rng.setSeed(this.seed);
        if (this.syn0 == null || z) {
            this.syn0 = Nd4j.rand(new int[]{this.vocab.numWords(), this.vectorLength}, this.rng).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(this.vectorLength));
        }
        if ((this.syn1 == null || z) && this.useHS) {
            log.info("Initializing syn1...");
            this.syn1 = this.syn0.like();
        }
        initNegative();
    }

    private List<String> fitTnseAndGetLabels(BarnesHutTsne barnesHutTsne, int i) {
        INDArray create = Nd4j.create(i, this.vectorLength);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i && i2 < this.vocab.numWords(); i2++) {
            arrayList.add(this.vocab.wordAtIndex(i2));
            create.putRow(i2, this.syn0.slice(i2));
        }
        barnesHutTsne.fit(create);
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void plotVocab(BarnesHutTsne barnesHutTsne, int i, File file) {
        try {
            barnesHutTsne.saveAsFile(fitTnseAndGetLabels(barnesHutTsne, i), file.getAbsolutePath());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void plotVocab(int i, File file) {
        plotVocab(new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.800000011920929d).numDimension(2).setMaxIter(1000).build(), i, file);
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void plotVocab(int i, UiConnectionInfo uiConnectionInfo) {
        plotVocab(new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.800000011920929d).numDimension(2).setMaxIter(1000).build(), i, uiConnectionInfo);
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void plotVocab(BarnesHutTsne barnesHutTsne, int i, UiConnectionInfo uiConnectionInfo) {
        try {
            List<String> fitTnseAndGetLabels = fitTnseAndGetLabels(barnesHutTsne, i);
            INDArray data = barnesHutTsne.getData();
            StringBuilder sb = new StringBuilder();
            for (int i2 = 0; i2 < data.rows() && i2 < i; i2++) {
                String str = fitTnseAndGetLabels.get(i2);
                INDArray row = data.getRow(i2);
                for (int i3 = 0; i3 < row.length(); i3++) {
                    sb.append(String.valueOf(row.getDouble(i3))).append(",");
                }
                sb.append(str);
            }
            URI uri = new URI(uiConnectionInfo.getFirstPart() + "/tsne/post/" + uiConnectionInfo.getSessionId());
            HttpURLConnection httpURLConnection = (HttpURLConnection) uri.toURL().openConnection();
            httpURLConnection.setRequestMethod("POST");
            httpURLConnection.setRequestProperty("User-Agent", "Mozilla/5.0");
            httpURLConnection.setRequestProperty("Content-Type", "multipart/form-data; boundary=-----TSNE-POST-DATA-----");
            httpURLConnection.setDoOutput(true);
            OutputStream outputStream = httpURLConnection.getOutputStream();
            PrintWriter printWriter = new PrintWriter(outputStream);
            printWriter.println("-------TSNE-POST-DATA-----");
            printWriter.println("Content-Disposition: form-data; name=\"fileupload\"; filename=\"tsne.csv\"");
            printWriter.println("Content-Type: text/plain; charset=UTF-16");
            printWriter.println("Content-Transfer-Encoding: binary");
            printWriter.println();
            printWriter.flush();
            DataOutputStream dataOutputStream = new DataOutputStream(outputStream);
            dataOutputStream.writeBytes(sb.toString());
            dataOutputStream.flush();
            printWriter.println();
            printWriter.flush();
            dataOutputStream.close();
            outputStream.close();
            try {
                int responseCode = httpURLConnection.getResponseCode();
                System.out.println("RESPONSE CODE: " + responseCode);
                if (responseCode != 200) {
                    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(httpURLConnection.getInputStream()));
                    StringBuilder sb2 = new StringBuilder();
                    while (true) {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        } else {
                            sb2.append(readLine);
                        }
                    }
                    bufferedReader.close();
                    log.warn("Error posting to remote UI - received response code {}\tContent: {}", sb2, sb2.toString());
                }
            } catch (IOException e) {
                log.warn("Error posting to remote UI at {}", uri, e);
            }
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void putCode(int i, INDArray iNDArray) {
        this.codes.put(Integer.valueOf(i), iNDArray);
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public INDArray loadCodes(int[] iArr) {
        return this.syn1.getRows(iArr);
    }

    public synchronized void initNegative() {
        if (this.negative <= 0.0d || this.syn1Neg != null) {
            return;
        }
        this.syn1Neg = Nd4j.zeros(this.syn0.shape());
        makeTable(Math.max(this.expTable.length, 100000), 0.75d);
    }

    protected void initExpTable() {
        this.expTable = new double[100000];
        for (int i = 0; i < this.expTable.length; i++) {
            double exp = FastMath.exp((((i / this.expTable.length) * 2.0d) - 1.0d) * MAX_EXP);
            this.expTable[i] = exp / (exp + 1.0d);
        }
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    @Deprecated
    public void iterateSample(T t, T t2, AtomicLong atomicLong, double d) {
        int i;
        double gradient;
        int length;
        if (t2 == null || t2.getIndex() < 0 || t.getIndex() == t2.getIndex() || t.getLabel().equals("STOP") || t2.getLabel().equals("STOP") || t.getLabel().equals(WordVectorsImpl.DEFAULT_UNK) || t2.getLabel().equals(WordVectorsImpl.DEFAULT_UNK)) {
            return;
        }
        INDArray slice = this.syn0.slice(t2.getIndex());
        INDArray create = Nd4j.create(this.vectorLength);
        for (int i2 = 0; i2 < t.getCodeLength(); i2++) {
            byte byteValue = t.getCodes().get(i2).byteValue();
            int intValue = t.getPoints().get(i2).intValue();
            if (intValue >= this.syn0.rows() || intValue < 0) {
                throw new IllegalStateException("Illegal point " + intValue);
            }
            INDArray slice2 = this.syn1.slice(intValue);
            double dot = Nd4j.getBlasWrapper().dot(slice, slice2);
            if (dot >= (-MAX_EXP) && dot < MAX_EXP && (length = (int) ((dot + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))) < this.expTable.length) {
                double d2 = this.expTable[length];
                double gradient2 = this.useAdaGrad ? t.getGradient(i2, (1 - byteValue) - d2, this.lr.get()) : ((1 - byteValue) - d2) * d;
                Nd4j.getBlasWrapper().level1().axpy(slice2.length(), gradient2, slice2, create);
                Nd4j.getBlasWrapper().level1().axpy(slice2.length(), gradient2, slice, slice2);
            }
        }
        int index = t.getIndex();
        if (this.negative > 0.0d) {
            for (int i3 = 0; i3 < this.negative + 1.0d; i3++) {
                if (i3 == 0) {
                    i = 1;
                } else {
                    atomicLong.set((atomicLong.get() * 25214903917L) + 11);
                    index = this.table.getInt((int) Math.abs(((int) (atomicLong.get() >> 16)) % this.table.length()));
                    if (index <= 0) {
                        index = (((int) atomicLong.get()) % (this.vocab.numWords() - 1)) + 1;
                    }
                    if (index != t.getIndex()) {
                        i = 0;
                    }
                }
                if (index < this.syn1Neg.rows() && index >= 0) {
                    double dot2 = Nd4j.getBlasWrapper().dot(slice, this.syn1Neg.slice(index));
                    if (dot2 > MAX_EXP) {
                        gradient = this.useAdaGrad ? t.getGradient(index, i - 1, d) : (i - 1) * d;
                    } else if (dot2 < (-MAX_EXP)) {
                        gradient = i * (this.useAdaGrad ? t.getGradient(index, d, d) : d);
                    } else {
                        gradient = this.useAdaGrad ? t.getGradient(index, i - this.expTable[(int) ((dot2 + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))], d) : (i - this.expTable[(int) ((dot2 + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))]) * d;
                    }
                    if (this.syn0.data().dataType() == DataType.DOUBLE) {
                        Nd4j.getBlasWrapper().axpy(gradient, this.syn1Neg.slice(index), create);
                    } else {
                        Nd4j.getBlasWrapper().axpy((float) gradient, this.syn1Neg.slice(index), create);
                    }
                    if (this.syn0.data().dataType() == DataType.DOUBLE) {
                        Nd4j.getBlasWrapper().axpy(gradient, slice, this.syn1Neg.slice(index));
                    } else {
                        Nd4j.getBlasWrapper().axpy((float) gradient, slice, this.syn1Neg.slice(index));
                    }
                }
            }
        }
        if (this.syn0.data().dataType() == DataType.DOUBLE) {
            Nd4j.getBlasWrapper().axpy(1.0d, create, slice);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, create, slice);
        }
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public void setUseAdaGrad(boolean z) {
        this.useAdaGrad = z;
    }

    public double getNegative() {
        return this.negative;
    }

    public void setUseHS(boolean z) {
        this.useHS = z;
    }

    public void setNegative(double d) {
        this.negative = d;
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    @Deprecated
    public void iterate(T t, T t2) {
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void resetWeights() {
        resetWeights(true);
    }

    protected void makeTable(int i, double d) {
        int rows = this.syn0.rows();
        this.table = Nd4j.create(i);
        double d2 = 0.0d;
        Iterator<String> it2 = this.vocab.words().iterator();
        while (it2.hasNext()) {
            d2 += Math.pow(this.vocab.wordFrequency(it2.next()), d);
        }
        int i2 = 0;
        double pow = Math.pow(this.vocab.wordFrequency(this.vocab.wordAtIndex(0)), d) / d2;
        for (int i3 = 0; i3 < i; i3++) {
            this.table.putScalar(i3, i2);
            if ((i3 * 1.0d) / i > pow) {
                if (i2 < rows - 1) {
                    i2++;
                }
                String wordAtIndex = this.vocab.wordAtIndex(i2);
                String wordAtIndex2 = this.vocab.wordAtIndex(i2);
                if (wordAtIndex != null) {
                    pow += Math.pow(this.vocab.wordFrequency(wordAtIndex2), d) / d2;
                }
            }
        }
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void putVector(String str, INDArray iNDArray) {
        if (str == null) {
            throw new IllegalArgumentException("No null words allowed");
        }
        if (iNDArray == null) {
            throw new IllegalArgumentException("No null vectors allowed");
        }
        this.syn0.slice(this.vocab.indexOf(str)).assign(iNDArray);
    }

    public INDArray getTable() {
        return this.table;
    }

    public void setTable(INDArray iNDArray) {
        this.table = iNDArray;
    }

    public INDArray getSyn1Neg() {
        return this.syn1Neg;
    }

    public void setSyn1Neg(INDArray iNDArray) {
        this.syn1Neg = iNDArray;
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public INDArray vector(String str) {
        if (str == null) {
            return null;
        }
        int indexOf = this.vocab.indexOf(str);
        if (indexOf < 0) {
            indexOf = this.vocab.indexOf(WordVectorsImpl.DEFAULT_UNK);
            if (indexOf < 0) {
                return null;
            }
        }
        return this.syn0.getRow(indexOf, true);
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void setLearningRate(double d) {
        this.lr.set(d);
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public Iterator<INDArray> vectors() {
        return new WeightIterator();
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public INDArray getWeights() {
        return this.syn0;
    }

    public INDArray getSyn0() {
        return this.syn0;
    }

    public void setSyn0(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("syn0 is marked @NonNull but is null");
        }
        Preconditions.checkArgument(!iNDArray.isEmpty(), "syn0 can't be empty");
        Preconditions.checkArgument(iNDArray.rank() == 2, "syn0 must have rank 2");
        this.syn0 = iNDArray;
        this.vectorLength = iNDArray.columns();
    }

    public INDArray getSyn1() {
        return this.syn1;
    }

    public void setSyn1(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("syn1 is marked @NonNull but is null");
        }
        this.syn1 = iNDArray;
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public VocabCache<T> getVocabCache() {
        return this.vocab;
    }

    public void setVectorLength(int i) {
        this.vectorLength = i;
    }

    @Deprecated
    public AtomicDouble getLr() {
        return this.lr;
    }

    public void setLr(AtomicDouble atomicDouble) {
        this.lr = atomicDouble;
    }

    public VocabCache getVocab() {
        return this.vocab;
    }

    public void setVocab(VocabCache vocabCache) {
        this.vocab = vocabCache;
    }

    public Map<Integer, INDArray> getCodes() {
        return this.codes;
    }

    public void setCodes(Map<Integer, INDArray> map) {
        this.codes = map;
    }

    public String toString() {
        return "InMemoryLookupTable{syn0=" + this.syn0 + ", syn1=" + this.syn1 + ", vectorLength=" + this.vectorLength + ", rng=" + this.rng + ", lr=" + this.lr + ", expTable=" + Arrays.toString(this.expTable) + ", seed=" + this.seed + ", table=" + this.table + ", syn1Neg=" + this.syn1Neg + ", useAdaGrad=" + this.useAdaGrad + ", negative=" + this.negative + ", vocab=" + this.vocab + ", codes=" + this.codes + '}';
    }

    public void consume(InMemoryLookupTable<T> inMemoryLookupTable) {
        if (inMemoryLookupTable.vectorLength != this.vectorLength) {
            throw new IllegalStateException("You can't consume lookupTable with different vector lengths");
        }
        if (inMemoryLookupTable.syn0 == null) {
            throw new IllegalStateException("Source lookupTable Syn0 is NULL");
        }
        resetWeights(true);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        AtomicInteger atomicInteger2 = new AtomicInteger(0);
        if (inMemoryLookupTable.syn0.rows() > this.syn0.rows()) {
            throw new IllegalStateException("You can't consume lookupTable with built for larger vocabulary without updating your vocabulary first");
        }
        for (int i = 0; i < inMemoryLookupTable.syn0.rows(); i++) {
            this.syn0.putRow(i, inMemoryLookupTable.syn0.getRow(i));
            if (this.syn1 != null && inMemoryLookupTable.syn1 != null) {
                this.syn1.putRow(i, inMemoryLookupTable.syn1.getRow(i));
            } else if (atomicInteger.incrementAndGet() == 1) {
                log.info("Skipping syn1 merge");
            }
            if (this.syn1Neg != null && inMemoryLookupTable.syn1Neg != null) {
                this.syn1Neg.putRow(i, inMemoryLookupTable.syn1Neg.getRow(i));
            } else if (atomicInteger2.incrementAndGet() == 1) {
                log.info("Skipping syn1Neg merge");
            }
            if (atomicInteger.get() > 0 && atomicInteger2.get() > 0) {
                throw new ND4JIllegalStateException("srcTable has no syn1/syn1neg");
            }
        }
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public Long getTableId() {
        return this.tableId;
    }

    @Override // org.deeplearning4j.models.embeddings.WeightLookupTable
    public void setTableId(Long l) {
        this.tableId = l;
    }
}
