package edu.emory.mathcs.nlp.learning.optimization;

import edu.emory.mathcs.nlp.learning.optimization.reguralization.Regularizer;
import edu.emory.mathcs.nlp.learning.util.MajorVector;
import edu.emory.mathcs.nlp.learning.util.WeightVector;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.Arrays;

/* loaded from: input_file:edu/emory/mathcs/nlp/learning/optimization/AdaptiveGradientDescentMiniBatch.class */
public abstract class AdaptiveGradientDescentMiniBatch extends AdaptiveGradientDescent {
    private static final long serialVersionUID = -9070887527388228842L;
    protected transient WeightVector gradients;
    protected transient IntSet sparse_updated_indices;
    protected transient IntSet dense_updated_indices;
    protected transient int batch_steps;

    public AdaptiveGradientDescentMiniBatch(WeightVector weightVector, float f, float f2) {
        this(weightVector, f, f2, null);
    }

    public AdaptiveGradientDescentMiniBatch(WeightVector weightVector, float f, float f2, Regularizer regularizer) {
        super(weightVector, f, f2, regularizer);
        this.batch_steps = 1;
        this.gradients = weightVector.createZeroVector();
        this.sparse_updated_indices = new IntOpenHashSet();
        this.dense_updated_indices = new IntOpenHashSet();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.mathcs.nlp.learning.optimization.AdaptiveGradientDescent, edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer
    public boolean expand(int i, int i2, int i3) {
        boolean expand = super.expand(i, i2, i3);
        if (expand) {
            this.gradients.expand(i, i2, i3);
        }
        return expand;
    }

    @Override // edu.emory.mathcs.nlp.learning.optimization.StochasticGradientDescent
    protected void updateWeight(int i, int i2, float f, boolean z) {
        MajorVector majorVector = this.gradients.getMajorVector(z);
        int indexOf = majorVector.indexOf(i, i2);
        majorVector.add(indexOf, f);
        if (z) {
            this.sparse_updated_indices.add(indexOf);
        } else {
            this.dense_updated_indices.add(indexOf);
        }
    }

    @Override // edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer
    public void updateMiniBatch() {
        update(true);
        update(false);
        this.batch_steps++;
    }

    protected void update(boolean z) {
        IntSet intSet = z ? this.sparse_updated_indices : this.dense_updated_indices;
        MajorVector majorVector = this.weight_vector.getMajorVector(z);
        MajorVector majorVector2 = this.diagonals.getMajorVector(z);
        MajorVector majorVector3 = this.gradients.getMajorVector(z);
        int[] intArray = intSet.toIntArray();
        Arrays.sort(intArray);
        updateDiagonals(majorVector2, majorVector3, intArray);
        updateWeights(majorVector, majorVector3, intArray, z);
        clearGraidents(majorVector3, intArray);
        intSet.clear();
    }

    protected void updateDiagonals(MajorVector majorVector, MajorVector majorVector2, int[] iArr) {
        for (int i : iArr) {
            majorVector.set(i, getDiagonal(majorVector.get(i), majorVector2.get(i)));
        }
    }

    protected void updateWeights(MajorVector majorVector, MajorVector majorVector2, int[] iArr, boolean z) {
        for (int i : iArr) {
            if (isL1Regularization()) {
                this.l1_regularizer.updateWeight(i, majorVector2.get(i), getLearningRate(i, z), this.batch_steps, z);
            } else {
                majorVector.add(i, majorVector2.get(i) * getLearningRate(i, z));
            }
        }
    }

    protected void clearGraidents(MajorVector majorVector, int[] iArr) {
        for (int i : iArr) {
            majorVector.set(i, 0.0f);
        }
    }

    protected abstract float getDiagonal(float f, float f2);
}
