package org.deeplearning4j.util;

import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.deeplearning4j.datasets.DataSet;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.jblas.SimpleBlas;
import org.jblas.ranges.IntervalRange;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/util/MatrixUtil.class */
public class MatrixUtil {
    private static Logger log = LoggerFactory.getLogger(MatrixUtil.class);

    public static void complainAboutMissMatchedMatrices(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        if (doubleMatrix == null || doubleMatrix2 == null) {
            throw new IllegalArgumentException("No null matrices allowed");
        }
        if (doubleMatrix.rows != doubleMatrix2.rows) {
            throw new IllegalArgumentException("Matrices must have same rows");
        }
    }

    public static DataSet xorData(int i) {
        DoubleMatrix gti = DoubleMatrix.rand(i, 2).gti(0.5d);
        DoubleMatrix zeros = DoubleMatrix.zeros(i, 2);
        for (int i2 = 0; i2 < gti.rows; i2++) {
            if (gti.get(i2, 0) == gti.get(i2, 1)) {
                zeros.put(i2, 0, 1.0d);
            } else {
                zeros.put(i2, 1, 1.0d);
            }
        }
        return new DataSet(gti, zeros);
    }

    public static DataSet xorData(int i, int i2) {
        DoubleMatrix eq = DoubleMatrix.rand(i, i2).gti(0.5d).eq(DoubleMatrix.rand(i, i2).gti(0.5d)).eq(DoubleMatrix.zeros(i, i2));
        int i3 = i2 / 2;
        DoubleMatrix doubleMatrix = new DoubleMatrix(i, 2);
        for (int i4 = 0; i4 < doubleMatrix.rows; i4++) {
            if (eq.get(i4, new IntervalRange(0, i3)).sum() > eq.get(i4, new IntervalRange(i3, i2)).sum()) {
                doubleMatrix.put(i4, 0, 1.0d);
            } else {
                doubleMatrix.put(i4, 1, 1.0d);
            }
        }
        return new DataSet(eq, doubleMatrix);
    }

    public static double magnitude(DoubleMatrix doubleMatrix) {
        double d = 0.0d;
        for (int i = 0; i < doubleMatrix.length; i++) {
            d += doubleMatrix.get(i) * doubleMatrix.get(i);
        }
        return Math.sqrt(d);
    }

    public static DoubleMatrix unroll(DoubleMatrix doubleMatrix) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(1, doubleMatrix.length);
        for (int i = 0; i < doubleMatrix.length; i++) {
            doubleMatrix2.put(i, doubleMatrix.get(i));
        }
        return doubleMatrix2;
    }

    public static DoubleMatrix outcomes(DoubleMatrix doubleMatrix) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(doubleMatrix.rows, 1);
        for (int i = 0; i < doubleMatrix.rows; i++) {
            doubleMatrix2.put(i, SimpleBlas.iamax(doubleMatrix.getRow(i)));
        }
        return doubleMatrix2;
    }

    public static double cosineSim(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        return unitVec(doubleMatrix).dot(unitVec(doubleMatrix2));
    }

    public static DoubleMatrix normalize(DoubleMatrix doubleMatrix) {
        double min = doubleMatrix.min();
        return doubleMatrix.subi(min).divi(doubleMatrix.max() - min);
    }

    public static double cosine(DoubleMatrix doubleMatrix) {
        return 1.0d * Math.sqrt(MatrixFunctions.pow(doubleMatrix, 2.0d).sum());
    }

    public static DoubleMatrix unitVec(DoubleMatrix doubleMatrix) {
        double norm2 = doubleMatrix.norm2();
        return norm2 > 0.0d ? SimpleBlas.scal(1.0d / norm2, doubleMatrix) : doubleMatrix;
    }

    public static DoubleMatrix uniform(RandomGenerator randomGenerator, int i, int i2) {
        UniformRealDistribution uniformRealDistribution = new UniformRealDistribution(randomGenerator, 0.0d, 1.0d);
        DoubleMatrix doubleMatrix = new DoubleMatrix(i, i2);
        for (int i3 = 0; i3 < doubleMatrix.rows; i3++) {
            for (int i4 = 0; i4 < doubleMatrix.columns; i4++) {
                doubleMatrix.put(i3, i4, uniformRealDistribution.sample());
            }
        }
        return doubleMatrix;
    }

    public static boolean isValidOutcome(DoubleMatrix doubleMatrix) {
        boolean z = false;
        int i = 0;
        while (true) {
            if (i >= doubleMatrix.length) {
                break;
            }
            if (doubleMatrix.get(i) > 0.0d) {
                z = true;
                break;
            }
            i++;
        }
        return z;
    }

    public static double min(DoubleMatrix doubleMatrix) {
        double d = doubleMatrix.get(0);
        for (int i = 0; i < doubleMatrix.length; i++) {
            if (doubleMatrix.get(i) < d) {
                d = doubleMatrix.get(i);
            }
        }
        return d;
    }

    public static double max(DoubleMatrix doubleMatrix) {
        double d = doubleMatrix.get(0);
        for (int i = 0; i < doubleMatrix.length; i++) {
            if (doubleMatrix.get(i) > d) {
                d = doubleMatrix.get(i);
            }
        }
        return d;
    }

    public static void ensureValidOutcomeMatrix(DoubleMatrix doubleMatrix) {
        boolean z = false;
        int i = 0;
        while (true) {
            if (i >= doubleMatrix.length) {
                break;
            }
            if (doubleMatrix.get(i) > 0.0d) {
                z = true;
                break;
            }
            i++;
        }
        if (z) {
            return;
        }
        log.warn("Found invalid matrix assuming; nothing which means adding a 1 to the first spot");
        doubleMatrix.put(0, 1.0d);
    }

    public static void assertIntMatrix(DoubleMatrix doubleMatrix) {
        for (int i = 0; i < doubleMatrix.length; i++) {
            if (((int) doubleMatrix.get(i)) != doubleMatrix.get(i)) {
                throw new IllegalArgumentException("Found something that is not an integer at linear index " + i);
            }
        }
    }

    public static boolean isInfinite(DoubleMatrix doubleMatrix) {
        DoubleMatrix isInfinite = doubleMatrix.isInfinite();
        for (int i = 0; i < isInfinite.length; i++) {
            if (isInfinite.get(i) > 0.0d) {
                return true;
            }
        }
        return false;
    }

    public static boolean isNaN(DoubleMatrix doubleMatrix) {
        for (int i = 0; i < doubleMatrix.length; i++) {
            if (Double.isNaN(doubleMatrix.get(i))) {
                return true;
            }
        }
        return false;
    }

    public static void discretizeColumns(DoubleMatrix doubleMatrix, int i) {
        DoubleMatrix columnMaxs = doubleMatrix.columnMaxs();
        DoubleMatrix columnMins = doubleMatrix.columnMins();
        for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
            double d = columnMins.get(i2);
            double d2 = columnMaxs.get(i2);
            DoubleMatrix column = doubleMatrix.getColumn(i2);
            DoubleMatrix doubleMatrix2 = new DoubleMatrix(column.length);
            for (int i3 = 0; i3 < column.length; i3++) {
                doubleMatrix2.put(i3, MathUtils.discretize(column.get(i3), d, d2, i));
            }
            doubleMatrix.putColumn(i2, doubleMatrix2);
        }
    }

    public static DoubleMatrix roundToTheNearest(DoubleMatrix doubleMatrix, double d) {
        DoubleMatrix mul = doubleMatrix.mul(d);
        for (int i = 0; i < doubleMatrix.rows; i++) {
            for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
                mul.put(i, i2, Math.round(doubleMatrix.get(i, i2) * d) / d);
            }
        }
        return mul;
    }

    public static void columnNormalizeBySum(DoubleMatrix doubleMatrix) {
        for (int i = 0; i < doubleMatrix.columns; i++) {
            doubleMatrix.putColumn(i, doubleMatrix.getColumn(i).div(doubleMatrix.getColumn(i).sum()));
        }
    }

    public static DoubleMatrix toOutcomeVector(int i, int i2) {
        int[] iArr = new int[i2];
        iArr[i] = 1;
        return toMatrix(iArr);
    }

    public static DoubleMatrix toMatrix(int[][] iArr) {
        DoubleMatrix doubleMatrix = new DoubleMatrix(iArr.length, iArr[0].length);
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                doubleMatrix.put(i, i2, iArr[i][i2]);
            }
        }
        return doubleMatrix;
    }

    public static DoubleMatrix toMatrix(int[] iArr) {
        DoubleMatrix doubleMatrix = new DoubleMatrix(iArr.length);
        for (int i = 0; i < iArr.length; i++) {
            doubleMatrix.put(i, iArr[i]);
        }
        doubleMatrix.reshape(1, doubleMatrix.length);
        return doubleMatrix;
    }

    public static DoubleMatrix add(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        return doubleMatrix.addi(doubleMatrix2);
    }

    public static DoubleMatrix softmax(DoubleMatrix doubleMatrix) {
        MatrixFunctions.expi(doubleMatrix.subiColumnVector(doubleMatrix.rowMaxs()));
        return doubleMatrix.diviColumnVector(doubleMatrix.rowSums());
    }

    public static DoubleMatrix mean(DoubleMatrix doubleMatrix, int i) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(doubleMatrix.rows, 1);
        return i == 0 ? doubleMatrix.columnMeans() : i == 1 ? doubleMatrix2.rowMeans() : doubleMatrix2;
    }

    public static DoubleMatrix sum(DoubleMatrix doubleMatrix, int i) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(doubleMatrix.rows, 1);
        if (i == 0) {
            for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
                doubleMatrix2.put(i2, doubleMatrix.getColumn(i2).sum());
            }
            return doubleMatrix2;
        }
        if (i == 1) {
            for (int i3 = 0; i3 < doubleMatrix.rows; i3++) {
                doubleMatrix2.put(i3, doubleMatrix.getRow(i3).sum());
            }
            return doubleMatrix2;
        }
        for (int i4 = 0; i4 < doubleMatrix.rows; i4++) {
            doubleMatrix2.put(i4, doubleMatrix.getRow(i4).sum());
        }
        return doubleMatrix2;
    }

    public static DoubleMatrix binomial(DoubleMatrix doubleMatrix, int i, RandomGenerator randomGenerator) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(doubleMatrix.rows, doubleMatrix.columns);
        for (int i2 = 0; i2 < doubleMatrix2.length; i2++) {
            doubleMatrix2.put(i2, MathUtils.binomial(randomGenerator, i, doubleMatrix.get(i2)));
        }
        return doubleMatrix2;
    }

    public static DoubleMatrix columnWiseMean(DoubleMatrix doubleMatrix, int i) {
        DoubleMatrix zeros = DoubleMatrix.zeros(doubleMatrix.columns);
        for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
            zeros.put(i2, doubleMatrix.getColumn(i).mean());
        }
        return zeros;
    }

    public static DoubleMatrix avg(DoubleMatrix... doubleMatrixArr) {
        if (doubleMatrixArr == null) {
            return null;
        }
        if (doubleMatrixArr.length == 1) {
            return doubleMatrixArr[0];
        }
        DoubleMatrix doubleMatrix = doubleMatrixArr[0];
        for (int i = 1; i < doubleMatrixArr.length; i++) {
            doubleMatrix = doubleMatrix.add(doubleMatrixArr[i]);
        }
        return doubleMatrix.div(doubleMatrixArr.length);
    }

    public static int maxIndex(DoubleMatrix doubleMatrix) {
        double max = doubleMatrix.max();
        for (int i = 0; i < doubleMatrix.length; i++) {
            if (doubleMatrix.get(i) == max) {
                return i;
            }
        }
        return -1;
    }

    public static DoubleMatrix sigmoid(DoubleMatrix doubleMatrix) {
        DoubleMatrix ones = DoubleMatrix.ones(doubleMatrix.rows, doubleMatrix.columns);
        return ones.div(ones.add(MatrixFunctions.exp(doubleMatrix.neg())));
    }

    public static DoubleMatrix dot(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        return doubleMatrix.isColumnVector() || ((doubleMatrix.isRowVector() && doubleMatrix2.isColumnVector()) || doubleMatrix2.isRowVector()) ? DoubleMatrix.scalar(doubleMatrix.dot(doubleMatrix2)) : doubleMatrix.mmul(doubleMatrix2);
    }

    public static DoubleMatrix out(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        return doubleMatrix.mmul(doubleMatrix2);
    }

    public static DoubleMatrix oneMinus(DoubleMatrix doubleMatrix) {
        return DoubleMatrix.ones(doubleMatrix.rows, doubleMatrix.columns).sub(doubleMatrix);
    }

    public static DoubleMatrix oneDiv(DoubleMatrix doubleMatrix) {
        for (int i = 0; i < doubleMatrix.rows; i++) {
            for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
                if (doubleMatrix.get(i, i2) == 0.0d) {
                    doubleMatrix.put(i, i2, 0.01d);
                }
            }
        }
        return DoubleMatrix.ones(doubleMatrix.rows, doubleMatrix.columns).div(doubleMatrix);
    }

    public static DoubleMatrix columnStd(DoubleMatrix doubleMatrix) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(1, doubleMatrix.columns);
        for (int i = 0; i < doubleMatrix.columns; i++) {
            doubleMatrix2.put(i, new StandardDeviation().evaluate(doubleMatrix.getColumn(i).data));
        }
        return doubleMatrix2;
    }

    /* JADX WARN: Type inference failed for: r1v4, types: [double[], double[][]] */
    public static double meanSquaredError(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        if (doubleMatrix.length != doubleMatrix2.length) {
            throw new IllegalArgumentException("Matrices must be same length");
        }
        SimpleRegression simpleRegression = new SimpleRegression();
        simpleRegression.addData((double[][]) new double[]{doubleMatrix.data, doubleMatrix2.data});
        return simpleRegression.getMeanSquareError();
    }

    public static DoubleMatrix log(DoubleMatrix doubleMatrix) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(doubleMatrix.rows, doubleMatrix.columns);
        for (int i = 0; i < doubleMatrix.length; i++) {
            doubleMatrix2.put(i, doubleMatrix.get(i) == 0.0d ? 0.0d : Math.log(doubleMatrix.get(i)));
        }
        return doubleMatrix2;
    }

    /* JADX WARN: Type inference failed for: r1v4, types: [double[], double[][]] */
    public static double sumSquaredError(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        if (doubleMatrix.length != doubleMatrix2.length) {
            throw new IllegalArgumentException("Matrices must be same length");
        }
        SimpleRegression simpleRegression = new SimpleRegression();
        simpleRegression.addData((double[][]) new double[]{doubleMatrix.data, doubleMatrix2.data});
        return simpleRegression.getSumSquaredErrors();
    }

    public static void normalizeMatrix(DoubleMatrix doubleMatrix) {
        doubleMatrix.subiRowVector(doubleMatrix.columnMeans());
        doubleMatrix.diviRowVector(columnStd(doubleMatrix));
    }

    public static DoubleMatrix normalizeByColumnSums(DoubleMatrix doubleMatrix) {
        DoubleMatrix columnSums = doubleMatrix.columnSums();
        for (int i = 0; i < doubleMatrix.columns; i++) {
            doubleMatrix.putColumn(i, doubleMatrix.getColumn(i).div(columnSums.get(i)));
        }
        return doubleMatrix;
    }

    public static DoubleMatrix columnStdDeviation(DoubleMatrix doubleMatrix) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(1, doubleMatrix.columns);
        for (int i = 0; i < doubleMatrix2.length; i++) {
            double evaluate = new StandardDeviation().evaluate(doubleMatrix.getColumn(i).toArray());
            if (Double.isNaN(evaluate)) {
                log.warn("WTF");
            }
            doubleMatrix2.put(i, evaluate);
        }
        return doubleMatrix2;
    }

    public static DoubleMatrix divColumnsByStDeviation(DoubleMatrix doubleMatrix) {
        DoubleMatrix columnStdDeviation = columnStdDeviation(doubleMatrix);
        for (int i = 0; i < doubleMatrix.columns; i++) {
            doubleMatrix.putColumn(i, doubleMatrix.getColumn(i).div(columnStdDeviation.get(i)));
        }
        return doubleMatrix;
    }

    public static DoubleMatrix normalizeByColumnMeans(DoubleMatrix doubleMatrix) {
        DoubleMatrix columnMeans = doubleMatrix.columnMeans();
        for (int i = 0; i < doubleMatrix.columns; i++) {
            doubleMatrix.putColumn(i, doubleMatrix.getColumn(i).sub(columnMeans.get(i)));
        }
        return doubleMatrix;
    }

    public static DoubleMatrix normalizeByRowSums(DoubleMatrix doubleMatrix) {
        DoubleMatrix rowSums = doubleMatrix.rowSums();
        for (int i = 0; i < doubleMatrix.rows; i++) {
            doubleMatrix.putRow(i, doubleMatrix.getRow(i).div(rowSums.get(i)));
        }
        return doubleMatrix;
    }
}
