package org.numenta.nupic.algorithms;

import chaschev.lang.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.ComputeCycle;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.DistalDendrite;
import org.numenta.nupic.model.Synapse;
import org.numenta.nupic.monitor.ComputeDecorator;
import org.numenta.nupic.util.GroupBy2;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;

/* loaded from: input_file:org/numenta/nupic/algorithms/TemporalMemory.class */
public class TemporalMemory implements ComputeDecorator, Serializable {
    private static final long serialVersionUID = 1;
    private static final double EPSILON = 1.0E-5d;
    private static final int ACTIVE_COLUMNS = 1;

    /* loaded from: input_file:org/numenta/nupic/algorithms/TemporalMemory$ColumnData.class */
    public static class ColumnData implements Serializable {
        private static final long serialVersionUID = 1;
        Tuple t;

        public ColumnData() {
        }

        public ColumnData(Tuple tuple) {
            this.t = tuple;
        }

        public Column column() {
            return (Column) this.t.get(0);
        }

        public List<Column> activeColumns() {
            return (List) this.t.get(1);
        }

        public List<DistalDendrite> activeSegments() {
            return ((List) this.t.get(2)).get(0).equals(GroupBy2.Slot.empty()) ? Collections.emptyList() : (List) this.t.get(2);
        }

        public List<DistalDendrite> matchingSegments() {
            return ((List) this.t.get(3)).get(0).equals(GroupBy2.Slot.empty()) ? Collections.emptyList() : (List) this.t.get(3);
        }

        public ColumnData set(Tuple tuple) {
            this.t = tuple;
            return this;
        }

        public boolean isNotNone(int i) {
            return !((List) this.t.get(i)).get(0).equals(GroupBy2.Slot.NONE);
        }
    }

    public static void init(Connections connections) {
        SparseObjectMatrix<Column> sparseObjectMatrix = connections.getMemory() == null ? new SparseObjectMatrix<>(connections.getColumnDimensions()) : connections.getMemory();
        connections.setMemory(sparseObjectMatrix);
        int maxIndex = sparseObjectMatrix.getMaxIndex() + 1;
        connections.setNumColumns(maxIndex);
        int cellsPerColumn = connections.getCellsPerColumn();
        Cell[] cellArr = new Cell[maxIndex * cellsPerColumn];
        Column object = sparseObjectMatrix.getObject(0);
        for (int i = 0; i < maxIndex; i++) {
            Column column = object == null ? new Column(cellsPerColumn, i) : sparseObjectMatrix.getObject(i);
            for (int i2 = 0; i2 < cellsPerColumn; i2++) {
                cellArr[(i * cellsPerColumn) + i2] = column.getCell(i2);
            }
            if (object == null) {
                sparseObjectMatrix.set(i, (int) column);
            }
        }
        connections.setCells(cellArr);
    }

    @Override // org.numenta.nupic.monitor.ComputeDecorator
    public ComputeCycle compute(Connections connections, int[] iArr, boolean z) {
        ComputeCycle computeCycle = new ComputeCycle();
        activateCells(connections, computeCycle, iArr, z);
        activateDendrites(connections, computeCycle, z);
        return computeCycle;
    }

    public void activateCells(Connections connections, ComputeCycle computeCycle, int[] iArr, boolean z) {
        ColumnData columnData = new ColumnData();
        Set<Cell> activeCells = connections.getActiveCells();
        Set<Cell> winnerCells = connections.getWinnerCells();
        List list = (List) Arrays.stream(iArr).sorted().mapToObj(i -> {
            return connections.getColumn(i);
        }).collect(Collectors.toList());
        Function identity = Function.identity();
        Function function = distalDendrite -> {
            return distalDendrite.getParentCell().getColumn();
        };
        GroupBy2 of = GroupBy2.of(new Pair(list, identity), new Pair(new ArrayList(connections.getActiveSegments()), function), new Pair(new ArrayList(connections.getMatchingSegments()), function));
        double permanenceIncrement = connections.getPermanenceIncrement();
        double permanenceDecrement = connections.getPermanenceDecrement();
        Iterator<Tuple> it = of.iterator();
        while (it.hasNext()) {
            columnData = columnData.set(it.next());
            if (columnData.isNotNone(1)) {
                if (columnData.activeSegments().isEmpty()) {
                    Tuple burstColumn = burstColumn(connections, columnData.column(), columnData.matchingSegments(), activeCells, winnerCells, permanenceIncrement, permanenceDecrement, connections.getRandom(), z);
                    computeCycle.activeCells.addAll((List) burstColumn.get(0));
                    computeCycle.winnerCells.add((Cell) burstColumn.get(1));
                } else {
                    List<Cell> activatePredictedColumn = activatePredictedColumn(connections, columnData.activeSegments(), columnData.matchingSegments(), activeCells, winnerCells, permanenceIncrement, permanenceDecrement, z);
                    computeCycle.activeCells.addAll(activatePredictedColumn);
                    computeCycle.winnerCells.addAll(activatePredictedColumn);
                }
            } else if (z) {
                punishPredictedColumn(connections, columnData.activeSegments(), columnData.matchingSegments(), activeCells, winnerCells, connections.getPredictedSegmentDecrement());
            }
        }
    }

    public void activateDendrites(Connections connections, ComputeCycle computeCycle, boolean z) {
        Connections.Activity computeActivity = connections.computeActivity(computeCycle.activeCells, connections.getConnectedPermanence());
        List<DistalDendrite> list = (List) IntStream.range(0, computeActivity.numActiveConnected.length).filter(i -> {
            return computeActivity.numActiveConnected[i] >= connections.getActivationThreshold();
        }).mapToObj(i2 -> {
            return connections.segmentForFlatIdx(i2);
        }).collect(Collectors.toList());
        List<DistalDendrite> list2 = (List) IntStream.range(0, computeActivity.numActiveConnected.length).filter(i3 -> {
            return computeActivity.numActivePotential[i3] >= connections.getMinThreshold();
        }).mapToObj(i4 -> {
            return connections.segmentForFlatIdx(i4);
        }).collect(Collectors.toList());
        Collections.sort(list, connections.segmentPositionSortKey);
        Collections.sort(list2, connections.segmentPositionSortKey);
        computeCycle.activeSegments = list;
        computeCycle.matchingSegments = list2;
        connections.lastActivity = computeActivity;
        connections.setActiveCells(new LinkedHashSet(computeCycle.activeCells));
        connections.setWinnerCells(new LinkedHashSet(computeCycle.winnerCells));
        connections.setActiveSegments(list);
        connections.setMatchingSegments(list2);
        connections.clearPredictiveCells();
        connections.getPredictiveCells();
        if (z) {
            list.stream().forEach(distalDendrite -> {
                connections.recordSegmentActivity(distalDendrite);
            });
            connections.startNewIteration();
        }
    }

    @Override // org.numenta.nupic.monitor.ComputeDecorator
    public void reset(Connections connections) {
        connections.getActiveCells().clear();
        connections.getWinnerCells().clear();
        connections.getActiveSegments().clear();
        connections.getMatchingSegments().clear();
    }

    public List<Cell> activatePredictedColumn(Connections connections, List<DistalDendrite> list, List<DistalDendrite> list2, Set<Cell> set, Set<Cell> set2, double d, double d2, boolean z) {
        ArrayList arrayList = new ArrayList();
        Cell cell = null;
        for (DistalDendrite distalDendrite : list) {
            Cell parentCell = distalDendrite.getParentCell();
            if (parentCell != cell) {
                arrayList.add(parentCell);
                cell = parentCell;
            }
            if (z) {
                adaptSegment(connections, distalDendrite, set, d, d2);
                int maxNewSynapseCount = connections.getMaxNewSynapseCount() - connections.getLastActivity().numActivePotential[distalDendrite.getIndex()];
                if (maxNewSynapseCount > 0) {
                    growSynapses(connections, set2, distalDendrite, connections.getInitialPermanence(), maxNewSynapseCount, connections.getRandom());
                }
            }
        }
        return arrayList;
    }

    public Tuple burstColumn(Connections connections, Column column, List<DistalDendrite> list, Set<Cell> set, Set<Cell> set2, double d, double d2, Random random, boolean z) {
        Cell leastUsedCell;
        int min;
        List<Cell> cells = column.getCells();
        if (list.isEmpty()) {
            leastUsedCell = leastUsedCell(connections, cells, random);
            if (z && (min = Math.min(connections.getMaxNewSynapseCount(), set2.size())) > 0) {
                growSynapses(connections, set2, connections.createSegment(leastUsedCell), connections.getInitialPermanence(), min, random);
            }
        } else {
            int[] iArr = connections.getLastActivity().numActivePotential;
            DistalDendrite distalDendrite = list.stream().max((distalDendrite2, distalDendrite3) -> {
                return iArr[distalDendrite2.getIndex()] - iArr[distalDendrite3.getIndex()];
            }).get();
            leastUsedCell = distalDendrite.getParentCell();
            if (z) {
                adaptSegment(connections, distalDendrite, set, d, d2);
                int maxNewSynapseCount = connections.getMaxNewSynapseCount() - iArr[distalDendrite.getIndex()];
                if (maxNewSynapseCount > 0) {
                    growSynapses(connections, set2, distalDendrite, connections.getInitialPermanence(), maxNewSynapseCount, random);
                }
            }
        }
        return new Tuple(cells, leastUsedCell);
    }

    public void punishPredictedColumn(Connections connections, List<DistalDendrite> list, List<DistalDendrite> list2, Set<Cell> set, Set<Cell> set2, double d) {
        if (d > 0.0d) {
            Iterator<DistalDendrite> it = list2.iterator();
            while (it.hasNext()) {
                adaptSegment(connections, it.next(), set, -connections.getPredictedSegmentDecrement(), 0.0d);
            }
        }
    }

    public Cell leastUsedCell(Connections connections, List<Cell> list, Random random) {
        ArrayList arrayList = new ArrayList();
        int i = Integer.MAX_VALUE;
        for (Cell cell : list) {
            int numSegments = connections.numSegments(cell);
            if (numSegments < i) {
                i = numSegments;
                arrayList.clear();
            }
            if (numSegments == i) {
                arrayList.add(cell);
            }
        }
        return (Cell) arrayList.get(random.nextInt(arrayList.size()));
    }

    public void growSynapses(Connections connections, Set<Cell> set, DistalDendrite distalDendrite, double d, int i, Random random) {
        ArrayList arrayList = new ArrayList(set);
        Collections.sort(arrayList);
        Iterator<Synapse> it = connections.getSynapses(distalDendrite).iterator();
        while (it.hasNext()) {
            int indexOf = arrayList.indexOf(it.next().getPresynapticCell());
            if (indexOf != -1) {
                arrayList.remove(indexOf);
            }
        }
        int size = arrayList.size();
        int i2 = i < size ? i : size;
        for (int i3 = 0; i3 < i2; i3++) {
            int nextInt = random.nextInt(arrayList.size());
            connections.createSynapse(distalDendrite, (Cell) arrayList.get(nextInt), d);
            arrayList.remove(nextInt);
        }
    }

    public void adaptSegment(Connections connections, DistalDendrite distalDendrite, Set<Cell> set, double d, double d2) {
        ArrayList arrayList = new ArrayList();
        for (Synapse synapse : connections.getSynapses(distalDendrite)) {
            double permanence = synapse.getPermanence();
            double d3 = set.contains(synapse.getPresynapticCell()) ? permanence + d : permanence - d2;
            double d4 = d3 < 0.0d ? 0.0d : d3 > 1.0d ? 1.0d : d3;
            if (d4 < EPSILON) {
                arrayList.add(synapse);
            } else {
                synapse.setPermanence(connections, d4);
            }
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            connections.destroySynapse((Synapse) it.next());
        }
        if (connections.numSynapses(distalDendrite) == 0) {
            connections.destroySegment(distalDendrite);
        }
    }
}
