package edu.emory.mathcs.nlp.learning.util;

import edu.emory.mathcs.nlp.common.util.DSUtils;
import it.unimi.dsi.fastutil.ints.IntCollection;
import it.unimi.dsi.fastutil.ints.IntIterator;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:edu/emory/mathcs/nlp/learning/util/MLUtils.class */
public class MLUtils {
    public static void softmax(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) FastMath.exp(fArr[i]);
            f += fArr[i];
        }
        for (int i2 = 0; i2 < fArr.length; i2++) {
            int i3 = i2;
            fArr[i3] = fArr[i3] / f;
        }
    }

    public static int argmax(float[] fArr) {
        return argmax(fArr, fArr.length);
    }

    public static int argmax(float[] fArr, int i) {
        int i2 = 0;
        double d = fArr[0];
        for (int i3 = 1; i3 < i; i3++) {
            if (d < fArr[i3]) {
                i2 = i3;
                d = fArr[i2];
            }
        }
        return i2;
    }

    public static int argmax(float[] fArr, IntCollection intCollection) {
        if (intCollection == null || intCollection.isEmpty()) {
            return argmax(fArr);
        }
        float f = -3.4028235E38f;
        int i = -1;
        IntIterator it = intCollection.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            if (f < fArr[intValue]) {
                i = intValue;
                f = fArr[intValue];
            }
        }
        return i;
    }

    public static int[] argmax2(float[] fArr) {
        return argmax2(fArr, fArr.length);
    }

    public static int[] argmax2(float[] fArr, int i) {
        if (i < 2) {
            return new int[]{0, -1};
        }
        int[] iArr = {0, 1};
        if (fArr[0] < fArr[1]) {
            iArr[0] = 1;
            iArr[1] = 0;
        }
        for (int i2 = 2; i2 < i; i2++) {
            if (fArr[iArr[0]] < fArr[i2]) {
                iArr[1] = iArr[0];
                iArr[0] = i2;
            } else if (fArr[iArr[1]] < fArr[i2]) {
                iArr[1] = i2;
            }
        }
        return iArr;
    }

    public static int[] argmax2(float[] fArr, IntCollection intCollection) {
        if (intCollection == null || intCollection.isEmpty()) {
            return argmax2(fArr);
        }
        IntIterator it = intCollection.iterator();
        if (intCollection.size() < 2) {
            return new int[]{it.nextInt(), -1};
        }
        int[] iArr = {it.nextInt(), it.nextInt()};
        if (fArr[iArr[0]] < fArr[iArr[1]]) {
            DSUtils.swap(iArr, 0, 1);
        }
        while (it.hasNext()) {
            int nextInt = it.nextInt();
            if (fArr[iArr[0]] < fArr[nextInt]) {
                iArr[1] = iArr[0];
                iArr[0] = nextInt;
            } else if (fArr[iArr[1]] < fArr[nextInt]) {
                iArr[1] = nextInt;
            }
        }
        return iArr;
    }
}
