package org.numenta.nupic.algorithms;

import gnu.trove.map.hash.TObjectIntHashMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.numenta.nupic.ComputeCycle;
import org.numenta.nupic.Connections;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.DistalDendrite;
import org.numenta.nupic.model.Synapse;
import org.numenta.nupic.monitor.ComputeDecorator;
import org.numenta.nupic.util.SparseObjectMatrix;

/* loaded from: input_file:org/numenta/nupic/algorithms/TemporalMemory.class */
public class TemporalMemory implements ComputeDecorator {

    /* loaded from: input_file:org/numenta/nupic/algorithms/TemporalMemory$BurstResult.class */
    class BurstResult {
        Set<Cell> activeCells;
        Set<Cell> winnerCells;
        Set<DistalDendrite> learningSegments;

        public BurstResult(Set<Cell> set, Set<Cell> set2, Set<DistalDendrite> set3) {
            this.activeCells = set;
            this.winnerCells = set2;
            this.learningSegments = set3;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/numenta/nupic/algorithms/TemporalMemory$CellSearch.class */
    public class CellSearch {
        Cell bestCell;
        DistalDendrite bestSegment;

        public CellSearch(Cell cell, DistalDendrite distalDendrite) {
            this.bestCell = cell;
            this.bestSegment = distalDendrite;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/numenta/nupic/algorithms/TemporalMemory$SegmentSearch.class */
    public class SegmentSearch {
        DistalDendrite bestSegment;
        int numActiveSynapses;

        public SegmentSearch(DistalDendrite distalDendrite, int i) {
            this.bestSegment = distalDendrite;
            this.numActiveSynapses = i;
        }
    }

    public void init(Connections connections) {
        SparseObjectMatrix<Column> sparseObjectMatrix = connections.getMemory() == null ? new SparseObjectMatrix<>(connections.getColumnDimensions()) : connections.getMemory();
        connections.setMemory(sparseObjectMatrix);
        int maxIndex = sparseObjectMatrix.getMaxIndex() + 1;
        connections.setNumColumns(maxIndex);
        int cellsPerColumn = connections.getCellsPerColumn();
        Cell[] cellArr = new Cell[maxIndex * cellsPerColumn];
        Column object = sparseObjectMatrix.getObject(0);
        for (int i = 0; i < maxIndex; i++) {
            Column column = object == null ? new Column(cellsPerColumn, i) : sparseObjectMatrix.getObject(i);
            for (int i2 = 0; i2 < cellsPerColumn; i2++) {
                cellArr[(i * cellsPerColumn) + i2] = column.getCell(i2);
            }
            if (object == null) {
                sparseObjectMatrix.set(i, (int) column);
            }
        }
        connections.setCells(cellArr);
    }

    @Override // org.numenta.nupic.monitor.ComputeDecorator
    public ComputeCycle compute(Connections connections, int[] iArr, boolean z) {
        ComputeCycle computeFn = computeFn(connections, connections.getColumnSet(iArr), connections.getPredictiveCells(), connections.getActiveSegments(), connections.getActiveCells(), connections.getWinnerCells(), connections.getMatchingSegments(), connections.getMatchingCells(), z);
        connections.setActiveCells(computeFn.activeCells);
        connections.setWinnerCells(computeFn.winnerCells);
        connections.setPredictiveCells(computeFn.predictiveCells);
        connections.setSuccessfullyPredictedColumns(computeFn.successfullyPredictedColumns);
        connections.setActiveSegments(computeFn.activeSegments);
        connections.setLearningSegments(computeFn.learningSegments);
        connections.setMatchingSegments(computeFn.matchingSegments);
        connections.setMatchingCells(computeFn.matchingCells);
        return computeFn;
    }

    public ComputeCycle computeFn(Connections connections, Set<Column> set, Set<Cell> set2, Set<DistalDendrite> set3, Set<Cell> set4, Set<Cell> set5, Set<DistalDendrite> set6, Set<Cell> set7, boolean z) {
        ComputeCycle computeCycle = new ComputeCycle();
        activateCorrectlyPredictiveCells(connections, computeCycle, set2, set7, set);
        burstColumns(computeCycle, connections, set, computeCycle.successfullyPredictedColumns, set4, set5);
        if (z) {
            learnOnSegments(connections, set3, computeCycle.learningSegments, set4, computeCycle.winnerCells, set5, computeCycle.predictedInactiveCells, set6);
        }
        computePredictiveCells(connections, computeCycle, computeCycle.activeCells);
        return computeCycle;
    }

    public void activateCorrectlyPredictiveCells(Connections connections, ComputeCycle computeCycle, Set<Cell> set, Set<Cell> set2, Set<Column> set3) {
        for (Cell cell : set) {
            Column column = cell.getColumn();
            if (set3.contains(column)) {
                computeCycle.activeCells.add(cell);
                computeCycle.winnerCells.add(cell);
                computeCycle.successfullyPredictedColumns.add(column);
            }
        }
        if (connections.getPredictedSegmentDecrement() > 0.0d) {
            for (Cell cell2 : set2) {
                if (!set3.contains(cell2.getColumn())) {
                    computeCycle.predictedInactiveCells.add(cell2);
                }
            }
        }
    }

    public void burstColumns(ComputeCycle computeCycle, Connections connections, Set<Column> set, Set<Column> set2, Set<Cell> set3, Set<Cell> set4) {
        set.removeAll(set2);
        Iterator<Column> it = set.iterator();
        while (it.hasNext()) {
            List<Cell> cells = it.next().getCells();
            computeCycle.activeCells.addAll(cells);
            CellSearch bestMatchingCell = getBestMatchingCell(connections, cells, set3);
            computeCycle.winnerCells.add(bestMatchingCell.bestCell);
            DistalDendrite distalDendrite = bestMatchingCell.bestSegment;
            if (distalDendrite == null && set4.size() > 0) {
                distalDendrite = bestMatchingCell.bestCell.createSegment(connections);
            }
            if (distalDendrite != null) {
                computeCycle.learningSegments.add(distalDendrite);
            }
        }
    }

    public void learnOnSegments(Connections connections, Set<DistalDendrite> set, Set<DistalDendrite> set2, Set<Cell> set3, Set<Cell> set4, Set<Cell> set5, Set<Cell> set6, Set<DistalDendrite> set7) {
        double permanenceIncrement = connections.getPermanenceIncrement();
        double permanenceDecrement = connections.getPermanenceDecrement();
        HashSet<DistalDendrite> hashSet = new HashSet(set);
        hashSet.addAll(set2);
        for (DistalDendrite distalDendrite : hashSet) {
            boolean contains = set2.contains(distalDendrite);
            boolean contains2 = set4.contains(distalDendrite.getParentCell());
            Set<Synapse> activeSynapses = distalDendrite.getActiveSynapses(connections, set3);
            if (contains || contains2) {
                distalDendrite.adaptSegment(connections, activeSynapses, permanenceIncrement, permanenceDecrement);
            }
            int maxNewSynapseCount = connections.getMaxNewSynapseCount() - activeSynapses.size();
            if (contains && maxNewSynapseCount > 0) {
                Iterator<Cell> it = distalDendrite.pickCellsToLearnOn(connections, maxNewSynapseCount, set5, connections.getRandom()).iterator();
                while (it.hasNext()) {
                    distalDendrite.createSynapse(connections, it.next(), connections.getInitialPermanence());
                }
            }
        }
        if (connections.getPredictedSegmentDecrement() > 0.0d) {
            for (DistalDendrite distalDendrite2 : set7) {
                if (set6.contains(distalDendrite2.getParentCell())) {
                    distalDendrite2.adaptSegment(connections, distalDendrite2.getActiveSynapses(connections, set3), -connections.getPredictedSegmentDecrement(), 0.0d);
                }
            }
        }
    }

    public void computePredictiveCells(Connections connections, ComputeCycle computeCycle, Set<Cell> set) {
        TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap();
        TObjectIntHashMap tObjectIntHashMap2 = new TObjectIntHashMap();
        Iterator<Cell> it = set.iterator();
        while (it.hasNext()) {
            for (Synapse synapse : connections.getReceptorSynapses(it.next())) {
                DistalDendrite distalDendrite = (DistalDendrite) synapse.getSegment();
                double permanence = synapse.getPermanence();
                if (permanence >= connections.getConnectedPermanence()) {
                    tObjectIntHashMap.adjustOrPutValue(distalDendrite, 1, 1);
                    if (tObjectIntHashMap.get(distalDendrite) >= connections.getActivationThreshold()) {
                        computeCycle.activeSegments.add(distalDendrite);
                        computeCycle.predictiveCells.add(distalDendrite.getParentCell());
                    }
                }
                if (permanence > 0.0d && connections.getPredictedSegmentDecrement() > 0.0d) {
                    tObjectIntHashMap2.adjustOrPutValue(distalDendrite, 1, 1);
                    if (tObjectIntHashMap2.get(distalDendrite) >= connections.getMinThreshold()) {
                        computeCycle.matchingSegments.add(distalDendrite);
                        computeCycle.matchingCells.add(distalDendrite.getParentCell());
                    }
                }
            }
        }
    }

    @Override // org.numenta.nupic.monitor.ComputeDecorator
    public void reset(Connections connections) {
        connections.getActiveCells().clear();
        connections.getPredictiveCells().clear();
        connections.getActiveSegments().clear();
        connections.getWinnerCells().clear();
        connections.getMatchingCells().clear();
        connections.getMatchingSegments().clear();
    }

    public CellSearch getBestMatchingCell(Connections connections, List<Cell> list, Set<Cell> set) {
        int i = 0;
        Cell cell = null;
        DistalDendrite distalDendrite = null;
        for (Cell cell2 : list) {
            SegmentSearch bestMatchingSegment = getBestMatchingSegment(connections, cell2, set);
            if (bestMatchingSegment.bestSegment != null && bestMatchingSegment.numActiveSynapses > i) {
                i = bestMatchingSegment.numActiveSynapses;
                cell = cell2;
                distalDendrite = bestMatchingSegment.bestSegment;
            }
        }
        if (cell == null) {
            cell = getLeastUsedCell(connections, list);
        }
        return new CellSearch(cell, distalDendrite);
    }

    public SegmentSearch getBestMatchingSegment(Connections connections, Cell cell, Set<Cell> set) {
        int minThreshold = connections.getMinThreshold();
        DistalDendrite distalDendrite = null;
        int i = 0;
        for (DistalDendrite distalDendrite2 : connections.getSegments(cell)) {
            int i2 = 0;
            for (Synapse synapse : connections.getSynapses(distalDendrite2)) {
                if (set.contains(synapse.getPresynapticCell()) && synapse.getPermanence() > 0.0d) {
                    i2++;
                }
            }
            if (i2 >= minThreshold) {
                minThreshold = i2;
                distalDendrite = distalDendrite2;
                i = i2;
            }
        }
        return new SegmentSearch(distalDendrite, i);
    }

    public Cell getLeastUsedCell(Connections connections, List<Cell> list) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        int i = Integer.MAX_VALUE;
        for (Cell cell : list) {
            int size = connections.getSegments(cell).size();
            if (size < i) {
                i = size;
                linkedHashSet.clear();
            }
            if (size == i) {
                linkedHashSet.add(cell);
            }
        }
        int nextInt = connections.getRandom().nextInt(linkedHashSet.size());
        ArrayList arrayList = new ArrayList(linkedHashSet);
        Collections.sort(arrayList);
        return (Cell) arrayList.get(nextInt);
    }
}
