package net.jamu.matrix;

import java.util.Arrays;
import net.frobenius.TTrans;

/* loaded from: input_file:net/jamu/matrix/TensorF.class */
public class TensorF extends TensorBase {
    private static final float BETA = 1.0f;
    private static final int OFFS = 0;
    protected float[] a;

    public TensorF(int i, int i2) {
        this(i, i2, 1);
    }

    public TensorF(int i, int i2, int i3) {
        super(i, i2, i3);
        this.a = new float[this.length];
    }

    public TensorF(MatrixF matrixF) {
        super(matrixF.numRows(), matrixF.numColumns(), 1);
        this.a = Arrays.copyOf(matrixF.getArrayUnsafe(), matrixF.getArrayUnsafe().length);
    }

    public TensorF(TensorF tensorF) {
        super(tensorF.rows, tensorF.cols, tensorF.depth);
        this.a = Arrays.copyOf(tensorF.a, tensorF.a.length);
    }

    public TensorF set(int i, int i2, int i3, float f) {
        checkIndex(i, i2, i3);
        return setUnsafe(i, i2, i3, f);
    }

    public float get(int i, int i2, int i3) {
        checkIndex(i, i2, i3);
        return getUnsafe(i, i2, i3);
    }

    public TensorF setUnsafe(int i, int i2, int i3, float f) {
        this.a[idx(i, i2, i3)] = f;
        return this;
    }

    public float getUnsafe(int i, int i2, int i3) {
        return this.a[idx(i, i2, i3)];
    }

    public TensorF set(MatrixF matrixF, int i) {
        Checks.checkEqualDimension(this, matrixF);
        int startIdx = startIdx(i);
        float[] arrayUnsafe = matrixF.getArrayUnsafe();
        System.arraycopy(arrayUnsafe, OFFS, this.a, startIdx, arrayUnsafe.length);
        return this;
    }

    public MatrixF get(int i) {
        int startIdx = startIdx(i);
        int stride = stride();
        float[] fArr = new float[stride];
        System.arraycopy(this.a, startIdx, fArr, OFFS, stride);
        return new SimpleMatrixF(this.rows, this.cols, fArr);
    }

    public TensorF append(MatrixF matrixF) {
        Checks.checkEqualDimension(this, matrixF);
        float[] growAndCopyForAppend = growAndCopyForAppend(matrixF);
        float[] arrayUnsafe = matrixF.getArrayUnsafe();
        System.arraycopy(arrayUnsafe, OFFS, growAndCopyForAppend, this.length, arrayUnsafe.length);
        this.a = growAndCopyForAppend;
        this.length = growAndCopyForAppend.length;
        this.depth++;
        return this;
    }

    public TensorF multAdd(float f, TensorF tensorF, TensorF tensorF2) {
        Checks.checkMultAdd(this, tensorF, tensorF2);
        Matrices.getBlas().sgemm_multi(TTrans.NO_TRANS.val(), TTrans.NO_TRANS.val(), tensorF2.numRows(), tensorF2.numColumns(), this.cols, f, this.a, OFFS, Math.max(1, this.rows), tensorF.getArrayUnsafe(), OFFS, Math.max(1, tensorF.numRows()), BETA, tensorF2.getArrayUnsafe(), OFFS, Math.max(1, tensorF2.numRows()), Math.min(Math.min(this.depth, tensorF.depth), tensorF2.depth), stride(), tensorF.stride(), tensorF2.stride());
        return tensorF2;
    }

    public TensorF transAmultAdd(float f, TensorF tensorF, TensorF tensorF2) {
        Checks.checkTransAmultAdd(this, tensorF, tensorF2);
        Matrices.getBlas().sgemm_multi(TTrans.TRANS.val(), TTrans.NO_TRANS.val(), tensorF2.numRows(), tensorF2.numColumns(), this.rows, f, this.a, OFFS, Math.max(1, this.rows), tensorF.getArrayUnsafe(), OFFS, Math.max(1, tensorF.numRows()), BETA, tensorF2.getArrayUnsafe(), OFFS, Math.max(1, tensorF2.numRows()), Math.min(Math.min(this.depth, tensorF.depth), tensorF2.depth), stride(), tensorF.stride(), tensorF2.stride());
        return tensorF2;
    }

    public TensorF transBmultAdd(float f, TensorF tensorF, TensorF tensorF2) {
        Checks.checkTransBmultAdd(this, tensorF, tensorF2);
        Matrices.getBlas().sgemm_multi(TTrans.NO_TRANS.val(), TTrans.TRANS.val(), tensorF2.numRows(), tensorF2.numColumns(), this.cols, f, this.a, OFFS, Math.max(1, this.rows), tensorF.getArrayUnsafe(), OFFS, Math.max(1, tensorF.numRows()), BETA, tensorF2.getArrayUnsafe(), OFFS, Math.max(1, tensorF2.numRows()), Math.min(Math.min(this.depth, tensorF.depth), tensorF2.depth), stride(), tensorF.stride(), tensorF2.stride());
        return tensorF2;
    }

    public TensorF transABmultAdd(float f, TensorF tensorF, TensorF tensorF2) {
        Checks.checkTransABmultAdd(this, tensorF, tensorF2);
        Matrices.getBlas().sgemm_multi(TTrans.TRANS.val(), TTrans.TRANS.val(), tensorF2.numRows(), tensorF2.numColumns(), this.rows, f, this.a, OFFS, Math.max(1, this.rows), tensorF.getArrayUnsafe(), OFFS, Math.max(1, tensorF.numRows()), BETA, tensorF2.getArrayUnsafe(), OFFS, Math.max(1, tensorF2.numRows()), Math.min(Math.min(this.depth, tensorF.depth), tensorF2.depth), stride(), tensorF.stride(), tensorF2.stride());
        return tensorF2;
    }

    public TensorF transABmultAdd(TensorF tensorF, TensorF tensorF2) {
        return transABmultAdd(BETA, tensorF, tensorF2);
    }

    public TensorF transABmult(float f, TensorF tensorF, TensorF tensorF2) {
        return transABmultAdd(f, tensorF, tensorF2.zeroInplace());
    }

    public TensorF transABmult(TensorF tensorF, TensorF tensorF2) {
        return transABmult(BETA, tensorF, tensorF2);
    }

    public TensorF transBmultAdd(TensorF tensorF, TensorF tensorF2) {
        return transBmultAdd(BETA, tensorF, tensorF2);
    }

    public TensorF transBmult(float f, TensorF tensorF, TensorF tensorF2) {
        return transBmultAdd(f, tensorF, tensorF2.zeroInplace());
    }

    public TensorF transBmult(TensorF tensorF, TensorF tensorF2) {
        return transBmult(BETA, tensorF, tensorF2);
    }

    public TensorF transAmultAdd(TensorF tensorF, TensorF tensorF2) {
        return transAmultAdd(BETA, tensorF, tensorF2);
    }

    public TensorF transAmult(float f, TensorF tensorF, TensorF tensorF2) {
        return transAmultAdd(f, tensorF, tensorF2.zeroInplace());
    }

    public TensorF transAmult(TensorF tensorF, TensorF tensorF2) {
        return transAmult(BETA, tensorF, tensorF2);
    }

    public TensorF multAdd(TensorF tensorF, TensorF tensorF2) {
        return multAdd(BETA, tensorF, tensorF2);
    }

    public TensorF mult(float f, TensorF tensorF, TensorF tensorF2) {
        return multAdd(f, tensorF, tensorF2.zeroInplace());
    }

    public TensorF mult(TensorF tensorF, TensorF tensorF2) {
        return mult(BETA, tensorF, tensorF2);
    }

    public TensorF hadamard(TensorF tensorF, TensorF tensorF2) {
        Checks.checkEqualDimension(this, tensorF);
        Checks.checkEqualDimension(this, tensorF2);
        int min = Math.min(Math.min(this.depth, tensorF.depth), tensorF2.depth) * stride();
        float[] fArr = this.a;
        float[] arrayUnsafe = tensorF.getArrayUnsafe();
        float[] arrayUnsafe2 = tensorF2.getArrayUnsafe();
        for (int i = OFFS; i < min; i++) {
            arrayUnsafe2[i] = fArr[i] * arrayUnsafe[i];
        }
        return tensorF2;
    }

    public TensorF hadamard(TensorF tensorF) {
        Checks.checkEqualDimension(this, tensorF);
        return hadamard(tensorF, create(this.rows, this.cols, Math.min(this.depth, tensorF.depth)));
    }

    public TensorF hadamardTransposed(TensorF tensorF) {
        Checks.checkTrans(this, tensorF);
        int i = this.rows;
        int i2 = this.cols;
        int min = Math.min(this.depth, tensorF.depth);
        TensorF create = create(i, i2, min);
        float[] fArr = this.a;
        float[] arrayUnsafe = tensorF.getArrayUnsafe();
        float[] arrayUnsafe2 = create.getArrayUnsafe();
        for (int i3 = OFFS; i3 < min; i3++) {
            for (int i4 = OFFS; i4 < i2; i4++) {
                for (int i5 = OFFS; i5 < i; i5++) {
                    int idx = idx(i5, i4, i3);
                    arrayUnsafe2[idx] = fArr[idx] * arrayUnsafe[tensorF.idx(i4, i5, i3)];
                }
            }
        }
        return create;
    }

    public TensorF transposedHadamard(TensorF tensorF) {
        Checks.checkTrans(this, tensorF);
        int numRows = tensorF.numRows();
        int numColumns = tensorF.numColumns();
        int min = Math.min(this.depth, tensorF.depth);
        TensorF create = create(numRows, numColumns, min);
        float[] fArr = this.a;
        float[] arrayUnsafe = tensorF.getArrayUnsafe();
        float[] arrayUnsafe2 = create.getArrayUnsafe();
        for (int i = OFFS; i < min; i++) {
            for (int i2 = OFFS; i2 < numColumns; i2++) {
                for (int i3 = OFFS; i3 < numRows; i3++) {
                    int idx = tensorF.idx(i3, i2, i);
                    arrayUnsafe2[idx] = arrayUnsafe[idx] * fArr[idx(i2, i3, i)];
                }
            }
        }
        return create;
    }

    public TensorF times(TensorF tensorF) {
        return mult(tensorF, create(this.rows, tensorF.numColumns(), Math.min(this.depth, tensorF.depth)));
    }

    public TensorF timesTransposed(TensorF tensorF) {
        return transBmult(tensorF, create(this.rows, tensorF.numRows(), Math.min(this.depth, tensorF.depth)));
    }

    public TensorF transposedTimes(TensorF tensorF) {
        return transAmult(tensorF, create(this.cols, tensorF.numColumns(), Math.min(this.depth, tensorF.depth)));
    }

    public TensorF zeroInplace() {
        Arrays.fill(this.a, 0.0f);
        return this;
    }

    public TensorF scaleInplace(float f) {
        if (f == 0.0d) {
            return zeroInplace();
        }
        if (f == 1.0d) {
            return this;
        }
        float[] fArr = this.a;
        for (int i = OFFS; i < fArr.length; i++) {
            int i2 = i;
            fArr[i2] = fArr[i2] * f;
        }
        return this;
    }

    public TensorF mapInplace(FFunction fFunction) {
        float[] fArr = this.a;
        for (int i = OFFS; i < fArr.length; i++) {
            fArr[i] = fFunction.apply(fArr[i]);
        }
        return this;
    }

    public TensorF map(FFunction fFunction) {
        return new TensorF(this).mapInplace(fFunction);
    }

    public TensorF clampInplace(float f, float f2) {
        float[] fArr = this.a;
        for (int i = OFFS; i < fArr.length; i++) {
            fArr[i] = Math.min(Math.max(fArr[i], f), f2);
        }
        return this;
    }

    public float[] getArrayUnsafe() {
        return this.a;
    }

    public TensorF copy() {
        return new TensorF(this);
    }

    public TensorF shuffle() {
        return shuffle((XoShiRo256StarStar) null);
    }

    public TensorF shuffle(long j) {
        return shuffle(new XoShiRo256StarStar(j));
    }

    public TensorF rescaleInplace(float f, float f2) {
        float[] fArr = this.a;
        float f3 = Float.MAX_VALUE;
        float f4 = -3.4028235E38f;
        for (int i = OFFS; i < fArr.length; i++) {
            float f5 = this.a[i];
            if (f5 < f3) {
                f3 = f5;
            }
            if (f5 > f4) {
                f4 = f5;
            }
        }
        float f6 = f2 - f;
        float f7 = f3 == f4 ? Float.MIN_NORMAL : f4 - f3;
        for (int i2 = OFFS; i2 < fArr.length; i2++) {
            this.a[i2] = f + (((this.a[i2] - f3) * f6) / f7);
        }
        return this;
    }

    private TensorF shuffle(XoShiRo256StarStar xoShiRo256StarStar) {
        int stride = stride();
        float[] fArr = this.a;
        float[] fArr2 = new float[stride];
        XoShiRo256StarStar xoShiRo256StarStar2 = xoShiRo256StarStar == null ? new XoShiRo256StarStar() : xoShiRo256StarStar;
        for (int i = this.depth; i > 1; i--) {
            swap((i - 1) * stride, fArr, fArr2, stride, xoShiRo256StarStar2.nextInt(i) * stride);
        }
        return this;
    }

    private TensorF create(int i, int i2, int i3) {
        return new TensorF(i, i2, i3);
    }

    private float[] growAndCopyForAppend(Dimensions dimensions) {
        return copyForAppend(new float[checkNewArrayLength(dimensions)]);
    }

    private float[] copyForAppend(float[] fArr) {
        System.arraycopy(this.a, OFFS, fArr, OFFS, this.length);
        return fArr;
    }

    private static void swap(int i, float[] fArr, float[] fArr2, int i2, int i3) {
        if (i != i3) {
            System.arraycopy(fArr, i, fArr2, OFFS, i2);
            System.arraycopy(fArr, i3, fArr, i, i2);
            System.arraycopy(fArr2, OFFS, fArr, i3, i2);
        }
    }
}
