package net.jamu.matrix;

import java.util.Arrays;
import net.frobenius.TTrans;
import net.frobenius.lapack.PlainLapack;

/* loaded from: input_file:net/jamu/matrix/SimpleMatrixF.class */
public class SimpleMatrixF extends MatrixFBase implements MatrixF {
    private static final float BETA = 1.0f;

    public SimpleMatrixF(int i, int i2) {
        this(i, i2, new float[Checks.checkArrayLength(i, i2)]);
    }

    public SimpleMatrixF(int i, int i2, float f) {
        super(i, i2, new float[Checks.checkArrayLength(i, i2)], false);
        Arrays.fill(this.a, f);
    }

    private SimpleMatrixF(SimpleMatrixF simpleMatrixF) {
        super(simpleMatrixF.rows, simpleMatrixF.cols, simpleMatrixF.a, true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SimpleMatrixF(int i, int i2, float[] fArr) {
        super(i, i2, fArr, false);
    }

    @Override // net.jamu.matrix.MatrixFBase
    protected MatrixF create(int i, int i2) {
        return new SimpleMatrixF(i, i2);
    }

    @Override // net.jamu.matrix.MatrixFBase
    protected MatrixF create(int i, int i2, float[] fArr) {
        return new SimpleMatrixF(i, i2, fArr);
    }

    @Override // net.jamu.matrix.MatrixFBase, net.jamu.matrix.MatrixF
    public MatrixF multAdd(float f, MatrixF matrixF, MatrixF matrixF2) {
        Checks.checkMultAdd(this, matrixF, matrixF2);
        Matrices.getBlas().sgemm(TTrans.NO_TRANS.val(), TTrans.NO_TRANS.val(), matrixF2.numRows(), matrixF2.numColumns(), this.cols, f, this.a, Math.max(1, this.rows), matrixF.getArrayUnsafe(), Math.max(1, matrixF.numRows()), BETA, matrixF2.getArrayUnsafe(), Math.max(1, matrixF2.numRows()));
        return matrixF2;
    }

    @Override // net.jamu.matrix.MatrixFBase, net.jamu.matrix.MatrixF
    public MatrixF transABmultAdd(float f, MatrixF matrixF, MatrixF matrixF2) {
        Checks.checkTransABmultAdd(this, matrixF, matrixF2);
        Matrices.getBlas().sgemm(TTrans.TRANS.val(), TTrans.TRANS.val(), matrixF2.numRows(), matrixF2.numColumns(), this.rows, f, this.a, Math.max(1, this.rows), matrixF.getArrayUnsafe(), Math.max(1, matrixF.numRows()), BETA, matrixF2.getArrayUnsafe(), Math.max(1, matrixF2.numRows()));
        return matrixF2;
    }

    @Override // net.jamu.matrix.MatrixFBase, net.jamu.matrix.MatrixF
    public MatrixF transAmultAdd(float f, MatrixF matrixF, MatrixF matrixF2) {
        Checks.checkTransAmultAdd(this, matrixF, matrixF2);
        Matrices.getBlas().sgemm(TTrans.TRANS.val(), TTrans.NO_TRANS.val(), matrixF2.numRows(), matrixF2.numColumns(), this.rows, f, this.a, Math.max(1, this.rows), matrixF.getArrayUnsafe(), Math.max(1, matrixF.numRows()), BETA, matrixF2.getArrayUnsafe(), Math.max(1, matrixF2.numRows()));
        return matrixF2;
    }

    @Override // net.jamu.matrix.MatrixFBase, net.jamu.matrix.MatrixF
    public MatrixF transBmultAdd(float f, MatrixF matrixF, MatrixF matrixF2) {
        Checks.checkTransBmultAdd(this, matrixF, matrixF2);
        Matrices.getBlas().sgemm(TTrans.NO_TRANS.val(), TTrans.TRANS.val(), matrixF2.numRows(), matrixF2.numColumns(), this.cols, f, this.a, Math.max(1, this.rows), matrixF.getArrayUnsafe(), Math.max(1, matrixF.numRows()), BETA, matrixF2.getArrayUnsafe(), Math.max(1, matrixF2.numRows()));
        return matrixF2;
    }

    @Override // net.jamu.matrix.MatrixF
    public MatrixF solve(MatrixF matrixF, MatrixF matrixF2) {
        Checks.checkSolve(this, matrixF, matrixF2);
        return isSquareMatrix() ? lusolve(this, matrixF2, matrixF) : qrsolve(this, matrixF2, matrixF);
    }

    @Override // net.jamu.matrix.MatrixF
    public SvdF svd(boolean z) {
        return new SvdF(this, z);
    }

    @Override // net.jamu.matrix.MatrixF
    public SvdEconF svdEcon() {
        return new SvdEconF(this);
    }

    @Override // net.jamu.matrix.MatrixF
    public EvdF evd(boolean z) {
        if (isSquareMatrix()) {
            return new EvdF(this, z);
        }
        throw new IllegalArgumentException("EVD only works for square matrices");
    }

    @Override // net.jamu.matrix.MatrixF
    public QrdF qrd() {
        return new QrdF(this);
    }

    @Override // net.jamu.matrix.MatrixF
    public LudF lud() {
        return new LudF(this);
    }

    @Override // net.jamu.matrix.MatrixF
    public float norm2() {
        return new SvdF(this, false).norm2();
    }

    private static MatrixF lusolve(MatrixF matrixF, MatrixF matrixF2, MatrixF matrixF3) {
        matrixF2.setInplace(matrixF3);
        PlainLapack.sgesv(Matrices.getLapack(), matrixF.numRows(), matrixF3.numColumns(), (float[]) matrixF.getArrayUnsafe().clone(), Math.max(1, matrixF.numRows()), new int[matrixF.numRows()], matrixF2.getArrayUnsafe(), Math.max(1, matrixF.numRows()));
        return matrixF2;
    }

    private static MatrixF qrsolve(MatrixF matrixF, MatrixF matrixF2, MatrixF matrixF3) {
        int numColumns = matrixF3.numColumns();
        int numRows = matrixF.numRows();
        int numColumns2 = matrixF.numColumns();
        SimpleMatrixF simpleMatrixF = new SimpleMatrixF(Math.max(numRows, numColumns2), numColumns);
        for (int i = 0; i < numColumns; i++) {
            for (int i2 = 0; i2 < numRows; i2++) {
                simpleMatrixF.setUnsafe(i2, i, matrixF3.getUnsafe(i2, i));
            }
        }
        PlainLapack.sgels(Matrices.getLapack(), TTrans.NO_TRANS, numRows, numColumns2, numColumns, (float[]) matrixF.getArrayUnsafe().clone(), Math.max(1, numRows), simpleMatrixF.getArrayUnsafe(), Math.max(1, Math.max(numRows, numColumns2)));
        for (int i3 = 0; i3 < numColumns; i3++) {
            for (int i4 = 0; i4 < numColumns2; i4++) {
                matrixF2.setUnsafe(i4, i3, simpleMatrixF.getUnsafe(i4, i3));
            }
        }
        return matrixF2;
    }

    @Override // net.jamu.matrix.MatrixF
    public MatrixF copy() {
        return new SimpleMatrixF(this);
    }
}
