package org.deeplearning4j.eval;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.deeplearning4j.eval.ROC;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;

/* loaded from: input_file:org/deeplearning4j/eval/ROCMultiClass.class */
public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
    private final int thresholdSteps;
    private long[] countActualPositive;
    private long[] countActualNegative;
    private final Map<Integer, Map<Double, ROC.CountsForThreshold>> counts = new LinkedHashMap();

    public ROCMultiClass(int i) {
        this.thresholdSteps = i;
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.rank() == 3 && iNDArray2.rank() == 3) {
            evalTimeSeries(iNDArray, iNDArray2);
        }
        if (iNDArray.rank() > 2 || iNDArray2.rank() > 2 || iNDArray.size(1) != iNDArray2.size(1)) {
            throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(iNDArray.shape()) + ", predictions shape = " + Arrays.toString(iNDArray2.shape()) + "; require rank 2 array with size(1) == 1 or 2");
        }
        double d = 1.0d / this.thresholdSteps;
        if (this.countActualPositive == null) {
            int size = iNDArray.size(1);
            this.countActualPositive = new long[size];
            this.countActualNegative = new long[size];
            for (int i = 0; i < size; i++) {
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                this.counts.put(Integer.valueOf(i), linkedHashMap);
                for (int i2 = 0; i2 <= this.thresholdSteps; i2++) {
                    double d2 = i2 * d;
                    linkedHashMap.put(Double.valueOf(d2), new ROC.CountsForThreshold(d2));
                }
            }
        }
        if (this.countActualPositive.length != iNDArray.size(1)) {
            throw new IllegalArgumentException("Cannot evaluate data: number of label classes does not match previous call. Got " + iNDArray.size(1) + " labels (from array shape " + Arrays.toString(iNDArray.shape()) + ") vs. expected number of label classes = " + this.countActualPositive.length);
        }
        for (int i3 = 0; i3 < this.countActualPositive.length; i3++) {
            INDArray column = iNDArray.getColumn(i3);
            INDArray column2 = iNDArray2.getColumn(i3);
            long intValue = column.sumNumber().intValue();
            long[] jArr = this.countActualPositive;
            int i4 = i3;
            jArr[i4] = jArr[i4] + intValue;
            long[] jArr2 = this.countActualNegative;
            int i5 = i3;
            jArr2[i5] = jArr2[i5] + (column.length() - intValue);
            for (int i6 = 0; i6 <= this.thresholdSteps; i6++) {
                double d3 = i6 * d;
                INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn((Op) new CompareAndSet(Nd4j.getExecutioner().execAndReturn((Op) new CompareAndSet(column2.dup(), 1.0d, Conditions.greaterThanOrEqual(Double.valueOf(d3)))), CMAESOptimizer.DEFAULT_STOPFITNESS, Conditions.lessThanOrEqual(Double.valueOf(d3))));
                INDArray mul = execAndReturn.mul(column);
                INDArray mul2 = execAndReturn.mul(column.rsub(Double.valueOf(1.0d)));
                int intValue2 = mul.sumNumber().intValue();
                int intValue3 = mul2.sumNumber().intValue();
                ROC.CountsForThreshold countsForThreshold = this.counts.get(Integer.valueOf(i3)).get(Double.valueOf(d3));
                countsForThreshold.incrementTruePositive(intValue2);
                countsForThreshold.incrementFalsePositive(intValue3);
            }
        }
    }

    public List<ROC.ROCValue> getResults(int i) {
        assertHasBeenFit(i);
        ArrayList arrayList = new ArrayList(this.counts.size());
        for (Map.Entry<Double, ROC.CountsForThreshold> entry : this.counts.get(Integer.valueOf(i)).entrySet()) {
            double doubleValue = entry.getKey().doubleValue();
            ROC.CountsForThreshold value = entry.getValue();
            arrayList.add(new ROC.ROCValue(doubleValue, value.getCountTruePositive() / this.countActualPositive[i], value.getCountFalsePositive() / this.countActualNegative[i]));
        }
        return arrayList;
    }

    public double[][] getResultsAsArray(int i) {
        assertHasBeenFit(i);
        double[][] dArr = new double[2][this.thresholdSteps + 1];
        int i2 = 0;
        Iterator<Map.Entry<Double, ROC.CountsForThreshold>> it2 = this.counts.get(Integer.valueOf(i)).entrySet().iterator();
        while (it2.hasNext()) {
            ROC.CountsForThreshold value = it2.next().getValue();
            dArr[0][i2] = value.getCountFalsePositive() / this.countActualNegative[i];
            dArr[1][i2] = value.getCountTruePositive() / this.countActualPositive[i];
            i2++;
        }
        return dArr;
    }

    public double calculateAUC(int i) {
        assertHasBeenFit(i);
        List<ROC.ROCValue> results = getResults(i);
        double d = 0.0d;
        for (int i2 = 0; i2 < results.size() - 1; i2++) {
            ROC.ROCValue rOCValue = results.get(i2);
            ROC.ROCValue rOCValue2 = results.get(i2 + 1);
            d += Math.abs(rOCValue2.getFalsePositiveRate() - rOCValue.getFalsePositiveRate()) * ((rOCValue.getTruePositiveRate() + rOCValue2.getTruePositiveRate()) / 2.0d);
        }
        return d;
    }

    public double calculateAverageAUC() {
        assertHasBeenFit(0);
        double d = 0.0d;
        for (int i = 0; i < this.countActualPositive.length; i++) {
            d += calculateAUC(i);
        }
        return d / this.countActualPositive.length;
    }

    public List<ROC.PrecisionRecallPoint> getPrecisionRecallCurve(int i) {
        ArrayList arrayList = new ArrayList(this.counts.get(Integer.valueOf(i)).size());
        for (Map.Entry<Double, ROC.CountsForThreshold> entry : this.counts.get(Integer.valueOf(i)).entrySet()) {
            entry.getKey().doubleValue();
            ROC.CountsForThreshold value = entry.getValue();
            long countTruePositive = value.getCountTruePositive();
            arrayList.add(new ROC.PrecisionRecallPoint(value.getThreshold(), (countTruePositive == 0 && value.getCountFalsePositive() == 0) ? 1.0d : countTruePositive / (countTruePositive + r0), this.countActualPositive[i] == 0 ? 1.0d : countTruePositive / this.countActualPositive[i]));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void merge(ROCMultiClass rOCMultiClass) {
        if (rOCMultiClass.countActualPositive == null) {
            return;
        }
        if (this.countActualPositive == null) {
            this.countActualPositive = Arrays.copyOf(rOCMultiClass.countActualPositive, rOCMultiClass.countActualPositive.length);
            this.countActualNegative = Arrays.copyOf(rOCMultiClass.countActualNegative, rOCMultiClass.countActualNegative.length);
            for (Map.Entry<Integer, Map<Double, ROC.CountsForThreshold>> entry : rOCMultiClass.counts.entrySet()) {
                Map<Double, ROC.CountsForThreshold> value = entry.getValue();
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (Map.Entry<Double, ROC.CountsForThreshold> entry2 : value.entrySet()) {
                    linkedHashMap.put(entry2.getKey(), entry2.getValue().m3375clone());
                }
                this.counts.put(entry.getKey(), linkedHashMap);
            }
            return;
        }
        for (int i = 0; i < this.countActualPositive.length; i++) {
            long[] jArr = this.countActualPositive;
            int i2 = i;
            jArr[i2] = jArr[i2] + rOCMultiClass.countActualPositive[i];
            long[] jArr2 = this.countActualNegative;
            int i3 = i;
            jArr2[i3] = jArr2[i3] + rOCMultiClass.countActualNegative[i];
        }
        for (Integer num : this.counts.keySet()) {
            Map<Double, ROC.CountsForThreshold> map = this.counts.get(num);
            Map<Double, ROC.CountsForThreshold> map2 = rOCMultiClass.counts.get(num);
            for (Double d : map.keySet()) {
                ROC.CountsForThreshold countsForThreshold = map.get(d);
                ROC.CountsForThreshold countsForThreshold2 = map2.get(d);
                countsForThreshold.incrementTruePositive(countsForThreshold2.getCountTruePositive());
                countsForThreshold.incrementFalsePositive(countsForThreshold2.getCountFalsePositive());
            }
        }
    }

    private void assertHasBeenFit(int i) {
        if (this.countActualPositive == null) {
            throw new IllegalStateException("Cannot get results: no data has been collected");
        }
        if (i < 0 || i >= this.countActualPositive.length) {
            throw new IllegalArgumentException("Invalid class index (" + i + "): must be in range 0 to numClasses = " + this.countActualPositive.length);
        }
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public long[] getCountActualPositive() {
        return this.countActualPositive;
    }

    public long[] getCountActualNegative() {
        return this.countActualNegative;
    }

    public Map<Integer, Map<Double, ROC.CountsForThreshold>> getCounts() {
        return this.counts;
    }
}
