package org.numenta.nupic.research;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.numenta.nupic.Connections;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.DistalDendrite;
import org.numenta.nupic.model.Synapse;
import org.numenta.nupic.util.SparseObjectMatrix;

/* loaded from: input_file:org/numenta/nupic/research/TemporalMemory.class */
public class TemporalMemory {
    public void init(Connections connections) {
        SparseObjectMatrix<Column> sparseObjectMatrix = connections.getMemory() == null ? new SparseObjectMatrix<>(connections.getColumnDimensions()) : connections.getMemory();
        connections.setMemory(sparseObjectMatrix);
        int maxIndex = sparseObjectMatrix.getMaxIndex() + 1;
        int cellsPerColumn = connections.getCellsPerColumn();
        Cell[] cellArr = new Cell[maxIndex * cellsPerColumn];
        Column object = sparseObjectMatrix.getObject(0);
        for (int i = 0; i < maxIndex; i++) {
            Column column = object == null ? new Column(cellsPerColumn, i) : sparseObjectMatrix.getObject(i);
            for (int i2 = 0; i2 < cellsPerColumn; i2++) {
                cellArr[(i * cellsPerColumn) + i2] = column.getCell(i2);
            }
            if (object == null) {
                sparseObjectMatrix.set(i, (int) column);
            }
        }
        connections.setCells(cellArr);
    }

    public ComputeCycle compute(Connections connections, int[] iArr, boolean z) {
        ComputeCycle computeFn = computeFn(connections, connections.getColumnSet(iArr), new LinkedHashSet(connections.getPredictiveCells()), new LinkedHashSet(connections.getActiveSegments()), new LinkedHashMap(connections.getActiveSynapsesForSegment()), new LinkedHashSet(connections.getWinnerCells()), z);
        connections.setActiveCells(computeFn.activeCells());
        connections.setWinnerCells(computeFn.winnerCells());
        connections.setPredictiveCells(computeFn.predictiveCells());
        connections.setSuccessfullyPredictedColumns(computeFn.successfullyPredictedColumns());
        connections.setActiveSegments(computeFn.activeSegments());
        connections.setLearningSegments(computeFn.learningSegments());
        connections.setActiveSynapsesForSegment(computeFn.activeSynapsesForSegment());
        return computeFn;
    }

    public ComputeCycle computeFn(Connections connections, Set<Column> set, Set<Cell> set2, Set<DistalDendrite> set3, Map<DistalDendrite, Set<Synapse>> map, Set<Cell> set4, boolean z) {
        ComputeCycle computeCycle = new ComputeCycle();
        activateCorrectlyPredictiveCells(computeCycle, set2, set);
        burstColumns(computeCycle, connections, set, computeCycle.successfullyPredictedColumns, map);
        if (z) {
            learnOnSegments(connections, set3, computeCycle.learningSegments, map, computeCycle.winnerCells, set4);
        }
        computeCycle.activeSynapsesForSegment = computeActiveSynapses(connections, computeCycle.activeCells);
        computePredictiveCells(connections, computeCycle, computeCycle.activeSynapsesForSegment);
        return computeCycle;
    }

    public void activateCorrectlyPredictiveCells(ComputeCycle computeCycle, Set<Cell> set, Set<Column> set2) {
        for (Cell cell : set) {
            Column parentColumn = cell.getParentColumn();
            if (set2.contains(parentColumn)) {
                computeCycle.activeCells.add(cell);
                computeCycle.winnerCells.add(cell);
                computeCycle.successfullyPredictedColumns.add(parentColumn);
            }
        }
    }

    public void burstColumns(ComputeCycle computeCycle, Connections connections, Set<Column> set, Set<Column> set2, Map<DistalDendrite, Set<Synapse>> map) {
        set.removeAll(set2);
        for (Column column : set) {
            computeCycle.activeCells.addAll(column.getCells());
            Object[] bestMatchingCell = getBestMatchingCell(connections, column, map);
            DistalDendrite distalDendrite = (DistalDendrite) bestMatchingCell[0];
            Cell cell = (Cell) bestMatchingCell[1];
            if (cell != null) {
                computeCycle.winnerCells.add(cell);
            }
            int segmentCount = connections.getSegmentCount();
            if (distalDendrite == null) {
                distalDendrite = cell.createSegment(connections, segmentCount);
                connections.setSegmentCount(segmentCount + 1);
            }
            computeCycle.learningSegments.add(distalDendrite);
        }
    }

    public void learnOnSegments(Connections connections, Set<DistalDendrite> set, Set<DistalDendrite> set2, Map<DistalDendrite, Set<Synapse>> map, Set<Cell> set3, Set<Cell> set4) {
        double permanenceIncrement = connections.getPermanenceIncrement();
        double permanenceDecrement = connections.getPermanenceDecrement();
        ArrayList<DistalDendrite> arrayList = new ArrayList(set);
        arrayList.addAll(set2);
        for (DistalDendrite distalDendrite : arrayList) {
            boolean contains = set2.contains(distalDendrite);
            boolean contains2 = set3.contains(distalDendrite.getParentCell());
            Set<Synapse> connectedActiveSynapses = distalDendrite.getConnectedActiveSynapses(map, 0.0d);
            if (contains || contains2) {
                distalDendrite.adaptSegment(connections, connectedActiveSynapses, permanenceIncrement, permanenceDecrement);
            }
            int synapseCount = connections.getSynapseCount();
            int maxNewSynapseCount = connections.getMaxNewSynapseCount() - connectedActiveSynapses.size();
            if (contains && maxNewSynapseCount > 0) {
                Iterator<Cell> it = distalDendrite.pickCellsToLearnOn(connections, maxNewSynapseCount, set4, connections.getRandom()).iterator();
                while (it.hasNext()) {
                    distalDendrite.createSynapse(connections, it.next(), connections.getInitialPermanence(), synapseCount);
                    synapseCount++;
                }
                connections.setSynapseCount(synapseCount);
            }
        }
    }

    public void computePredictiveCells(Connections connections, ComputeCycle computeCycle, Map<DistalDendrite, Set<Synapse>> map) {
        for (DistalDendrite distalDendrite : map.keySet()) {
            if (distalDendrite.getConnectedActiveSynapses(map, connections.getConnectedPermanence()).size() >= connections.getActivationThreshold()) {
                computeCycle.activeSegments.add(distalDendrite);
                computeCycle.predictiveCells.add(distalDendrite.getParentCell());
            }
        }
    }

    public Map<DistalDendrite, Set<Synapse>> computeActiveSynapses(Connections connections, Set<Cell> set) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<Cell> it = set.iterator();
        while (it.hasNext()) {
            for (Synapse synapse : it.next().getReceptorSynapses(connections)) {
                Set set2 = (Set) linkedHashMap.get(synapse.getSegment());
                Set set3 = set2;
                if (set2 == null) {
                    DistalDendrite distalDendrite = (DistalDendrite) synapse.getSegment();
                    LinkedHashSet linkedHashSet = new LinkedHashSet();
                    set3 = linkedHashSet;
                    linkedHashMap.put(distalDendrite, linkedHashSet);
                }
                set3.add(synapse);
            }
        }
        return linkedHashMap;
    }

    public void reset(Connections connections) {
        connections.getActiveCells().clear();
        connections.getPredictiveCells().clear();
        connections.getActiveSegments().clear();
        connections.getActiveSynapsesForSegment().clear();
        connections.getWinnerCells().clear();
    }

    public Object[] getBestMatchingCell(Connections connections, Column column, Map<DistalDendrite, Set<Synapse>> map) {
        Object[] objArr = new Object[2];
        Cell cell = null;
        DistalDendrite distalDendrite = null;
        int i = 0;
        for (Cell cell2 : column.getCells()) {
            DistalDendrite bestMatchingSegment = getBestMatchingSegment(connections, cell2, map);
            if (bestMatchingSegment != null) {
                Set<Synapse> connectedActiveSynapses = bestMatchingSegment.getConnectedActiveSynapses(map, 0.0d);
                if (connectedActiveSynapses.size() > i) {
                    i = connectedActiveSynapses.size();
                    cell = cell2;
                    distalDendrite = bestMatchingSegment;
                }
            }
        }
        if (cell == null) {
            cell = column.getLeastUsedCell(connections, connections.getRandom());
        }
        objArr[0] = distalDendrite;
        objArr[1] = cell;
        return objArr;
    }

    public DistalDendrite getBestMatchingSegment(Connections connections, Cell cell, Map<DistalDendrite, Set<Synapse>> map) {
        int minThreshold = connections.getMinThreshold();
        DistalDendrite distalDendrite = null;
        for (DistalDendrite distalDendrite2 : cell.getSegments(connections)) {
            Set<Synapse> connectedActiveSynapses = distalDendrite2.getConnectedActiveSynapses(map, 0.0d);
            if (connectedActiveSynapses.size() >= minThreshold) {
                minThreshold = connectedActiveSynapses.size();
                distalDendrite = distalDendrite2;
            }
        }
        return distalDendrite;
    }

    protected int columnForCell(Connections connections, int i) {
        return i / connections.getCellsPerColumn();
    }

    public Cell getCell(Connections connections, int i) {
        return connections.getCells()[i];
    }

    public LinkedHashSet<Cell> getCells(Connections connections, int[] iArr) {
        LinkedHashSet<Cell> linkedHashSet = new LinkedHashSet<>();
        for (int i : iArr) {
            linkedHashSet.add(getCell(connections, i));
        }
        return linkedHashSet;
    }

    public LinkedHashSet<Column> getColumns(Connections connections, int[] iArr) {
        return connections.getColumnSet(iArr);
    }
}
