package org.numenta.nupic.util;

import gnu.trove.TIntCollection;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.lang.reflect.Array;
import java.util.Arrays;
import org.numenta.nupic.model.Persistable;

/* loaded from: input_file:org/numenta/nupic/util/AbstractSparseBinaryMatrix.class */
public abstract class AbstractSparseBinaryMatrix extends AbstractSparseMatrix implements Persistable {
    private static final long serialVersionUID = 1;
    private int[] trueCounts;

    public AbstractSparseBinaryMatrix(int[] iArr) {
        this(iArr, false);
    }

    public AbstractSparseBinaryMatrix(int[] iArr, boolean z) {
        super(iArr, z);
        this.trueCounts = new int[iArr[0]];
    }

    public abstract Object getSlice(int... iArr);

    /* JADX INFO: Access modifiers changed from: protected */
    public void sliceError(int... iArr) {
        throw new IllegalArgumentException("This method only returns the array holding the specified maximum index: " + Arrays.toString(this.dimensions));
    }

    protected int[] getSliceIndexes(int[] iArr) {
        int[] dimensions = getDimensions();
        if (iArr.length >= dimensions.length) {
            sliceError(iArr);
        }
        int length = dimensions.length - iArr.length;
        int[] iArr2 = (int[]) Array.newInstance((Class<?>) Integer.TYPE, length);
        for (int length2 = iArr.length; length2 < dimensions.length; length2++) {
            iArr2[length2 - iArr.length] = dimensions[length2];
        }
        int[] copyOf = Arrays.copyOf(iArr, iArr.length + 1);
        int[] iArr3 = new int[Arrays.stream(iArr2).reduce((i, i2) -> {
            return i * i2;
        }).getAsInt()];
        if (iArr.length + 1 == dimensions.length) {
            for (int i3 = 0; i3 < dimensions[iArr.length]; i3++) {
                copyOf[iArr.length] = i3;
                Array.set(iArr3, i3, Integer.valueOf(computeIndex(copyOf)));
            }
        } else {
            for (int i4 = 0; i4 < dimensions[length]; i4++) {
                copyOf[iArr.length] = i4;
                int[] sliceIndexes = getSliceIndexes(copyOf);
                System.arraycopy(sliceIndexes, 0, iArr3, i4 * sliceIndexes.length, sliceIndexes.length);
            }
        }
        return iArr3;
    }

    public abstract void rightVecSumAtNZ(int[] iArr, int[] iArr2);

    public abstract void rightVecSumAtNZ(int[] iArr, int[] iArr2, double d);

    @Override // org.numenta.nupic.util.AbstractSparseMatrix
    public AbstractSparseBinaryMatrix set(int i, int i2) {
        return set(i2, computeCoordinates(i));
    }

    @Override // org.numenta.nupic.util.AbstractSparseMatrix
    public abstract AbstractSparseBinaryMatrix set(int i, int... iArr);

    public AbstractSparseBinaryMatrix set(int[] iArr, int[] iArr2) {
        for (int i = 0; i < iArr.length; i++) {
            set(iArr[i], iArr2[i]);
        }
        return this;
    }

    @Override // org.numenta.nupic.util.AbstractSparseMatrix, org.numenta.nupic.util.AbstractFlatMatrix, org.numenta.nupic.util.Matrix
    public Integer get(int... iArr) {
        return get(computeIndex(iArr));
    }

    @Override // org.numenta.nupic.util.AbstractFlatMatrix, org.numenta.nupic.util.FlatMatrix
    public abstract Integer get(int i);

    public abstract AbstractSparseBinaryMatrix setForTest(int i, int i2);

    public AbstractSparseBinaryMatrix set(int[] iArr, int[] iArr2, boolean z) {
        for (int i = 0; i < iArr.length; i++) {
            if (z) {
                setForTest(iArr[i], iArr2[i]);
            } else {
                set(iArr[i], iArr2[i]);
            }
        }
        return this;
    }

    public int getTrueCount(int i) {
        return this.trueCounts[i];
    }

    public void setTrueCount(int i, int i2) {
        this.trueCounts[i] = i2;
    }

    public int[] getTrueCounts() {
        return this.trueCounts;
    }

    public void clearStatistics(int i) {
        this.trueCounts[i] = 0;
        for (int i2 : getSliceIndexes(new int[]{i})) {
            set(i2, 0);
        }
    }

    @Override // org.numenta.nupic.util.AbstractSparseMatrix
    public int getIntValue(int... iArr) {
        return get(computeIndex(iArr)).intValue();
    }

    @Override // org.numenta.nupic.util.AbstractSparseMatrix
    public int getIntValue(int i) {
        return get(i).intValue();
    }

    @Override // org.numenta.nupic.util.AbstractSparseMatrix, org.numenta.nupic.util.SparseMatrix
    public int[] getSparseIndices() {
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (int i = 0; i <= getMaxIndex(); i++) {
            if (get(i).intValue() > 0) {
                tIntArrayList.add(i);
            }
        }
        return tIntArrayList.toArray();
    }

    public AbstractSparseBinaryMatrix or(AbstractSparseBinaryMatrix abstractSparseBinaryMatrix) {
        int[] sparseIndices = abstractSparseBinaryMatrix.getSparseIndices();
        int[] iArr = new int[sparseIndices.length];
        Arrays.fill(iArr, 1);
        return set(sparseIndices, iArr);
    }

    public AbstractSparseBinaryMatrix or(TIntCollection tIntCollection) {
        int[] iArr = new int[tIntCollection.size()];
        Arrays.fill(iArr, 1);
        return set(tIntCollection.toArray(), iArr);
    }

    public AbstractSparseBinaryMatrix or(int[] iArr) {
        int[] iArr2 = new int[iArr.length];
        Arrays.fill(iArr2, 1);
        return set(iArr, iArr2);
    }

    protected TIntSet getSparseSet() {
        return new TIntHashSet(getSparseIndices());
    }

    public boolean all(AbstractSparseBinaryMatrix abstractSparseBinaryMatrix) {
        return getSparseSet().containsAll(abstractSparseBinaryMatrix.getSparseIndices());
    }

    public boolean all(TIntCollection tIntCollection) {
        return getSparseSet().containsAll(tIntCollection);
    }

    public boolean all(int[] iArr) {
        return getSparseSet().containsAll(iArr);
    }

    public boolean any(AbstractSparseBinaryMatrix abstractSparseBinaryMatrix) {
        TIntSet sparseSet = getSparseSet();
        for (int i : abstractSparseBinaryMatrix.getSparseIndices()) {
            if (sparseSet.contains(i)) {
                return true;
            }
        }
        return false;
    }

    public boolean any(TIntList tIntList) {
        TIntSet sparseSet = getSparseSet();
        TIntIterator it = tIntList.iterator();
        while (it.hasNext()) {
            if (sparseSet.contains(it.next())) {
                return true;
            }
        }
        return false;
    }

    public boolean any(int[] iArr) {
        TIntSet sparseSet = getSparseSet();
        for (int i : iArr) {
            if (sparseSet.contains(i)) {
                return true;
            }
        }
        return false;
    }

    @Override // org.numenta.nupic.util.AbstractFlatMatrix
    public int hashCode() {
        return (31 * super.hashCode()) + Arrays.hashCode(this.trueCounts);
    }

    @Override // org.numenta.nupic.util.AbstractFlatMatrix
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        return super.equals(obj) && getClass() == obj.getClass() && Arrays.equals(this.trueCounts, ((AbstractSparseBinaryMatrix) obj).trueCounts);
    }
}
