package edu.emory.mathcs.nlp.learning.optimization.method;

import edu.emory.mathcs.nlp.learning.activation.SoftmaxFunction;
import edu.emory.mathcs.nlp.learning.optimization.AdaptiveGradientDescent;
import edu.emory.mathcs.nlp.learning.util.FeatureVector;
import edu.emory.mathcs.nlp.learning.util.Instance;
import edu.emory.mathcs.nlp.learning.util.SparseItem;
import edu.emory.mathcs.nlp.learning.util.WeightVector;
import java.util.Iterator;

/* loaded from: input_file:edu/emory/mathcs/nlp/learning/optimization/method/AdaGradRegression.class */
public class AdaGradRegression extends AdaptiveGradientDescent {
    private static final long serialVersionUID = 6397042389113367031L;

    public AdaGradRegression(WeightVector weightVector, float f, float f2) {
        super(weightVector, f, f2, null);
        if (weightVector.hasActivationFunction()) {
            return;
        }
        weightVector.setActivationFunction(new SoftmaxFunction());
    }

    private void updateDiagonals(Instance instance, float[] fArr) {
        FeatureVector featureVector = instance.getFeatureVector();
        Iterator<SparseItem> it = featureVector.getSparseVector().iterator();
        while (it.hasNext()) {
            SparseItem next = it.next();
            for (int i = 0; i < fArr.length; i++) {
                updateDiagonal(i, next.getIndex(), fArr[i] * next.getValue(), true);
            }
        }
        if (featureVector.hasDenseVector()) {
            float[] denseVector = featureVector.getDenseVector();
            for (int i2 = 0; i2 < fArr.length; i2++) {
                for (int i3 = 0; i3 < denseVector.length; i3++) {
                    updateDiagonal(i2, i3, fArr[i2] * denseVector[i3], false);
                }
            }
        }
    }

    @Override // edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer
    public void trainAux(Instance instance) {
        float[] gradientsRegression = getGradientsRegression(instance);
        updateDiagonals(instance, gradientsRegression);
        trainRegression(instance, gradientsRegression);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer
    public int getPredictedLabel(Instance instance) {
        return getPredictedLabelRegression(instance);
    }

    @Override // edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer
    public void updateMiniBatch() {
    }

    public String toString() {
        return "AdaGrad Regression";
    }
}
