package edu.emory.mathcs.nlp.learning.neural;

import edu.emory.mathcs.nlp.learning.activation.ActivationFunction;
import edu.emory.mathcs.nlp.learning.activation.SoftmaxFunction;
import edu.emory.mathcs.nlp.learning.initialization.WeightGenerator;
import edu.emory.mathcs.nlp.learning.util.FeatureVector;
import edu.emory.mathcs.nlp.learning.util.Instance;
import edu.emory.mathcs.nlp.learning.util.MajorVector;
import edu.emory.mathcs.nlp.learning.util.SparseItem;
import java.util.Arrays;
import java.util.Iterator;

/* loaded from: input_file:edu/emory/mathcs/nlp/learning/neural/FeedForwardNeuralNetworkSoftmax.class */
public class FeedForwardNeuralNetworkSoftmax extends FeedForwardNeuralNetwork {
    private static final long serialVersionUID = 7122005284712284931L;

    public FeedForwardNeuralNetworkSoftmax(int[] iArr, ActivationFunction[] activationFunctionArr, float f, float f2, WeightGenerator weightGenerator) {
        super(iArr, activationFunctionArr, f, f2, weightGenerator);
    }

    public FeedForwardNeuralNetworkSoftmax(int[] iArr, ActivationFunction[] activationFunctionArr, float f, float f2, WeightGenerator weightGenerator, float[] fArr) {
        super(iArr, activationFunctionArr, f, f2, weightGenerator, fArr);
    }

    @Override // edu.emory.mathcs.nlp.learning.neural.FeedForwardNeuralNetwork
    protected ActivationFunction createActivationFunctionH2O() {
        return new SoftmaxFunction();
    }

    @Override // edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer
    protected int getPredictedLabel(Instance instance) {
        return getPredictedLabelRegression(instance);
    }

    @Override // edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer
    protected float getLearningRate(int i, boolean z) {
        return this.learning_rate;
    }

    @Override // edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer
    public void updateMiniBatch() {
    }

    @Override // edu.emory.mathcs.nlp.learning.neural.FeedForwardNeuralNetwork
    protected float[] backwardPropagationO2H(Instance instance, float[] fArr) {
        float[] copyOf = Arrays.copyOf(instance.getScores(), getLabelSize());
        float[] gradientsRegression = getGradientsRegression(instance);
        for (int i = 0; i < gradientsRegression.length; i++) {
            gradientsRegression[i] = (-1.0f) * gradientsRegression[i];
        }
        float[] fArr2 = new float[fArr.length];
        MajorVector denseWeightVector = this.w_h2o.getDenseWeightVector();
        for (int i2 = 0; i2 < gradientsRegression.length; i2++) {
            for (int i3 = 0; i3 < fArr.length; i3++) {
                if (this.sampled_thinned_network[this.sampled_thinned_network.length - 1][1 + i3]) {
                    int indexOf = denseWeightVector.indexOf(i2, i3);
                    int i4 = i3;
                    fArr2[i4] = fArr2[i4] + (gradientsRegression[i2] * copyOf[i2] * denseWeightVector.get(indexOf));
                    denseWeightVector.add(indexOf, (-1.0f) * getLearningRate(indexOf, false) * gradientsRegression[i2] * fArr[i3]);
                }
            }
        }
        return fArr2;
    }

    @Override // edu.emory.mathcs.nlp.learning.neural.FeedForwardNeuralNetwork
    protected float[] backwardPropagationH2H(MajorVector majorVector, float[] fArr, float[] fArr2, float[] fArr3, int i) {
        float[] fArr4 = new float[fArr2.length];
        for (int i2 = 0; i2 < fArr.length; i2++) {
            for (int i3 = 0; i3 < fArr2.length; i3++) {
                if (this.sampled_thinned_network[i + 1][1 + i3] && this.sampled_thinned_network[i + 2][1 + i2]) {
                    int indexOf = majorVector.indexOf(i2, i3);
                    int i4 = i3;
                    fArr4[i4] = fArr4[i4] + (fArr[i2] * majorVector.get(indexOf));
                    majorVector.add(indexOf, (-1.0f) * getLearningRate(indexOf, false) * fArr[i2] * fArr2[i3]);
                }
            }
        }
        return fArr4;
    }

    @Override // edu.emory.mathcs.nlp.learning.neural.FeedForwardNeuralNetwork
    protected void backwardPropagationH2I(FeatureVector featureVector, float[] fArr, float[] fArr2) {
        if (featureVector.hasSparseVector()) {
            MajorVector sparseWeightVector = this.weight_vector.getSparseWeightVector();
            Iterator<SparseItem> it = featureVector.getSparseVector().iterator();
            while (it.hasNext()) {
                SparseItem next = it.next();
                for (int i = 0; i < fArr.length; i++) {
                    if (this.sampled_thinned_network[0][next.getIndex()] && this.sampled_thinned_network[1][1 + i]) {
                        sparseWeightVector.add(sparseWeightVector.indexOf(i, next.getIndex()), fArr[i] * next.getValue());
                    }
                }
            }
        }
        if (featureVector.hasDenseVector()) {
            MajorVector denseWeightVector = this.weight_vector.getDenseWeightVector();
            float[] denseVector = featureVector.getDenseVector();
            for (int i2 = 0; i2 < fArr.length; i2++) {
                for (int i3 = 0; i3 < denseVector.length; i3++) {
                    if (this.sampled_thinned_network[0][featureVector.getSparseVector().maxIndex() + 1 + i3] && this.sampled_thinned_network[1][1 + i2]) {
                        int indexOf = denseWeightVector.indexOf(i2, i3);
                        denseWeightVector.add(indexOf, (-1.0f) * getLearningRate(indexOf, false) * fArr[i2] * denseVector[i3]);
                    }
                }
            }
        }
    }
}
