package edu.emory.mathcs.nlp.learning.optimization;

import edu.emory.mathcs.nlp.common.util.MathUtils;
import edu.emory.mathcs.nlp.component.template.train.HyperParameter;
import edu.emory.mathcs.nlp.learning.optimization.reguralization.Regularizer;
import edu.emory.mathcs.nlp.learning.util.FeatureVector;
import edu.emory.mathcs.nlp.learning.util.Instance;
import edu.emory.mathcs.nlp.learning.util.LabelMap;
import edu.emory.mathcs.nlp.learning.util.MLUtils;
import edu.emory.mathcs.nlp.learning.util.SparseVector;
import edu.emory.mathcs.nlp.learning.util.WeightVector;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.StringJoiner;

/* loaded from: input_file:edu/emory/mathcs/nlp/learning/optimization/OnlineOptimizer.class */
public abstract class OnlineOptimizer implements Serializable {
    private static final long serialVersionUID = -7750497048585331648L;
    protected WeightVector weight_vector;
    protected LabelMap label_map;
    protected float bias;
    protected transient Regularizer l1_regularizer;
    protected transient float learning_rate;
    protected transient int steps;

    public OnlineOptimizer(WeightVector weightVector, float f, float f2) {
        this(weightVector, f, f2, null);
    }

    public OnlineOptimizer(WeightVector weightVector, float f, float f2, Regularizer regularizer) {
        this.label_map = new LabelMap();
        setWeightVector(weightVector);
        setBias(f2);
        setLearningRate(f);
        setL1Regularizer(regularizer);
        this.steps = 1;
    }

    public void adapt(HyperParameter hyperParameter) {
        setL1Regularizer(hyperParameter.getL1Regularizer());
        setLearningRate(hyperParameter.getLearningRate());
    }

    public WeightVector getWeightVector() {
        return this.weight_vector;
    }

    public void setWeightVector(WeightVector weightVector) {
        this.weight_vector = weightVector;
    }

    public float getLearningRate() {
        return this.learning_rate;
    }

    public void setLearningRate(float f) {
        this.learning_rate = f;
    }

    public float getBias() {
        return this.bias;
    }

    public void setBias(float f) {
        this.bias = f;
    }

    public Regularizer getL1Regularizer() {
        return this.l1_regularizer;
    }

    public void setL1Regularizer(Regularizer regularizer) {
        this.l1_regularizer = regularizer;
        if (isL1Regularization()) {
            this.l1_regularizer.setWeightVector(this.weight_vector);
        }
    }

    public boolean isL1Regularization() {
        return this.l1_regularizer != null;
    }

    public void setLabelMap(LabelMap labelMap) {
        this.label_map = labelMap;
    }

    public LabelMap getLabelMap() {
        return this.label_map;
    }

    public String getLabel(int i) {
        return this.label_map.getLabel(i);
    }

    public int getLabelIndex(String str) {
        return this.label_map.index(str);
    }

    public int[] getLabelIndexArray(Collection<String> collection) {
        return collection.stream().mapToInt(str -> {
            return getLabelIndex(str);
        }).toArray();
    }

    public int getLabelSize() {
        return this.label_map.size();
    }

    public int addLabel(String str) {
        return this.label_map.add(str);
    }

    public void addLabels(Collection<String> collection) {
        Iterator<String> it = collection.iterator();
        while (it.hasNext()) {
            addLabel(it.next());
        }
    }

    public void train(Instance instance) {
        train(instance, true);
    }

    public void train(Instance instance, boolean z) {
        if (z) {
            augment(instance);
        }
        expand(instance.getFeatureVector());
        if (instance.hasScores() && instance.getScores().length == getLabelSize()) {
            addScores(instance.getFeatureVector(), instance.getScores());
        } else {
            instance.setScores(scores(instance.getFeatureVector()));
        }
        int predictedLabel = getPredictedLabel(instance);
        instance.setPredictedLabel(predictedLabel);
        if (!instance.isGoldLabel(predictedLabel)) {
            trainAux(instance);
        }
        this.steps++;
    }

    public void augment(Instance instance) {
        if (instance.hasStringLabel()) {
            instance.setGoldLabel(addLabel(instance.getStringLabel()));
        }
        augment(instance.getFeatureVector());
    }

    public void augment(FeatureVector featureVector) {
        if (featureVector.hasSparseVector()) {
            featureVector.getSparseVector().addBias(this.bias);
        } else {
            featureVector.setSparseVector(new SparseVector(this.bias));
        }
    }

    protected void expand(FeatureVector featureVector) {
        expand(featureVector.hasSparseVector() ? featureVector.getSparseVector().maxIndex() + 1 : 0, featureVector.hasDenseVector() ? featureVector.getDenseVector().length : 0, getLabelSize());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean expand(int i, int i2, int i3) {
        boolean expand = this.weight_vector.expand(i, i2, i3);
        if (expand && isL1Regularization()) {
            this.l1_regularizer.expand(i, i2, i3);
        }
        return expand;
    }

    protected abstract void trainAux(Instance instance);

    public abstract void updateMiniBatch();

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract int getPredictedLabel(Instance instance);

    /* JADX INFO: Access modifiers changed from: protected */
    public int getPredictedLabelHingeLoss(Instance instance) {
        float[] scores = instance.getScores();
        int goldLabel = instance.getGoldLabel();
        scores[goldLabel] = scores[goldLabel] - 1.0f;
        return argmax(scores);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getPredictedLabelRegression(Instance instance) {
        float[] scores = instance.getScores();
        int goldLabel = instance.getGoldLabel();
        return 1.0f <= scores[goldLabel] ? goldLabel : argmax(scores);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float[] getGradientsRegression(Instance instance) {
        float[] copyOf = Arrays.copyOf(instance.getScores(), getLabelSize());
        MathUtils.multiply(copyOf, -1);
        int goldLabel = instance.getGoldLabel();
        copyOf[goldLabel] = copyOf[goldLabel] + 1.0f;
        return copyOf;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract float getLearningRate(int i, boolean z);

    /* JADX INFO: Access modifiers changed from: protected */
    public int argmax(float[] fArr) {
        int argmax = MLUtils.argmax(fArr, getLabelSize());
        return (fArr[argmax] != 0.0f || argmax <= 0) ? argmax : MLUtils.argmax(fArr, argmax);
    }

    public String toString(String str, String... strArr) {
        StringJoiner stringJoiner = new StringJoiner(", ");
        stringJoiner.add("learning rate = " + this.learning_rate);
        stringJoiner.add("bias = " + this.bias);
        if (isL1Regularization()) {
            stringJoiner.add("l1 = " + this.l1_regularizer.getRate());
        }
        for (String str2 : strArr) {
            if (str2 != null) {
                stringJoiner.add(str2);
            }
        }
        return str + ": " + stringJoiner.toString();
    }

    public float[] scores(FeatureVector featureVector) {
        return scores(featureVector, true);
    }

    public float[] scores(FeatureVector featureVector, boolean z) {
        if (z) {
            augment(featureVector);
        }
        return this.weight_vector.scores(featureVector);
    }

    public void addScores(FeatureVector featureVector, float[] fArr) {
        this.weight_vector.addScores(featureVector, fArr);
    }
}
