package org.csstudio.ndarray;

import java.util.Arrays;
import org.epics.util.array.ListNumber;

/* loaded from: input_file:org/csstudio/ndarray/NDMatrix.class */
public class NDMatrix {
    public static NDArray zeros(NDType nDType, NDShape nDShape) {
        return new NDArray(NDArray.createDataArray(nDType, nDShape.getSize()), nDShape, nDType == NDType.BOOL);
    }

    public static NDArray ones(NDType nDType, NDShape nDShape) {
        ListNumber createDataArray = NDArray.createDataArray(nDType, nDShape.getSize());
        int size = nDShape.getSize();
        for (int i = 0; i < size; i++) {
            createDataArray.setInt(i, 1);
        }
        return new NDArray(createDataArray, nDShape, nDType == NDType.BOOL);
    }

    public static NDArray arange(double d, double d2, double d3) {
        return arange(d, d2, d3, NDType.FLOAT64);
    }

    public static NDArray arange(double d, double d2, double d3, NDType nDType) {
        int count = NDArray.getCount(d, d2, d3);
        double[] dArr = new double[count];
        for (int i = 0; i < count; i++) {
            dArr[i] = d + (i * d3);
        }
        return NDArray.create(dArr, nDType);
    }

    public static NDArray linspace(double d, double d2, int i, NDType nDType) {
        if (i < 2) {
            return NDArray.create(new double[]{d}, nDType);
        }
        double d3 = (d2 - d) / (i - 1);
        NDArray nDArray = new NDArray(nDType, new NDShape(i));
        for (int i2 = 0; i2 < i; i2++) {
            nDArray.setFlatDouble(i2, d + (i2 * d3));
        }
        return nDArray;
    }

    public static NDArray reshape(NDArray nDArray, int... iArr) {
        return reshape(nDArray, new NDShape(iArr));
    }

    public static NDArray reshape(NDArray nDArray, NDShape nDShape) {
        if (nDArray.getShape().getSize() != nDShape.getSize()) {
            throw new IllegalArgumentException("Cannot change shape from " + nDArray.getShape() + " to " + nDShape);
        }
        return new NDArray(nDArray, nDShape, new NDStrides(nDShape));
    }

    public static NDArray transpose(NDArray nDArray, int... iArr) {
        NDShape shape = nDArray.getShape();
        int dimensions = shape.getDimensions();
        if (dimensions <= 1) {
            return nDArray;
        }
        if (iArr.length <= 0) {
            iArr = new int[dimensions];
            for (int i = 0; i < dimensions; i++) {
                iArr[i] = (dimensions - 1) - i;
            }
        } else {
            if (iArr.length != dimensions) {
                throw new IllegalArgumentException("Axes " + Arrays.toString(iArr) + " don't match array shape " + shape);
            }
            for (int i2 : iArr) {
                if (i2 < 0 || i2 >= dimensions) {
                    throw new IllegalArgumentException("Invalid axis " + i2 + " for array shape " + shape);
                }
            }
        }
        NDStrides strides = nDArray.getStrides();
        int[] iArr2 = new int[dimensions];
        int[] iArr3 = new int[dimensions];
        for (int i3 = 0; i3 < dimensions; i3++) {
            iArr2[i3] = shape.getSize(iArr[i3]);
            iArr3[i3] = strides.getStride(iArr[i3]);
        }
        return new NDArray(nDArray, new NDShape(iArr2), new NDStrides(iArr3));
    }

    public static NDArray dot(NDArray nDArray, NDArray nDArray2) {
        NDShape shape = nDArray.getShape();
        NDShape shape2 = nDArray2.getShape();
        if (shape.getDimensions() == 2 && shape2.getDimensions() == 2) {
            return dot2x2(nDArray, nDArray2, shape, shape2);
        }
        if (shape.getDimensions() == 2 && shape2.getDimensions() == 1) {
            return dot2x1(nDArray, nDArray2, shape);
        }
        if (shape.getDimensions() == 1 && shape2.getDimensions() == 1) {
            return inner(nDArray, nDArray2);
        }
        throw new IllegalArgumentException("Matrix multiplication not supported for arrays with shapes " + shape + " and " + shape2);
    }

    private static NDArray dot2x2(NDArray nDArray, NDArray nDArray2, NDShape nDShape, NDShape nDShape2) {
        NDType determineSuperType = NDType.determineSuperType(nDArray.getType(), nDArray2.getType());
        int size = nDShape.getSize(0);
        int size2 = nDShape.getSize(1);
        if (size2 != nDShape2.getSize(0)) {
            throw new IllegalArgumentException("For matrix multiplication, number of columns in first array must match number of rows in second array, but got shapes " + nDShape + " and " + nDShape2);
        }
        int size3 = nDShape2.getSize(1);
        NDArray zeros = zeros(determineSuperType, new NDShape(size, size3));
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size3; i2++) {
                double d = 0.0d;
                for (int i3 = 0; i3 < size2; i3++) {
                    d += nDArray.getDouble(i, i3) * nDArray2.getDouble(i3, i2);
                }
                zeros.setDouble(d, i, i2);
            }
        }
        return zeros;
    }

    private static NDArray dot2x1(NDArray nDArray, NDArray nDArray2, NDShape nDShape) {
        NDType determineSuperType = NDType.determineSuperType(nDArray.getType(), nDArray2.getType());
        int size = nDShape.getSize(0);
        int size2 = nDShape.getSize(1);
        if (size2 != nDArray2.getSize()) {
            throw new IllegalArgumentException("For matrix multiplication, number of columns in first array must match number of rows in second array, but got shapes " + nDShape + " and " + nDArray2.getShape());
        }
        NDArray zeros = zeros(determineSuperType, new NDShape(size2));
        for (int i = 0; i < size; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < size2; i2++) {
                d += nDArray.getDouble(i, i2) * nDArray2.getFlatDouble(i2);
            }
            zeros.setDouble(d, i);
        }
        return zeros;
    }

    public static NDArray inner(NDArray nDArray, NDArray nDArray2) {
        NDShape shape = nDArray.getShape();
        NDShape shape2 = nDArray2.getShape();
        if (shape.getDimensions() != shape2.getDimensions()) {
            throw new IllegalArgumentException("Inner product only supported for 1-D arrays, not for shapes " + shape + " and " + shape2);
        }
        int size = shape.getSize();
        if (size != shape2.getSize()) {
            throw new IllegalArgumentException("Inner product arrays must have same size, not shapes " + shape + " and " + shape2);
        }
        NDArray zeros = zeros(NDType.determineSuperType(nDArray.getType(), nDArray2.getType()), new NDShape(1));
        double d = 0.0d;
        for (int i = 0; i < size; i++) {
            d += nDArray.getDouble(i) * nDArray2.getDouble(i);
        }
        zeros.setDouble(d, 0);
        return zeros;
    }
}
