package org.deeplearning4j.models.embeddings.learning.impl.elements;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.glove.AbstractCoOccurrences;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.primitives.Counter;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/*  JADX ERROR: NullPointerException in pass: ClassModifier
    java.lang.NullPointerException
    */
/*  JADX ERROR: NullPointerException in pass: ProcessKotlinInternals
    java.lang.NullPointerException
    */
/* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.class */
public class GloVe<T extends SequenceElement> implements ElementsLearningAlgorithm<T> {
    private VocabCache<T> vocabCache;
    private AbstractCoOccurrences<T> coOccurrences;
    private WeightLookupTable<T> lookupTable;
    private VectorsConfiguration configuration;
    private INDArray syn0;
    private double xMax;
    private boolean shuffle;
    private boolean symmetric;
    private AdaGrad weightAdaGrad;
    private AdaGrad biasAdaGrad;
    private INDArray bias;
    private int vectorLength;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) GloVe.class);
    private AtomicBoolean isTerminate = new AtomicBoolean(false);
    protected double alpha = 0.75d;
    protected double learningRate = 0.0d;
    protected int maxmemory = 0;
    protected int batchSize = 1000;
    private int workers = Runtime.getRuntime().availableProcessors();

    /* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe$Builder.class */
    public static class Builder<T extends SequenceElement> {
        protected double xMax = 100.0d;
        protected double alpha = 0.75d;
        protected double learningRate = 0.0d;
        protected boolean shuffle = false;
        protected boolean symmetric = false;
        protected int maxmemory = 0;
        protected int batchSize = 1000;

        public Builder<T> batchSize(int i) {
            this.batchSize = i;
            return this;
        }

        public Builder<T> learningRate(double d) {
            this.learningRate = d;
            return this;
        }

        public Builder<T> alpha(double d) {
            this.alpha = d;
            return this;
        }

        public Builder<T> maxMemory(int i) {
            this.maxmemory = i;
            return this;
        }

        public Builder<T> xMax(double d) {
            this.xMax = d;
            return this;
        }

        public Builder<T> shuffle(boolean z) {
            this.shuffle = z;
            return this;
        }

        public Builder<T> symmetric(boolean z) {
            this.symmetric = z;
            return this;
        }

        /*  JADX ERROR: JadxRuntimeException in pass: InlineMethods
            jadx.core.utils.exceptions.JadxRuntimeException: Failed to process method for inline: org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe.access$402(org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe, double):double
            	at jadx.core.dex.visitors.InlineMethods.processInvokeInsn(InlineMethods.java:74)
            	at jadx.core.dex.visitors.InlineMethods.visit(InlineMethods.java:49)
            Caused by: jadx.core.utils.exceptions.JadxRuntimeException: Class not yet loaded at codegen stage: org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe
            	at jadx.core.dex.nodes.ClassNode.reloadAtCodegenStage(ClassNode.java:883)
            	at jadx.core.dex.visitors.InlineMethods.processInvokeInsn(InlineMethods.java:66)
            	... 1 more
            */
        public org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe<T> build() {
            /*
                r4 = this;
                org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe r0 = new org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe
                r1 = r0
                r1.<init>()
                r5 = r0
                r0 = r5
                r1 = r4
                boolean r1 = r1.symmetric
                boolean r0 = org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe.access$302(r0, r1)
                r0 = r5
                r1 = r4
                boolean r1 = r1.shuffle
                boolean r0 = org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe.access$002(r0, r1)
                r0 = r5
                r1 = r4
                double r1 = r1.xMax
                double r0 = org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe.access$402(r0, r1)
                r0 = r5
                r1 = r4
                double r1 = r1.alpha
                r0.alpha = r1
                r0 = r5
                r1 = r4
                double r1 = r1.learningRate
                r0.learningRate = r1
                r0 = r5
                r1 = r4
                int r1 = r1.maxmemory
                r0.maxmemory = r1
                r0 = r5
                r1 = r4
                int r1 = r1.batchSize
                r0.batchSize = r1
                r0 = r5
                return r0
            */
            throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe.Builder.build():org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe");
        }
    }

    /* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe$GloveCalculationsThread.class */
    private class GloveCalculationsThread extends Thread implements Runnable {
        private final int threadId;
        private final int epochId;
        private final Iterator<Pair<Pair<T, T>, Double>> coList;
        private final AtomicLong pairsCounter;
        private final Counter<Integer> errorCounter;

        public GloveCalculationsThread(int i, int i2, @NonNull Iterator<Pair<Pair<T, T>, Double>> it2, @NonNull AtomicLong atomicLong, @NonNull Counter<Integer> counter) {
            if (it2 == null) {
                throw new NullPointerException("pairs is marked @NonNull but is null");
            }
            if (atomicLong == null) {
                throw new NullPointerException("pairsCounter is marked @NonNull but is null");
            }
            if (counter == null) {
                throw new NullPointerException("errorCounter is marked @NonNull but is null");
            }
            this.epochId = i;
            this.threadId = i2;
            this.pairsCounter = atomicLong;
            this.errorCounter = counter;
            this.coList = it2;
            setName("GloVe ELA t." + this.threadId);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (this.coList.hasNext()) {
                ArrayList<Pair> arrayList = new ArrayList();
                for (int i = 0; this.coList.hasNext() && i < GloVe.this.batchSize; i++) {
                    arrayList.add(this.coList.next());
                }
                if (GloVe.this.shuffle) {
                    Collections.shuffle(arrayList);
                }
                for (Pair pair : arrayList) {
                    SequenceElement sequenceElement = (SequenceElement) ((Pair) pair.getFirst()).getFirst();
                    SequenceElement sequenceElement2 = (SequenceElement) ((Pair) pair.getFirst()).getSecond();
                    double doubleValue = ((Double) pair.getSecond()).doubleValue();
                    if (doubleValue <= 0.0d) {
                        this.pairsCounter.incrementAndGet();
                    } else {
                        this.errorCounter.incrementCount(Integer.valueOf(this.epochId), GloVe.this.iterateSample(sequenceElement, sequenceElement2, doubleValue));
                        if (this.pairsCounter.incrementAndGet() % 1000000 == 0) {
                            GloVe.log.info("Processed [" + this.pairsCounter.get() + "] word pairs so far...");
                        }
                    }
                }
            }
        }
    }

    public GloVe() {
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public String getCodeName() {
        return "GloVe";
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void finish() {
        log.info("GloVe finalizer...");
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> weightLookupTable, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache is marked @NonNull but is null");
        }
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable is marked @NonNull but is null");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = weightLookupTable;
        this.configuration = vectorsConfiguration;
        this.syn0 = ((InMemoryLookupTable) weightLookupTable).getSyn0();
        this.vectorLength = vectorsConfiguration.getLayersSize();
        if (this.learningRate == 0.0d) {
            this.learningRate = vectorsConfiguration.getLearningRate();
        }
        this.weightAdaGrad = new AdaGrad(new int[]{this.vocabCache.numWords() + 1, this.vectorLength}, this.learningRate);
        this.bias = Nd4j.create(this.syn0.rows());
        this.biasAdaGrad = new AdaGrad(ArrayUtil.toInts(this.bias.shape()), this.learningRate);
        log.info("GloVe params: {Max Memory: [" + this.maxmemory + "], Learning rate: [" + this.learningRate + "], Alpha: [" + this.alpha + "], xMax: [" + this.xMax + "], Symmetric: [" + this.symmetric + "], Shuffle: [" + this.shuffle + "]}");
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void pretrain(@NonNull SequenceIterator<T> sequenceIterator) {
        if (sequenceIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        this.coOccurrences = new AbstractCoOccurrences.Builder().symmetric(this.symmetric).windowSize(this.configuration.getWindow()).iterate(sequenceIterator).workers(this.workers).vocabCache(this.vocabCache).maxMemory(this.maxmemory).build();
        this.coOccurrences.fit();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public double learnSequence(Sequence<T> sequence, AtomicLong atomicLong, double d, BatchSequences<T> batchSequences) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public synchronized double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong, double d) {
        if (sequence == null) {
            throw new NullPointerException("sequence is marked @NonNull but is null");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom is marked @NonNull but is null");
        }
        if (this.isTerminate.get()) {
            return 0.0d;
        }
        AtomicLong atomicLong2 = new AtomicLong(0L);
        Counter counter = new Counter();
        for (int i = 0; i < this.configuration.getEpochs(); i++) {
            Iterator<Pair<Pair<T, T>, Double>> it2 = this.coOccurrences.iterator();
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < this.workers; i2++) {
                arrayList.add(i2, new GloveCalculationsThread(i, i2, it2, atomicLong2, counter));
                ((GloveCalculationsThread) arrayList.get(i2)).start();
            }
            for (int i3 = 0; i3 < this.workers; i3++) {
                try {
                    ((GloveCalculationsThread) arrayList.get(i3)).join();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            log.info("Processed [" + atomicLong2.get() + "] pairs, Error was [" + counter.getCount(Integer.valueOf(i)) + "]");
        }
        this.isTerminate.set(true);
        return 0.0d;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public synchronized boolean isEarlyTerminationHit() {
        return this.isTerminate.get();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double iterateSample(T t, T t2, double d) {
        if (t.getIndex() < 0 || t.getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + t.getLabel());
        }
        if (t2.getIndex() < 0 || t2.getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + t2.getLabel());
        }
        INDArray slice = this.syn0.slice(t.getIndex());
        INDArray slice2 = this.syn0.slice(t2.getIndex());
        double dot = Nd4j.getBlasWrapper().dot(slice, slice2) + ((this.bias.getDouble(t.getIndex()) + this.bias.getDouble(t2.getIndex())) - Math.log(d));
        double pow = d > this.xMax ? dot : Math.pow(d / this.xMax, this.alpha) * dot;
        if (Double.isNaN(pow)) {
            pow = Nd4j.EPS_THRESHOLD;
        }
        double d2 = pow * this.learningRate;
        update(t, slice, slice2, d2);
        update(t2, slice2, slice, d2);
        return 0.5d * pow * dot;
    }

    private void update(T t, INDArray iNDArray, INDArray iNDArray2, double d) {
        iNDArray.subi(this.weightAdaGrad.getGradient(iNDArray2.mul(Double.valueOf(d)), t.getIndex(), ArrayUtil.toInts(this.syn0.shape())));
        this.bias.putScalar(t.getIndex(), this.bias.getDouble(t.getIndex()) - this.biasAdaGrad.getGradient(d, t.getIndex(), ArrayUtil.toInts(this.bias.shape())));
    }

    /*  JADX ERROR: Failed to decode insn: 0x0002: MOVE_MULTI, method: org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe.access$402(org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe, double):double
        java.lang.ArrayIndexOutOfBoundsException: arraycopy: source index -1 out of bounds for object array[6]
        	at java.base/java.lang.System.arraycopy(Native Method)
        	at jadx.plugins.input.java.data.code.StackState.insert(StackState.java:49)
        	at jadx.plugins.input.java.data.code.CodeDecodeState.insert(CodeDecodeState.java:118)
        	at jadx.plugins.input.java.data.code.JavaInsnsRegister.dup2x1(JavaInsnsRegister.java:313)
        	at jadx.plugins.input.java.data.code.JavaInsnData.decode(JavaInsnData.java:46)
        	at jadx.core.dex.instructions.InsnDecoder.lambda$process$0(InsnDecoder.java:54)
        	at jadx.plugins.input.java.data.code.JavaCodeReader.visitInstructions(JavaCodeReader.java:81)
        	at jadx.core.dex.instructions.InsnDecoder.process(InsnDecoder.java:50)
        	at jadx.core.dex.nodes.MethodNode.load(MethodNode.java:156)
        	at jadx.core.dex.nodes.ClassNode.load(ClassNode.java:443)
        	at jadx.core.ProcessClass.process(ProcessClass.java:70)
        	at jadx.core.ProcessClass.generateCode(ProcessClass.java:118)
        	at jadx.core.dex.nodes.ClassNode.generateClassCode(ClassNode.java:400)
        	at jadx.core.dex.nodes.ClassNode.decompile(ClassNode.java:388)
        	at jadx.core.dex.nodes.ClassNode.getCode(ClassNode.java:338)
        */
    static /* synthetic */ double access$402(org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe r6, double r7) {
        /*
            r0 = r6
            r1 = r7
            // decode failed: arraycopy: source index -1 out of bounds for object array[6]
            r0.xMax = r1
            return r-1
        */
        throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe.access$402(org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe, double):double");
    }

    static {
    }
}
