package org.deeplearning4j.eval;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.Abs;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/eval/RegressionEvaluation.class */
public class RegressionEvaluation {
    public static final int DEFAULT_PRECISION = 5;
    private List<String> columnNames;
    private int precision;
    private int exampleCount;
    private INDArray labelsSumPerColumn;
    private INDArray sumSquaredErrorsPerColumn;
    private INDArray sumAbsErrorsPerColumn;
    private INDArray currentMean;
    private INDArray currentPredictionMean;
    private INDArray m2Actual;
    private INDArray sumOfProducts;
    private INDArray sumSquaredLabels;
    private INDArray sumSquaredPredicted;

    public RegressionEvaluation(int i) {
        this(createDefaultColumnNames(i), 5);
    }

    public RegressionEvaluation(int i, int i2) {
        this(createDefaultColumnNames(i), i2);
    }

    public RegressionEvaluation(String... strArr) {
        this((List<String>) Arrays.asList(strArr), 5);
    }

    public RegressionEvaluation(List<String> list) {
        this(list, 5);
    }

    public RegressionEvaluation(List<String> list, int i) {
        this.exampleCount = 0;
        this.columnNames = list;
        this.precision = i;
        int size = list.size();
        this.labelsSumPerColumn = Nd4j.zeros(size);
        this.sumSquaredErrorsPerColumn = Nd4j.zeros(size);
        this.sumAbsErrorsPerColumn = Nd4j.zeros(size);
        this.currentMean = Nd4j.zeros(size);
        this.m2Actual = Nd4j.zeros(size);
        this.currentPredictionMean = Nd4j.zeros(size);
        this.sumOfProducts = Nd4j.zeros(size);
        this.sumSquaredLabels = Nd4j.zeros(size);
        this.sumSquaredPredicted = Nd4j.zeros(size);
    }

    private static List<String> createDefaultColumnNames(int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add("col_" + i2);
        }
        return arrayList;
    }

    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        this.labelsSumPerColumn.addi(iNDArray.sum(new int[]{0}));
        INDArray sub = iNDArray2.sub(iNDArray);
        INDArray sum = Nd4j.getExecutioner().execAndReturn(new Abs(sub.dup())).sum(new int[]{0});
        INDArray sum2 = sub.mul(sub).sum(new int[]{0});
        this.sumAbsErrorsPerColumn.addi(sum);
        this.sumSquaredErrorsPerColumn.addi(sum2);
        this.sumOfProducts.addi(iNDArray.mul(iNDArray2).sum(new int[]{0}));
        this.sumSquaredLabels.addi(iNDArray.mul(iNDArray).sum(new int[]{0}));
        this.sumSquaredPredicted.addi(iNDArray2.mul(iNDArray2).sum(new int[]{0}));
        int size = iNDArray.size(0);
        this.currentMean.muli(Integer.valueOf(this.exampleCount)).addi(iNDArray.sum(new int[]{0})).divi(Integer.valueOf(this.exampleCount + size));
        this.currentPredictionMean.muli(Integer.valueOf(this.exampleCount)).addi(iNDArray2.sum(new int[]{0})).divi(Integer.valueOf(this.exampleCount + size));
        this.exampleCount += size;
    }

    public void evalTimeSeries(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.rank() == 2 && iNDArray2.rank() == 2) {
            eval(iNDArray, iNDArray2);
        }
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: labels are not rank 3 (rank=" + iNDArray.rank() + ")");
        }
        if (!Arrays.equals(iNDArray.shape(), iNDArray2.shape())) {
            throw new IllegalArgumentException("Labels and predicted have different shapes: labels=" + Arrays.toString(iNDArray.shape()) + ", predictions=" + Arrays.toString(iNDArray2.shape()));
        }
        if (iNDArray.ordering() == 'f') {
            iNDArray = Shape.toOffsetZeroCopy(iNDArray, 'c');
        }
        if (iNDArray2.ordering() == 'f') {
            iNDArray2 = Shape.toOffsetZeroCopy(iNDArray2, 'c');
        }
        int[] shape = iNDArray.shape();
        eval(iNDArray.permute(new int[]{0, 2, 1}).reshape(shape[0] * shape[2], shape[1]), iNDArray2.permute(new int[]{0, 2, 1}).reshape(shape[0] * shape[2], shape[1]));
    }

    public void evalTimeSeries(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        int intValue = iNDArray3.sumNumber().intValue();
        int size = iNDArray.size(1);
        INDArray create = Nd4j.create(intValue, size);
        INDArray create2 = Nd4j.create(intValue, size);
        int i = 0;
        for (int i2 = 0; i2 < iNDArray3.size(0); i2++) {
            for (int i3 = 0; i3 < iNDArray3.size(1); i3++) {
                if (iNDArray3.getDouble(i2, i3) != 0.0d) {
                    create.putRow(i, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i2), NDArrayIndex.all(), NDArrayIndex.point(i3)}));
                    create2.putRow(i, iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.point(i2), NDArrayIndex.all(), NDArrayIndex.point(i3)}));
                    i++;
                }
            }
        }
        eval(create, create2);
    }

    public String stats() {
        int i = 0;
        Iterator<String> it = this.columnNames.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().length());
        }
        int i2 = i + 5;
        int i3 = this.precision + 10;
        String str = "%-" + i2 + "s%-" + i3 + "." + this.precision + "e%-" + i3 + "." + this.precision + "e%-" + i3 + "." + this.precision + "e%-" + i3 + "." + this.precision + "e%-" + i3 + "." + this.precision + "e";
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("%-" + i2 + "s%-" + i3 + "s%-" + i3 + "s%-" + i3 + "s%-" + i3 + "s%-" + i3 + "s", "Column", "MSE", "MAE", "RMSE", "RSE", "R^2"));
        sb.append("\n");
        for (int i4 = 0; i4 < this.columnNames.size(); i4++) {
            sb.append(String.format(str, this.columnNames.get(i4), Double.valueOf(meanSquaredError(i4)), Double.valueOf(meanAbsoluteError(i4)), Double.valueOf(rootMeanSquaredError(i4)), Double.valueOf(relativeSquaredError(i4)), Double.valueOf(correlationR2(i4))));
            sb.append("\n");
        }
        return sb.toString();
    }

    public int numColumns() {
        return this.columnNames.size();
    }

    public double meanSquaredError(int i) {
        return this.sumSquaredErrorsPerColumn.getDouble(i) / this.exampleCount;
    }

    public double meanAbsoluteError(int i) {
        return this.sumAbsErrorsPerColumn.getDouble(i) / this.exampleCount;
    }

    public double rootMeanSquaredError(int i) {
        return Math.sqrt(this.sumSquaredErrorsPerColumn.getDouble(i) / this.exampleCount);
    }

    public double correlationR2(int i) {
        double d = this.sumOfProducts.getDouble(i);
        double d2 = this.currentPredictionMean.getDouble(i);
        double d3 = this.currentMean.getDouble(i);
        return (d - ((this.exampleCount * d2) * d3)) / (Math.sqrt(this.sumSquaredLabels.getDouble(i) - ((this.exampleCount * d3) * d3)) * Math.sqrt(this.sumSquaredPredicted.getDouble(i) - ((this.exampleCount * d2) * d2)));
    }

    public double relativeSquaredError(int i) {
        double d = (this.sumSquaredPredicted.getDouble(i) - (2.0d * this.sumOfProducts.getDouble(i))) + this.sumSquaredLabels.getDouble(i);
        double d2 = this.sumSquaredLabels.getDouble(i) - ((this.exampleCount * this.currentMean.getDouble(i)) * this.currentMean.getDouble(i));
        if (Math.abs(d2) > Nd4j.EPS_THRESHOLD) {
            return d / d2;
        }
        return Double.POSITIVE_INFINITY;
    }
}
