package org.numenta.nupic.research;

import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.numenta.nupic.Connections;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.Pool;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.SparseBinaryMatrix;
import org.numenta.nupic.util.SparseMatrix;
import org.numenta.nupic.util.SparseObjectMatrix;

/* JADX WARN: Classes with same name are omitted:
  input_file:org/numenta/nupic/examples/cortical_io/breakingnews/breaking-news-demo-1.0.0.jar:org/numenta/nupic/research/SpatialPooler.class
  input_file:org/numenta/nupic/examples/cortical_io/foxeats/FoxEatsDemo.jar:org/numenta/nupic/research/SpatialPooler.class
 */
/* loaded from: input_file:org/numenta/nupic/examples/napi/hotgym/NAPI-Hotgym-Demo-1.0.jar:org/numenta/nupic/research/SpatialPooler.class */
public class SpatialPooler {
    public void init(Connections connections) {
        initMatrices(connections);
        connectAndConfigureInputs(connections);
    }

    public void initMatrices(Connections connections) {
        SparseObjectMatrix<Column> sparseObjectMatrix;
        SparseObjectMatrix<Column> memory = connections.getMemory();
        if (memory == null) {
            sparseObjectMatrix = new SparseObjectMatrix<>(connections.getColumnDimensions());
            memory = sparseObjectMatrix;
        } else {
            sparseObjectMatrix = memory;
        }
        connections.setMemory(sparseObjectMatrix);
        connections.setInputMatrix(new SparseBinaryMatrix(connections.getInputDimensions()));
        int maxIndex = connections.getInputMatrix().getMaxIndex() + 1;
        int maxIndex2 = connections.getMemory().getMaxIndex() + 1;
        connections.setNumInputs(maxIndex);
        connections.setNumColumns(maxIndex2);
        for (int i = 0; i < maxIndex2; i++) {
            memory.set(i, (int) new Column(connections.getCellsPerColumn(), i));
        }
        connections.setPotentialPools(new SparseObjectMatrix<>(connections.getMemory().getDimensions()));
        connections.setConnectedMatrix(new SparseBinaryMatrix(new int[]{maxIndex2, maxIndex}));
        double[] dArr = new double[maxIndex2];
        for (int i2 = 0; i2 < maxIndex2; i2++) {
            dArr[i2] = 0.01d * connections.getRandom().nextDouble();
        }
        connections.setTieBreaker(dArr);
        connections.setOverlapDutyCycles(new double[maxIndex2]);
        connections.setActiveDutyCycles(new double[maxIndex2]);
        connections.setMinOverlapDutyCycles(new double[maxIndex2]);
        connections.setMinActiveDutyCycles(new double[maxIndex2]);
        connections.setBoostFactors(new double[maxIndex2]);
        Arrays.fill(connections.getBoostFactors(), 1.0d);
    }

    public void connectAndConfigureInputs(Connections connections) {
        int numColumns = connections.getNumColumns();
        for (int i = 0; i < numColumns; i++) {
            int[] mapPotential = mapPotential(connections, i, true);
            Column column = connections.getColumn(i);
            connections.getPotentialPools().set(i, (int) column.createPotentialPool(connections, mapPotential));
            updatePermanencesForColumn(connections, initPermanence(connections, mapPotential, i, connections.getInitConnectedPct()), column, mapPotential, true);
        }
        updateInhibitionRadius(connections);
    }

    public void compute(Connections connections, int[] iArr, int[] iArr2, boolean z, boolean z2) {
        if (iArr.length != connections.getNumInputs()) {
            throw new IllegalArgumentException("Input array must be same size as the defined number of inputs: From Params: " + connections.getNumInputs() + ", From Input Vector: " + iArr.length);
        }
        updateBookeepingVars(connections, z);
        int[] calculateOverlap = calculateOverlap(connections, iArr);
        int[] inhibitColumns = inhibitColumns(connections, z ? ArrayUtils.multiply(connections.getBoostFactors(), calculateOverlap) : ArrayUtils.toDoubleArray(calculateOverlap));
        if (z) {
            adaptSynapses(connections, iArr, inhibitColumns);
            updateDutyCycles(connections, calculateOverlap, inhibitColumns);
            bumpUpWeakColumns(connections);
            updateBoostFactors(connections);
            if (isUpdateRound(connections)) {
                updateInhibitionRadius(connections);
                updateMinDutyCycles(connections);
            }
        } else if (z2) {
            inhibitColumns = stripUnlearnedColumns(connections, inhibitColumns).toArray();
        }
        Arrays.fill(iArr2, 0);
        if (inhibitColumns.length > 0) {
            ArrayUtils.setIndexesTo(iArr2, inhibitColumns, 1);
        }
    }

    public TIntArrayList stripUnlearnedColumns(Connections connections, int[] iArr) {
        TIntHashSet tIntHashSet = new TIntHashSet(iArr);
        TIntHashSet tIntHashSet2 = new TIntHashSet();
        int numColumns = connections.getNumColumns();
        double[] activeDutyCycles = connections.getActiveDutyCycles();
        for (int i = 0; i < numColumns; i++) {
            if (activeDutyCycles[i] <= 0.0d) {
                tIntHashSet2.add(i);
            }
        }
        tIntHashSet.removeAll(tIntHashSet2);
        TIntArrayList tIntArrayList = new TIntArrayList(tIntHashSet);
        tIntArrayList.sort();
        return tIntArrayList;
    }

    public void updateMinDutyCycles(Connections connections) {
        if (connections.getGlobalInhibition() || connections.getInhibitionRadius() > connections.getNumInputs()) {
            updateMinDutyCyclesGlobal(connections);
        } else {
            updateMinDutyCyclesLocal(connections);
        }
    }

    public void updateMinDutyCyclesGlobal(Connections connections) {
        Arrays.fill(connections.getMinOverlapDutyCycles(), connections.getMinPctOverlapDutyCycles() * ArrayUtils.max(connections.getOverlapDutyCycles()));
        Arrays.fill(connections.getMinActiveDutyCycles(), connections.getMinPctActiveDutyCycles() * ArrayUtils.max(connections.getActiveDutyCycles()));
    }

    public void updateMinDutyCyclesLocal(Connections connections) {
        int numColumns = connections.getNumColumns();
        for (int i = 0; i < numColumns; i++) {
            int[] array = getNeighborsND(connections, i, connections.getMemory(), connections.getInhibitionRadius(), true).toArray();
            connections.getMinOverlapDutyCycles()[i] = ArrayUtils.max(ArrayUtils.sub(connections.getOverlapDutyCycles(), array)) * connections.getMinPctOverlapDutyCycles();
            connections.getMinActiveDutyCycles()[i] = ArrayUtils.max(ArrayUtils.sub(connections.getActiveDutyCycles(), array)) * connections.getMinPctActiveDutyCycles();
        }
    }

    public void updateDutyCycles(Connections connections, int[] iArr, int[] iArr2) {
        double[] dArr = new double[connections.getNumColumns()];
        double[] dArr2 = new double[connections.getNumColumns()];
        ArrayUtils.greaterThanXThanSetToY(iArr, 0, 1);
        if (iArr2.length > 0) {
            ArrayUtils.setIndexesTo(dArr2, iArr2, 1.0d);
        }
        int dutyCyclePeriod = connections.getDutyCyclePeriod();
        if (dutyCyclePeriod > connections.getIterationNum()) {
            dutyCyclePeriod = connections.getIterationNum();
        }
        connections.setOverlapDutyCycles(updateDutyCyclesHelper(connections, connections.getOverlapDutyCycles(), dArr, dutyCyclePeriod));
        connections.setActiveDutyCycles(updateDutyCyclesHelper(connections, connections.getActiveDutyCycles(), dArr2, dutyCyclePeriod));
    }

    public double[] updateDutyCyclesHelper(Connections connections, double[] dArr, double[] dArr2, double d) {
        return ArrayUtils.divide(ArrayUtils.d_add(ArrayUtils.multiply(dArr, d - 1.0d), dArr2), d);
    }

    public double avgConnectedSpanForColumnND(Connections connections, int i) {
        int[] inputDimensions = connections.getInputDimensions();
        int[] connectedSynapsesSparse = connections.getColumn(i).getProximalDendrite().getConnectedSynapsesSparse(connections);
        if (connectedSynapsesSparse == null || connectedSynapsesSparse.length == 0) {
            return 0.0d;
        }
        int[] iArr = new int[connections.getInputDimensions().length];
        int[] iArr2 = new int[connections.getInputDimensions().length];
        Arrays.fill(iArr, -1);
        Arrays.fill(iArr2, ArrayUtils.max(inputDimensions));
        SparseMatrix<?> inputMatrix = connections.getInputMatrix();
        for (int i2 = 0; i2 < connectedSynapsesSparse.length; i2++) {
            iArr = ArrayUtils.maxBetween(iArr, inputMatrix.computeCoordinates(connectedSynapsesSparse[i2]));
            iArr2 = ArrayUtils.minBetween(iArr2, inputMatrix.computeCoordinates(connectedSynapsesSparse[i2]));
        }
        return ArrayUtils.average(ArrayUtils.add(ArrayUtils.subtract(iArr, iArr2), 1));
    }

    public void updateInhibitionRadius(Connections connections) {
        if (connections.getGlobalInhibition()) {
            connections.setInhibitionRadius(ArrayUtils.max(connections.getColumnDimensions()));
            return;
        }
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        int numColumns = connections.getNumColumns();
        for (int i = 0; i < numColumns; i++) {
            tDoubleArrayList.add(avgConnectedSpanForColumnND(connections, i));
        }
        connections.setInhibitionRadius((int) Math.round(Math.max(1.0d, ((ArrayUtils.average(tDoubleArrayList.toArray()) * avgColumnsPerInput(connections)) - 1.0d) / 2.0d)));
    }

    public double avgColumnsPerInput(Connections connections) {
        return ArrayUtils.average(ArrayUtils.divide(ArrayUtils.toDoubleArray(Arrays.copyOf(connections.getColumnDimensions(), connections.getColumnDimensions().length)), ArrayUtils.toDoubleArray(Arrays.copyOf(connections.getInputDimensions(), connections.getInputDimensions().length)), 0.0d, 0.0d));
    }

    public void adaptSynapses(Connections connections, int[] iArr, int[] iArr2) {
        int[] where = ArrayUtils.where(iArr, (Condition) ArrayUtils.INT_GREATER_THAN_0);
        double[] dArr = new double[connections.getNumInputs()];
        Arrays.fill(dArr, (-1.0d) * connections.getSynPermInactiveDec());
        ArrayUtils.setIndexesTo(dArr, where, connections.getSynPermActiveInc());
        for (int i = 0; i < iArr2.length; i++) {
            Pool object = connections.getPotentialPools().getObject(iArr2[i]);
            double[] densePermanences = object.getDensePermanences(connections);
            int[] sparseConnections = object.getSparseConnections();
            ArrayUtils.raiseValuesBy(dArr, densePermanences);
            updatePermanencesForColumn(connections, densePermanences, connections.getColumn(iArr2[i]), sparseConnections, true);
        }
    }

    public void bumpUpWeakColumns(final Connections connections) {
        int[] where = ArrayUtils.where(connections.getMemory().get1DIndexes(), (Condition) new Condition.Adapter<Integer>() { // from class: org.numenta.nupic.research.SpatialPooler.1
            @Override // org.numenta.nupic.util.Condition.Adapter, org.numenta.nupic.util.Condition
            public boolean eval(int i) {
                return connections.getOverlapDutyCycles()[i] < connections.getMinOverlapDutyCycles()[i];
            }
        });
        for (int i = 0; i < where.length; i++) {
            Pool object = connections.getPotentialPools().getObject(where[i]);
            double[] sparsePermanences = object.getSparsePermanences();
            ArrayUtils.raiseValuesBy(connections.getSynPermBelowStimulusInc(), sparsePermanences);
            updatePermanencesForColumnSparse(connections, sparsePermanences, connections.getColumn(where[i]), object.getSparseConnections(), true);
        }
    }

    public void raisePermanenceToThreshold(Connections connections, double[] dArr, int[] iArr) {
        ArrayUtils.clip(dArr, connections.getSynPermMin(), connections.getSynPermMax());
        while (ArrayUtils.valueGreaterCountAtIndex(connections.getSynPermConnected(), dArr, iArr) < connections.getStimulusThreshold()) {
            ArrayUtils.raiseValuesBy(connections.getSynPermBelowStimulusInc(), dArr, iArr);
        }
    }

    public void raisePermanenceToThresholdSparse(Connections connections, double[] dArr) {
        ArrayUtils.clip(dArr, connections.getSynPermMin(), connections.getSynPermMax());
        while (ArrayUtils.valueGreaterCount(connections.getSynPermConnected(), dArr) < connections.getStimulusThreshold()) {
            ArrayUtils.raiseValuesBy(connections.getSynPermBelowStimulusInc(), dArr);
        }
    }

    public void updatePermanencesForColumn(Connections connections, double[] dArr, Column column, int[] iArr, boolean z) {
        if (z) {
            raisePermanenceToThreshold(connections, dArr, iArr);
        }
        ArrayUtils.lessThanOrEqualXThanSetToY(dArr, connections.getSynPermTrimThreshold(), 0.0d);
        ArrayUtils.clip(dArr, connections.getSynPermMin(), connections.getSynPermMax());
        column.setProximalPermanences(connections, dArr);
    }

    public void updatePermanencesForColumnSparse(Connections connections, double[] dArr, Column column, int[] iArr, boolean z) {
        if (z) {
            raisePermanenceToThresholdSparse(connections, dArr);
        }
        ArrayUtils.lessThanOrEqualXThanSetToY(dArr, connections.getSynPermTrimThreshold(), 0.0d);
        ArrayUtils.clip(dArr, connections.getSynPermMin(), connections.getSynPermMax());
        column.setProximalPermanencesSparse(connections, dArr, iArr);
    }

    public static double initPermConnected(Connections connections) {
        return ((int) ((connections.getSynPermConnected() + ((connections.getRandom().nextDouble() * connections.getSynPermActiveInc()) / 4.0d)) * 100000.0d)) / 100000.0d;
    }

    public static double initPermNonConnected(Connections connections) {
        return ((int) ((connections.getSynPermConnected() * connections.getRandom().nextDouble()) * 100000.0d)) / 100000.0d;
    }

    public double[] initPermanence(Connections connections, int[] iArr, int i, double d) {
        int round = (int) Math.round(iArr.length * d);
        TIntHashSet tIntHashSet = new TIntHashSet();
        Random random = connections.getRandom();
        while (tIntHashSet.size() < round) {
            tIntHashSet.add(iArr[random.nextInt(iArr.length)]);
        }
        double[] dArr = new double[connections.getNumInputs()];
        for (int i2 : iArr) {
            if (tIntHashSet.contains(i2)) {
                dArr[i2] = initPermConnected(connections);
            } else {
                dArr[i2] = initPermNonConnected(connections);
            }
            dArr[i2] = dArr[i2] < connections.getSynPermTrimThreshold() ? 0.0d : dArr[i2];
        }
        connections.getColumn(i).setProximalPermanences(connections, dArr);
        return dArr;
    }

    public int mapColumn(Connections connections, int i) {
        return connections.getInputMatrix().computeIndex(ArrayUtils.clip(ArrayUtils.toIntArray(ArrayUtils.d_add(ArrayUtils.multiply(ArrayUtils.toDoubleArray(connections.getInputDimensions()), ArrayUtils.divide(ArrayUtils.toDoubleArray(connections.getMemory().computeCoordinates(i)), ArrayUtils.toDoubleArray(connections.getColumnDimensions()), 0.0d, 0.0d), 0.0d, 0.0d), ArrayUtils.multiply(ArrayUtils.divide(ArrayUtils.toDoubleArray(connections.getInputDimensions()), ArrayUtils.toDoubleArray(connections.getColumnDimensions()), 0.0d, 0.0d), 0.5d))), connections.getInputDimensions(), -1));
    }

    public int[] mapPotential(Connections connections, int i, boolean z) {
        int mapColumn = mapColumn(connections, i);
        TIntArrayList neighborsND = getNeighborsND(connections, mapColumn, connections.getInputMatrix(), connections.getPotentialRadius(), z);
        neighborsND.add(mapColumn);
        neighborsND.sort();
        return ArrayUtils.sample((int) Math.round(neighborsND.size() * connections.getPotentialPct()), neighborsND, connections.getRandom());
    }

    public TIntArrayList getNeighborsND(Connections connections, int i, SparseMatrix<?> sparseMatrix, int i2, boolean z) {
        final int[] dimensions = sparseMatrix.getDimensions();
        int[] computeCoordinates = sparseMatrix.computeCoordinates(i);
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < dimensions.length; i3++) {
            int[] range = ArrayUtils.range(computeCoordinates[i3] - i2, computeCoordinates[i3] + i2 + 1);
            int[] iArr = new int[range.length];
            if (z) {
                for (int i4 = 0; i4 < iArr.length; i4++) {
                    iArr[i4] = (int) ArrayUtils.positiveRemainder(range[i4], dimensions[i3]);
                }
            } else {
                final int i5 = i3;
                iArr = ArrayUtils.retainLogicalAnd(range, (Condition<?>[]) new Condition[]{ArrayUtils.GREATER_OR_EQUAL_0, new Condition.Adapter<Integer>() { // from class: org.numenta.nupic.research.SpatialPooler.2
                    @Override // org.numenta.nupic.util.Condition.Adapter, org.numenta.nupic.util.Condition
                    public boolean eval(int i6) {
                        return i6 < dimensions[i5];
                    }
                }});
            }
            arrayList.add(ArrayUtils.unique(iArr));
        }
        List<int[]> dimensionsToCoordinateList = ArrayUtils.dimensionsToCoordinateList(arrayList);
        TIntArrayList tIntArrayList = new TIntArrayList(dimensionsToCoordinateList.size());
        int size = dimensionsToCoordinateList.size();
        for (int i6 = 0; i6 < size; i6++) {
            int computeIndex = connections.getInputMatrix().computeIndex(dimensionsToCoordinateList.get(i6), false);
            if (computeIndex != i) {
                tIntArrayList.add(computeIndex);
            }
        }
        return tIntArrayList;
    }

    public boolean isUpdateRound(Connections connections) {
        return connections.getIterationNum() % connections.getUpdatePeriod() == 0;
    }

    public void updateBookeepingVars(Connections connections, boolean z) {
        connections.iterationNum++;
        if (z) {
            connections.iterationLearnNum++;
        }
    }

    public int[] calculateOverlap(Connections connections, int[] iArr) {
        int[] iArr2 = new int[connections.getNumColumns()];
        connections.getConnectedCounts().rightVecSumAtNZ(iArr, iArr2);
        ArrayUtils.lessThanXThanSetToY(iArr2, (int) connections.getStimulusThreshold(), 0);
        return iArr2;
    }

    public double[] calculateOverlapPct(Connections connections, int[] iArr) {
        return ArrayUtils.divide(iArr, connections.getConnectedCounts().getTrueCounts());
    }

    public int[] inhibitColumns(Connections connections, double[] dArr) {
        double[] copyOf = Arrays.copyOf(dArr, dArr.length);
        double localAreaDensity = connections.getLocalAreaDensity();
        double d = localAreaDensity;
        if (localAreaDensity <= 0.0d) {
            d = Math.min(connections.getNumActiveColumnsPerInhArea() / Math.min(connections.getNumColumns(), Math.pow((2 * connections.getInhibitionRadius()) + 1, connections.getColumnDimensions().length)), 0.5d);
        }
        ArrayUtils.d_add(copyOf, connections.getTieBreaker());
        return (connections.getGlobalInhibition() || connections.getInhibitionRadius() > ArrayUtils.max(connections.getColumnDimensions())) ? inhibitColumnsGlobal(connections, copyOf, d) : inhibitColumnsLocal(connections, copyOf, d);
    }

    public int[] inhibitColumnsGlobal(Connections connections, double[] dArr, double d) {
        int[] nGreatest = ArrayUtils.nGreatest(dArr, (int) (d * connections.getNumColumns()));
        Arrays.sort(nGreatest);
        return nGreatest;
    }

    public int[] inhibitColumnsLocal(Connections connections, double[] dArr, double d) {
        int numColumns = connections.getNumColumns();
        int[] iArr = new int[numColumns];
        double max = ArrayUtils.max(dArr) / 1000.0d;
        for (int i = 0; i < numColumns; i++) {
            if (ArrayUtils.valueGreaterCount(dArr[i], ArrayUtils.sub(dArr, getNeighborsND(connections, i, connections.getMemory(), connections.getInhibitionRadius(), false).toArray())) < ((int) (0.5d + (d * (r0.size() + 1))))) {
                iArr[i] = 1;
                int i2 = i;
                dArr[i2] = dArr[i2] + max;
            }
        }
        return ArrayUtils.where(iArr, (Condition) ArrayUtils.INT_GREATER_THAN_0);
    }

    public void updateBoostFactors(Connections connections) {
        double[] d_add;
        double[] activeDutyCycles = connections.getActiveDutyCycles();
        final double[] minActiveDutyCycles = connections.getMinActiveDutyCycles();
        if (ArrayUtils.where(minActiveDutyCycles, ArrayUtils.GREATER_THAN_0).length < 1) {
            d_add = connections.getBoostFactors();
        } else {
            double[] dArr = new double[connections.getNumColumns()];
            Arrays.fill(dArr, 1.0d - connections.getMaxBoost());
            d_add = ArrayUtils.d_add(ArrayUtils.multiply(ArrayUtils.divide(dArr, minActiveDutyCycles, 0.0d, 0.0d), activeDutyCycles, 0.0d, 0.0d), connections.getMaxBoost());
        }
        ArrayUtils.setIndexesTo(d_add, ArrayUtils.where(activeDutyCycles, new Condition.Adapter<Object>() { // from class: org.numenta.nupic.research.SpatialPooler.3
            int i = 0;

            @Override // org.numenta.nupic.util.Condition.Adapter, org.numenta.nupic.util.Condition
            public boolean eval(double d) {
                double[] dArr2 = minActiveDutyCycles;
                int i = this.i;
                this.i = i + 1;
                return d > dArr2[i];
            }
        }), 1.0d);
        connections.setBoostFactors(d_add);
    }
}
