package org.numenta.nupic.algorithms;

import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import no.uib.cipr.matrix.sparse.FlexCompRowMatrix;
import org.numenta.nupic.model.Persistable;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Deque;
import org.numenta.nupic.util.Tuple;

/* loaded from: input_file:org/numenta/nupic/algorithms/SDRClassifier.class */
public class SDRClassifier implements Persistable {
    private static final long serialVersionUID = 1;
    int verbosity;
    double alpha;
    double actValueAlpha;
    int learnIteration;
    int recordNumMinusLearnIteration;
    int maxInputIdx;
    int maxBucketIdx;
    Map<Integer, FlexCompRowMatrix> weightMatrix;
    TIntList steps;
    Deque<Tuple> patternNZHistory;
    List<?> actualValues;
    String g_debugPrefix;

    public SDRClassifier() {
        this(new TIntArrayList(new int[]{1}), 0.001d, 0.3d, 0);
    }

    public SDRClassifier(TIntList tIntList, double d, double d2, int i) {
        this.verbosity = 0;
        this.alpha = 0.001d;
        this.actValueAlpha = 0.3d;
        this.recordNumMinusLearnIteration = -1;
        this.maxInputIdx = 0;
        this.weightMatrix = new HashMap();
        this.steps = new TIntArrayList();
        this.actualValues = new ArrayList();
        this.g_debugPrefix = "SDRClassifier";
        this.steps = tIntList;
        this.alpha = d;
        this.actValueAlpha = d2;
        this.verbosity = i;
        this.actualValues.add(null);
        this.patternNZHistory = new Deque<>(ArrayUtils.max(tIntList.toArray()) + 1);
        for (int i2 : tIntList.toArray()) {
            this.weightMatrix.put(Integer.valueOf(i2), new FlexCompRowMatrix(this.maxBucketIdx + 1, this.maxInputIdx + 1));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> Classification<T> compute(int i, Map<String, Object> map, int[] iArr, boolean z, boolean z2) {
        List<?> list = this.actualValues;
        if (this.recordNumMinusLearnIteration == -1) {
            this.recordNumMinusLearnIteration = i - this.learnIteration;
        }
        this.learnIteration = i - this.recordNumMinusLearnIteration;
        if (this.verbosity >= 1) {
            System.out.println(String.format("\n%s: compute ", this.g_debugPrefix));
            System.out.printf("recordNum: %d\n", Integer.valueOf(i));
            System.out.printf("learnIteration: %d\n", Integer.valueOf(this.learnIteration));
            System.out.printf("patternNZ (%d): %s\n", Integer.valueOf(iArr.length), ArrayUtils.intArrayToString(iArr));
            System.out.println("classificationIn: " + map);
        }
        this.patternNZHistory.append(new Tuple(Integer.valueOf(this.learnIteration), iArr));
        if (ArrayUtils.max(iArr) > this.maxInputIdx) {
            int max = ArrayUtils.max(iArr);
            for (int i2 : this.steps.toArray()) {
                for (int i3 = this.maxInputIdx; i3 < max; i3++) {
                    this.weightMatrix.get(Integer.valueOf(i2)).addCol(new double[this.maxBucketIdx + 1]);
                }
            }
            this.maxInputIdx = max;
        }
        Classification<T> infer = z2 ? infer(iArr, map) : null;
        if (z && map.get("bucketIdx") != null) {
            int intValue = ((Integer) map.get("bucketIdx")).intValue();
            Object obj = map.get("actValue");
            if (intValue > this.maxBucketIdx) {
                for (int i4 : this.steps.toArray()) {
                    for (int i5 = this.maxBucketIdx; i5 < intValue; i5++) {
                        this.weightMatrix.get(Integer.valueOf(i4)).addRow(new double[this.maxInputIdx + 1]);
                    }
                }
                this.maxBucketIdx = intValue;
            }
            while (this.maxBucketIdx > list.size() - 1) {
                list.add(null);
            }
            if (list.get(intValue) == null) {
                list.set(intValue, obj);
            } else if (Number.class.isAssignableFrom(obj.getClass())) {
                list.set(intValue, Double.valueOf(((1.0d - this.actValueAlpha) * ((Number) list.get(intValue)).doubleValue()) + (this.actValueAlpha * ((Number) obj).doubleValue())));
            } else {
                list.set(intValue, obj);
            }
            Iterator<Tuple> it = this.patternNZHistory.iterator();
            while (it.hasNext()) {
                Tuple next = it.next();
                int intValue2 = ((Integer) next.get(0)).intValue();
                int[] iArr2 = (int[]) next.get(1);
                Map<Integer, double[]> calculateError = calculateError(map);
                int i6 = this.learnIteration - intValue2;
                if (this.steps.contains(i6)) {
                    for (int i7 = 0; i7 <= this.maxBucketIdx; i7++) {
                        for (int i8 : iArr2) {
                            this.weightMatrix.get(Integer.valueOf(i6)).add(i7, i8, this.alpha * calculateError.get(Integer.valueOf(i6))[i7]);
                        }
                    }
                }
            }
        }
        if (z2 && this.verbosity >= 1) {
            System.out.println(" inference: combined bucket likelihoods:");
            System.out.println("   actual bucket values: " + Arrays.toString(infer.getActualValues()));
            for (int i9 : infer.stepSet()) {
                if (infer.getActualValue(i9) != null) {
                    System.out.println(String.format("  %d steps: ", Integer.valueOf(i9), pFormatArray(new Object[]{infer.getActualValue(i9)})));
                    int mostProbableBucketIndex = infer.getMostProbableBucketIndex(i9);
                    System.out.println(String.format("   most likely bucket idx: %d, value: %s ", Integer.valueOf(mostProbableBucketIndex), infer.getActualValue(mostProbableBucketIndex)));
                }
            }
        }
        return infer;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <T> Classification<T> infer(int[] iArr, Map<String, Object> map) {
        Classification<T> classification = (Classification<T>) new Classification();
        Object obj = (this.steps.get(0) == 0 || map == null) ? 0 : map.get("actValue");
        Object[] objArr = new Object[this.actualValues.size()];
        for (int i = 0; i < this.actualValues.size(); i++) {
            objArr[i] = this.actualValues.get(i) == null ? obj : this.actualValues.get(i);
        }
        classification.setActualValues(objArr);
        for (int i2 : this.steps.toArray()) {
            classification.setStats(i2, inferSingleStep(iArr, this.weightMatrix.get(Integer.valueOf(i2))));
        }
        return classification;
    }

    private double[] inferSingleStep(int[] iArr, FlexCompRowMatrix flexCompRowMatrix) {
        double[] dArr = new double[this.maxBucketIdx + 1];
        for (int i = 0; i <= this.maxBucketIdx; i++) {
            for (int i2 : iArr) {
                int i3 = i;
                dArr[i3] = dArr[i3] + flexCompRowMatrix.get(i, i2);
            }
        }
        double[] dArr2 = new double[dArr.length];
        for (int i4 = 0; i4 < dArr2.length; i4++) {
            dArr2[i4] = Math.exp(dArr[i4]);
        }
        double[] dArr3 = new double[dArr.length];
        for (int i5 = 0; i5 < dArr3.length; i5++) {
            dArr3[i5] = dArr2[i5] / ArrayUtils.sum(dArr2);
        }
        return dArr3;
    }

    private Map<Integer, double[]> calculateError(Map<String, Object> map) {
        HashMap hashMap = new HashMap();
        new int[this.maxBucketIdx + 1][((Integer) map.get("bucketIdx")).intValue()] = 1;
        Iterator<Tuple> it = this.patternNZHistory.iterator();
        while (it.hasNext()) {
            Tuple next = it.next();
            int intValue = ((Integer) next.get(0)).intValue();
            int[] iArr = (int[]) next.get(1);
            int i = this.learnIteration - intValue;
            if (this.steps.contains(i)) {
                double[] inferSingleStep = inferSingleStep(iArr, this.weightMatrix.get(Integer.valueOf(i)));
                double[] dArr = new double[this.maxBucketIdx + 1];
                for (int i2 = 0; i2 <= this.maxBucketIdx; i2++) {
                    dArr[i2] = r0[i2] - inferSingleStep[i2];
                }
                hashMap.put(Integer.valueOf(i), dArr);
            }
        }
        return hashMap;
    }

    private <T> String pFormatArray(T[] tArr) {
        if (tArr == null) {
            return "";
        }
        StringBuilder sb = new StringBuilder("[ ");
        for (T t : tArr) {
            sb.append(String.format("%.2s", t));
        }
        sb.append(" ]");
        return sb.toString();
    }
}
