package org.nd4j.linalg.indexing;

import com.google.common.primitives.Ints;
import java.util.ArrayList;
import java.util.Arrays;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.NDArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/indexing/NDArrayIndex.class */
public class NDArrayIndex implements INDArrayIndex {
    private int[] indices;
    private boolean isInterval = false;
    private static NDArrayIndexEmpty EMPTY;
    private static NewAxis NEW_AXIS;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static INDArrayIndex point(int i) {
        return new PointIndex(i);
    }

    public static INDArrayIndex[] indexesFor(int... iArr) {
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iNDArrayIndexArr[i] = point(iArr[i]);
        }
        return iNDArrayIndexArr;
    }

    public static int offset(INDArray iNDArray, int... iArr) {
        return offset(iNDArray.stride(), iArr);
    }

    public static int offset(INDArray iNDArray, INDArrayIndex... iNDArrayIndexArr) {
        return offset(iNDArray.stride(), Indices.offsets(iNDArray.shape(), iNDArrayIndexArr));
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r0v8, types: [int[], int[][]] */
    public static void updateForNewAxes(INDArray iNDArray, INDArrayIndex... iNDArrayIndexArr) {
        int numNewAxis = numNewAxis(iNDArrayIndexArr);
        if (numNewAxis < 1 || (iNDArrayIndexArr[0].length() <= 1 && !(iNDArrayIndexArr[0] instanceof NDArrayIndexAll))) {
            if (numNewAxis > 0) {
                int[] concat = Ints.concat(new int[]{ArrayUtil.nTimes(numNewAxis, 1), iNDArray.shape()});
                int[] concat2 = Ints.concat(new int[]{new int[numNewAxis], iNDArray.stride()});
                iNDArray.setShape(concat);
                iNDArray.setStride(concat2);
                return;
            }
            return;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int i = 0;
        for (INDArrayIndex iNDArrayIndex : iNDArrayIndexArr) {
            if (iNDArrayIndex instanceof NewAxis) {
                arrayList.add(1);
                arrayList2.add(0);
            } else {
                arrayList.add(Integer.valueOf(iNDArray.size(i)));
                arrayList2.add(Integer.valueOf(iNDArray.size(i)));
                i++;
            }
        }
        while (i < iNDArray.rank()) {
            arrayList.add(Integer.valueOf(i));
            arrayList2.add(Integer.valueOf(i));
            i++;
        }
        int[] array = Ints.toArray(arrayList);
        int[] array2 = Ints.toArray(arrayList2);
        iNDArray.setShape(array);
        iNDArray.setStride(array2);
    }

    public static int offset(int[] iArr, int[] iArr2) {
        int i = 0;
        if (ArrayUtil.prod(iArr2) == 1) {
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                i += iArr2[i2] * iArr[i2];
            }
        } else {
            for (int i3 = 0; i3 < iArr2.length; i3++) {
                i += iArr2[i3] * iArr[i3];
            }
        }
        return i;
    }

    public static INDArrayIndex[] nTimes(INDArrayIndex iNDArrayIndex, int i) {
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[i];
        for (int i2 = 0; i2 < i; i2++) {
            iNDArrayIndexArr[i2] = iNDArrayIndex;
        }
        return iNDArrayIndexArr;
    }

    public NDArrayIndex(int... iArr) {
        this.indices = new int[1];
        this.indices = iArr;
    }

    public static INDArrayIndex empty() {
        return EMPTY;
    }

    public static INDArrayIndex all() {
        return new NDArrayIndexAll(true);
    }

    public static INDArrayIndex newAxis() {
        return NEW_AXIS;
    }

    public static INDArrayIndex[] resolve(INDArray iNDArray, INDArrayIndex... iNDArrayIndexArr) {
        return resolve(allFor(iNDArray), iNDArrayIndexArr);
    }

    public static int numPoints(INDArrayIndex... iNDArrayIndexArr) {
        int i = 0;
        for (INDArrayIndex iNDArrayIndex : iNDArrayIndexArr) {
            if (iNDArrayIndex instanceof PointIndex) {
                i++;
            }
        }
        return i;
    }

    public static INDArrayIndex[] resolve(DataBuffer dataBuffer, INDArrayIndex... iNDArrayIndexArr) {
        int i = 0;
        for (INDArrayIndex iNDArrayIndex : iNDArrayIndexArr) {
            if (iNDArrayIndex instanceof SpecifiedIndex) {
                i++;
            }
        }
        if (i > 0) {
            DataBuffer shapeOf = Shape.shapeOf(dataBuffer);
            INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[iNDArrayIndexArr.length];
            for (int i2 = 0; i2 < iNDArrayIndexArr.length; i2++) {
                if (iNDArrayIndexArr[i2] instanceof SpecifiedIndex) {
                    iNDArrayIndexArr2[i2] = iNDArrayIndexArr[i2];
                } else if (iNDArrayIndexArr[i2] instanceof NDArrayIndexAll) {
                    iNDArrayIndexArr2[i2] = new SpecifiedIndex(ArrayUtil.range(0, shapeOf.getInt(i2)));
                } else if (iNDArrayIndexArr[i2] instanceof NDArrayIndexEmpty) {
                    iNDArrayIndexArr2[i2] = new SpecifiedIndex(new int[0]);
                } else if (iNDArrayIndexArr[i2] instanceof IntervalIndex) {
                    IntervalIndex intervalIndex = (IntervalIndex) iNDArrayIndexArr[i2];
                    iNDArrayIndexArr2[i2] = new SpecifiedIndex(ArrayUtil.range(0, intervalIndex.end(), intervalIndex.stride()));
                }
            }
            return iNDArrayIndexArr2;
        }
        int rank = Shape.rank(dataBuffer);
        DataBuffer shapeOf2 = Shape.shapeOf(dataBuffer);
        if (iNDArrayIndexArr.length >= rank || (Shape.isVector(dataBuffer) && iNDArrayIndexArr.length == 1)) {
            if (Shape.isRowVectorShape(dataBuffer) && iNDArrayIndexArr.length == 1) {
                INDArrayIndex[] iNDArrayIndexArr3 = new INDArrayIndex[2];
                iNDArrayIndexArr3[0] = point(0);
                iNDArrayIndexArr3[1] = validate((1 == shapeOf2.getInt(0L) && rank == 2) ? shapeOf2.getInt(1L) : shapeOf2.getInt(0L), iNDArrayIndexArr[0]);
                return iNDArrayIndexArr3;
            }
            ArrayList arrayList = new ArrayList(iNDArrayIndexArr.length);
            for (int i3 = 0; i3 < iNDArrayIndexArr.length; i3++) {
                if (i3 < rank) {
                    arrayList.add(validate(shapeOf2.getInt(i3), iNDArrayIndexArr[i3]));
                } else {
                    arrayList.add(iNDArrayIndexArr[i3]);
                }
            }
            return (INDArrayIndex[]) arrayList.toArray(new INDArrayIndex[arrayList.size()]);
        }
        ArrayList arrayList2 = new ArrayList(iNDArrayIndexArr.length + 1);
        int i4 = 0;
        if (Shape.isMatrix(shapeOf2) && iNDArrayIndexArr.length == 1) {
            arrayList2.add(validate(shapeOf2.getInt(0L), iNDArrayIndexArr[0]));
            arrayList2.add(all());
        } else {
            for (int i5 = 0; i5 < iNDArrayIndexArr.length; i5++) {
                arrayList2.add(validate(shapeOf2.getInt(i5), iNDArrayIndexArr[i5]));
                if (iNDArrayIndexArr[i5] instanceof NewAxis) {
                    i4++;
                }
            }
        }
        int i6 = rank + i4;
        while (arrayList2.size() < i6) {
            arrayList2.add(all());
        }
        return (INDArrayIndex[]) arrayList2.toArray(new INDArrayIndex[arrayList2.size()]);
    }

    public static INDArrayIndex[] resolve(int[] iArr, INDArrayIndex... iNDArrayIndexArr) {
        if (iNDArrayIndexArr.length >= iArr.length || (Shape.isVector(iArr) && iNDArrayIndexArr.length == 1)) {
            if (Shape.isRowVectorShape(iArr) && iNDArrayIndexArr.length == 1) {
                INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[2];
                iNDArrayIndexArr2[0] = point(0);
                iNDArrayIndexArr2[1] = validate((1 == iArr[0] && iArr.length == 2) ? iArr[1] : iArr[0], iNDArrayIndexArr[0]);
                return iNDArrayIndexArr2;
            }
            ArrayList arrayList = new ArrayList(iNDArrayIndexArr.length);
            for (int i = 0; i < iNDArrayIndexArr.length; i++) {
                if (i < iArr.length) {
                    arrayList.add(validate(iArr[i], iNDArrayIndexArr[i]));
                } else {
                    arrayList.add(iNDArrayIndexArr[i]);
                }
            }
            return (INDArrayIndex[]) arrayList.toArray(new INDArrayIndex[arrayList.size()]);
        }
        ArrayList arrayList2 = new ArrayList(iNDArrayIndexArr.length + 1);
        int i2 = 0;
        if (Shape.isMatrix(iArr) && iNDArrayIndexArr.length == 1) {
            arrayList2.add(validate(iArr[0], iNDArrayIndexArr[0]));
            arrayList2.add(all());
        } else {
            for (int i3 = 0; i3 < iNDArrayIndexArr.length; i3++) {
                arrayList2.add(validate(iArr[i3], iNDArrayIndexArr[i3]));
                if (iNDArrayIndexArr[i3] instanceof NewAxis) {
                    i2++;
                }
            }
        }
        int length = iArr.length + i2;
        while (arrayList2.size() < length) {
            arrayList2.add(all());
        }
        return (INDArrayIndex[]) arrayList2.toArray(new INDArrayIndex[arrayList2.size()]);
    }

    protected static INDArrayIndex validate(int i, INDArrayIndex iNDArrayIndex) {
        if (((iNDArrayIndex instanceof IntervalIndex) || (iNDArrayIndex instanceof PointIndex)) && i <= iNDArrayIndex.current() && i > 1) {
            throw new IllegalArgumentException("NDArrayIndex is out of range. Beginning index: " + iNDArrayIndex.current() + " must be less than its size: " + i);
        }
        if ((iNDArrayIndex instanceof IntervalIndex) && i < iNDArrayIndex.end()) {
            iNDArrayIndex = interval(((IntervalIndex) iNDArrayIndex).begin, iNDArrayIndex.stride(), i);
        }
        return iNDArrayIndex;
    }

    public static INDArrayIndex[] resolve(INDArrayIndex[] iNDArrayIndexArr, INDArrayIndex... iNDArrayIndexArr2) {
        INDArrayIndex[] iNDArrayIndexArr3 = new INDArrayIndex[iNDArrayIndexArr.length + numNewAxis(iNDArrayIndexArr2)];
        Arrays.fill(iNDArrayIndexArr3, all());
        for (int i = 0; i < iNDArrayIndexArr.length && i < iNDArrayIndexArr2.length; i++) {
            if (iNDArrayIndexArr2[i] instanceof NDArrayIndex) {
                NDArrayIndex nDArrayIndex = (NDArrayIndex) iNDArrayIndexArr2[i];
                if (nDArrayIndex.indices.length == 1) {
                    iNDArrayIndexArr2[i] = new PointIndex(nDArrayIndex.indices[0]);
                }
            }
            iNDArrayIndexArr3[i] = iNDArrayIndexArr2[i];
        }
        return iNDArrayIndexArr3;
    }

    public static int numNewAxis(INDArrayIndex... iNDArrayIndexArr) {
        int i = 0;
        for (INDArrayIndex iNDArrayIndex : iNDArrayIndexArr) {
            if (iNDArrayIndex instanceof NewAxis) {
                i++;
            }
        }
        return i;
    }

    public static INDArrayIndex[] allFor(INDArray iNDArray) {
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[iNDArray.rank()];
        for (int i = 0; i < iNDArrayIndexArr.length; i++) {
            iNDArrayIndexArr[i] = all();
        }
        return iNDArrayIndexArr;
    }

    public static INDArrayIndex[] createCoveringShape(int[] iArr) {
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[iArr.length];
        for (int i = 0; i < iNDArrayIndexArr.length; i++) {
            iNDArrayIndexArr[i] = interval(0, iArr[i]);
        }
        return iNDArrayIndexArr;
    }

    public static INDArrayIndex[] rangeOfLength(INDArrayIndex[] iNDArrayIndexArr) {
        INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[iNDArrayIndexArr.length];
        for (int i = 0; i < iNDArrayIndexArr.length; i++) {
            iNDArrayIndexArr2[i] = interval(0, iNDArrayIndexArr[i].length());
        }
        return iNDArrayIndexArr2;
    }

    public static INDArrayIndex[] create(INDArray iNDArray) {
        if (!iNDArray.isMatrix()) {
            if (iNDArray.isVector()) {
                return new NDArrayIndex[]{new NDArrayIndex(NDArrayUtil.toInts(iNDArray))};
            }
            throw new IllegalArgumentException("Passed in ndarray must be a matrix or a vector");
        }
        NDArrayIndex[] nDArrayIndexArr = new NDArrayIndex[iNDArray.rows()];
        for (int i = 0; i < iNDArray.rows(); i++) {
            INDArray row = iNDArray.getRow(i);
            int[] iArr = new int[iNDArray.getRow(i).columns()];
            for (int i2 = 0; i2 < row.columns(); i2++) {
                iArr[i2] = (int) row.getFloat(i2);
            }
            nDArrayIndexArr[i] = new NDArrayIndex(iArr);
        }
        return nDArrayIndexArr;
    }

    public static INDArrayIndex interval(int i, int i2, int i3) {
        if (Math.abs(i - i3) < 1) {
            i3++;
        }
        if (i2 > 1 && Math.abs(i - i3) == 1) {
            i3 *= i2;
        }
        return interval(i, i2, i3, false);
    }

    public static INDArrayIndex interval(int i, int i2, int i3, boolean z) {
        if (!$assertionsDisabled && i > i3) {
            throw new AssertionError("Beginning index in range must be less than end");
        }
        IntervalIndex intervalIndex = new IntervalIndex(z, i2);
        intervalIndex.init(i, i3);
        return intervalIndex;
    }

    public static INDArrayIndex interval(int i, int i2) {
        return interval(i, 1, i2, false);
    }

    public static INDArrayIndex interval(int i, int i2, boolean z) {
        return interval(i, 1, i2, z);
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public int end() {
        if (this.indices == null || this.indices.length <= 0) {
            return 0;
        }
        return this.indices[this.indices.length - 1];
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public int offset() {
        if (this.indices.length < 1) {
            return 0;
        }
        return this.indices[0];
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public int length() {
        return this.indices.length;
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public int stride() {
        return 1;
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public int current() {
        return 0;
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public boolean hasNext() {
        return false;
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public int next() {
        return 0;
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public void reverse() {
        ArrayUtil.reverse(this.indices);
    }

    public String toString() {
        return "NDArrayIndex{indices=" + Arrays.toString(this.indices) + '}';
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        return (obj instanceof INDArrayIndex) && Arrays.equals(this.indices, ((NDArrayIndex) obj).indices);
    }

    public int hashCode() {
        return Arrays.hashCode(this.indices);
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public boolean isInterval() {
        return this.isInterval;
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public void setInterval(boolean z) {
        this.isInterval = z;
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public void init(INDArray iNDArray, int i, int i2) {
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public void init(INDArray iNDArray, int i) {
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public void init(int i, int i2) {
    }

    @Override // org.nd4j.linalg.indexing.INDArrayIndex
    public void reset() {
    }

    static {
        $assertionsDisabled = !NDArrayIndex.class.desiredAssertionStatus();
        EMPTY = new NDArrayIndexEmpty();
        NEW_AXIS = new NewAxis();
    }
}
