package org.nd4j.linalg.api.blas.impl;

import org.nd4j.linalg.api.blas.Lapack;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/blas/impl/BaseLapack.class */
public abstract class BaseLapack implements Lapack {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BaseLapack.class);

    @Override // org.nd4j.linalg.api.blas.Lapack
    public INDArray getrf(INDArray iNDArray) {
        if (iNDArray.rows() > Integer.MAX_VALUE || iNDArray.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int rows = iNDArray.rows();
        int columns = iNDArray.columns();
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1L), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, iNDArray.dataType()).getFirst());
        int min = Math.min(rows, columns);
        INDArray createArrayFromShapeBuffer2 = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(min), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, min}, iNDArray.dataType()).getFirst());
        if (iNDArray.data().dataType() == DataType.DOUBLE) {
            dgetrf(rows, columns, iNDArray, createArrayFromShapeBuffer2, createArrayFromShapeBuffer);
        } else {
            if (iNDArray.data().dataType() != DataType.FLOAT) {
                throw new UnsupportedOperationException();
            }
            sgetrf(rows, columns, iNDArray, createArrayFromShapeBuffer2, createArrayFromShapeBuffer);
        }
        if (createArrayFromShapeBuffer.getInt(0) < 0) {
            throw new Error("Parameter #" + createArrayFromShapeBuffer.getInt(0) + " to getrf() was not valid");
        }
        if (createArrayFromShapeBuffer.getInt(0) > 0) {
            log.warn("The matrix is singular - cannot be used for inverse op. Check L matrix at row " + createArrayFromShapeBuffer.getInt(0));
        }
        return createArrayFromShapeBuffer2;
    }

    public abstract void sgetrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3);

    public abstract void dgetrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3);

    @Override // org.nd4j.linalg.api.blas.Lapack
    public void potrf(INDArray iNDArray, boolean z) {
        if (iNDArray.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        byte b = (byte) (z ? 76 : 85);
        int columns = iNDArray.columns();
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1L), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, iNDArray.dataType()).getFirst());
        if (iNDArray.data().dataType() == DataType.DOUBLE) {
            dpotrf(b, columns, iNDArray, createArrayFromShapeBuffer);
        } else {
            if (iNDArray.data().dataType() != DataType.FLOAT) {
                throw new UnsupportedOperationException();
            }
            spotrf(b, columns, iNDArray, createArrayFromShapeBuffer);
        }
        if (createArrayFromShapeBuffer.getInt(0) < 0) {
            throw new Error("Parameter #" + createArrayFromShapeBuffer.getInt(0) + " to potrf() was not valid");
        }
        if (createArrayFromShapeBuffer.getInt(0) > 0) {
            throw new Error("The matrix is not positive definite! (potrf fails @ order " + createArrayFromShapeBuffer.getInt(0) + ")");
        }
    }

    public abstract void spotrf(byte b, int i, INDArray iNDArray, INDArray iNDArray2);

    public abstract void dpotrf(byte b, int i, INDArray iNDArray, INDArray iNDArray2);

    @Override // org.nd4j.linalg.api.blas.Lapack
    public void geqrf(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.rows() > Integer.MAX_VALUE || iNDArray.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int rows = iNDArray.rows();
        int columns = iNDArray.columns();
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1L), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, iNDArray.dataType()).getFirst());
        if (iNDArray2.rows() != iNDArray.columns() || iNDArray2.columns() != iNDArray.columns()) {
            throw new Error("geqrf: R must be N x N (n = columns in A)");
        }
        if (iNDArray.data().dataType() == DataType.DOUBLE) {
            dgeqrf(rows, columns, iNDArray, iNDArray2, createArrayFromShapeBuffer);
        } else {
            if (iNDArray.data().dataType() != DataType.FLOAT) {
                throw new UnsupportedOperationException();
            }
            sgeqrf(rows, columns, iNDArray, iNDArray2, createArrayFromShapeBuffer);
        }
        if (createArrayFromShapeBuffer.getInt(0) < 0) {
            throw new Error("Parameter #" + createArrayFromShapeBuffer.getInt(0) + " to getrf() was not valid");
        }
    }

    public abstract void sgeqrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3);

    public abstract void dgeqrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3);

    @Override // org.nd4j.linalg.api.blas.Lapack
    public int syev(char c, char c2, INDArray iNDArray, INDArray iNDArray2) {
        int ssyev;
        if (iNDArray.rows() != iNDArray.columns()) {
            throw new Error("syev: A must be square.");
        }
        if (iNDArray.rows() != iNDArray2.length()) {
            throw new Error("syev: V must be the length of the matrix dimension.");
        }
        if (iNDArray.rows() > Integer.MAX_VALUE || iNDArray.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        if (iNDArray.data().dataType() == DataType.DOUBLE) {
            ssyev = dsyev(c, c2, iNDArray.rows(), iNDArray, iNDArray2);
        } else {
            if (iNDArray.data().dataType() != DataType.FLOAT) {
                throw new UnsupportedOperationException();
            }
            ssyev = ssyev(c, c2, iNDArray.rows(), iNDArray, iNDArray2);
        }
        return ssyev;
    }

    public abstract int ssyev(char c, char c2, int i, INDArray iNDArray, INDArray iNDArray2);

    public abstract int dsyev(char c, char c2, int i, INDArray iNDArray, INDArray iNDArray2);

    @Override // org.nd4j.linalg.api.blas.Lapack
    public void gesvd(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        if (iNDArray.rows() > Integer.MAX_VALUE || iNDArray.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int rows = iNDArray.rows();
        int columns = iNDArray.columns();
        byte b = (byte) (iNDArray3 == null ? 78 : 65);
        byte b2 = (byte) (iNDArray4 == null ? 78 : 65);
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1L), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, DataType.INT).getFirst());
        if (iNDArray.data().dataType() == DataType.DOUBLE) {
            dgesvd(b, b2, rows, columns, iNDArray, iNDArray2, iNDArray3, iNDArray4, createArrayFromShapeBuffer);
        } else {
            if (iNDArray.data().dataType() != DataType.FLOAT) {
                throw new UnsupportedOperationException();
            }
            sgesvd(b, b2, rows, columns, iNDArray, iNDArray2, iNDArray3, iNDArray4, createArrayFromShapeBuffer);
        }
        if (createArrayFromShapeBuffer.getInt(0) < 0) {
            throw new Error("Parameter #" + createArrayFromShapeBuffer.getInt(0) + " to gesvd() was not valid");
        }
        if (createArrayFromShapeBuffer.getInt(0) > 0) {
            log.warn("The matrix contains singular elements. Check S matrix at row " + createArrayFromShapeBuffer.getInt(0));
        }
    }

    public abstract void sgesvd(byte b, byte b2, int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5);

    public abstract void dgesvd(byte b, byte b2, int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5);

    @Override // org.nd4j.linalg.api.blas.Lapack
    public INDArray getPFactor(int i, INDArray iNDArray) {
        INDArray eye = Nd4j.eye(i);
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            int i3 = iNDArray.getInt(i2) - 1;
            if (i3 > i2) {
                INDArray dup = eye.getColumn(i2).dup();
                eye.putColumn(i2, eye.getColumn(i3));
                eye.putColumn(i3, dup);
            }
        }
        return eye;
    }

    @Override // org.nd4j.linalg.api.blas.Lapack
    public INDArray getLFactor(INDArray iNDArray) {
        if (iNDArray.rows() > Integer.MAX_VALUE || iNDArray.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int rows = iNDArray.rows();
        int columns = iNDArray.columns();
        INDArray create = Nd4j.create(rows, columns);
        for (int i = 0; i < rows; i++) {
            for (int i2 = 0; i2 < columns; i2++) {
                if (i > i2 && i < rows && i2 < columns) {
                    create.putScalar(i, i2, iNDArray.getFloat(i, i2));
                } else if (i < i2) {
                    create.putScalar(i, i2, 0.0d);
                } else {
                    create.putScalar(i, i2, 1.0d);
                }
            }
        }
        return create;
    }

    @Override // org.nd4j.linalg.api.blas.Lapack
    public INDArray getUFactor(INDArray iNDArray) {
        if (iNDArray.rows() > Integer.MAX_VALUE || iNDArray.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int rows = iNDArray.rows();
        int columns = iNDArray.columns();
        INDArray create = Nd4j.create(columns, columns);
        for (int i = 0; i < columns; i++) {
            for (int i2 = 0; i2 < columns; i2++) {
                if (i > i2 || i >= rows || i2 >= columns) {
                    create.putScalar(i, i2, 0.0d);
                } else {
                    create.putScalar(i, i2, iNDArray.getFloat(i, i2));
                }
            }
        }
        return create;
    }
}
