package org.numenta.nupic.encoders;

import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/numenta/nupic/encoders/SDRCategoryEncoder.class */
public class SDRCategoryEncoder extends Encoder<String> {
    private static final long serialVersionUID = 1;
    private static final Logger LOG;
    private Random random;
    private int thresholdOverlap;
    private final SDRByCategoryMap sdrByCategory;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/numenta/nupic/encoders/SDRCategoryEncoder$Builder.class */
    public static final class Builder extends Encoder.Builder<Builder, SDRCategoryEncoder> {
        private List<String> categoryList = new ArrayList();
        private int encoderSeed = 1;

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public SDRCategoryEncoder build() {
            if (this.n == 0) {
                throw new IllegalStateException("\"N\" should be set");
            }
            if (this.w == 0) {
                throw new IllegalStateException("\"W\" should be set");
            }
            if (this.categoryList == null) {
                throw new IllegalStateException("Category List cannot be null");
            }
            SDRCategoryEncoder sDRCategoryEncoder = new SDRCategoryEncoder();
            sDRCategoryEncoder.init(this.n, this.w, this.categoryList, this.name, this.encoderSeed, this.forced);
            return sDRCategoryEncoder;
        }

        public Builder categoryList(List<String> list) {
            this.categoryList = list;
            return this;
        }

        public Builder encoderSeed(int i) {
            this.encoderSeed = i;
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public Builder radius(double d) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public Builder resolution(double d) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public Builder periodic(boolean z) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public Builder clipInput(boolean z) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public Builder maxVal(double d) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public Builder minVal(double d) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/numenta/nupic/encoders/SDRCategoryEncoder$SDRByCategoryMap.class */
    public static final class SDRByCategoryMap extends LinkedHashMap<String, int[]> {
        private SDRByCategoryMap() {
        }

        public int[] getSdr(int i) {
            Map.Entry<String, int[]> entry = getEntry(i);
            if (entry == null) {
                return null;
            }
            return entry.getValue();
        }

        public String getCategory(int i) {
            Map.Entry<String, int[]> entry = getEntry(i);
            if (entry == null) {
                return null;
            }
            return entry.getKey();
        }

        public int getIndexByCategory(String str) {
            int i = 0;
            Iterator<String> it = keySet().iterator();
            while (it.hasNext()) {
                if (it.next().equals(str)) {
                    return i;
                }
                i++;
            }
            return 0;
        }

        private Map.Entry<String, int[]> getEntry(int i) {
            Set<Map.Entry<String, int[]>> entrySet = entrySet();
            if (i < 0 || i > entrySet.size()) {
                throw new IllegalArgumentException("Index should be in following range:[0," + entrySet.size() + "]");
            }
            int i2 = 0;
            for (Map.Entry<String, int[]> entry : entrySet) {
                int i3 = i2;
                i2++;
                if (i3 == i) {
                    return entry;
                }
            }
            return null;
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    private SDRCategoryEncoder() {
        this.sdrByCategory = new SDRByCategoryMap();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void init(int i, int i2, List<String> list, String str, int i3, boolean z) {
        this.n = i;
        this.w = i2;
        this.encLearningEnabled = true;
        this.random = new Random();
        if (i3 != -1) {
            this.random.setSeed(i3);
        }
        if (!z) {
            if (i / i2 < 2) {
                throw new IllegalArgumentException(String.format("Number of ON bits in SDR (%d) must be much smaller than the output width (%d)", Integer.valueOf(i2), Integer.valueOf(i)));
            }
            if (i2 < 21) {
                throw new IllegalArgumentException(String.format("Number of bits in the SDR (%d) must be greater than 2, and should be >= 21, pass forced=True to init() to override this check", Integer.valueOf(i2)));
            }
        }
        this.thresholdOverlap = ((int) ((i2 * (this.w / this.n)) + this.w)) / 2;
        if (this.thresholdOverlap < this.w - 3) {
            this.thresholdOverlap = this.w - 3;
        }
        this.description.add(new Tuple(str, 0));
        this.name = str;
        addCategory("<UNKNOWN>");
        if (list == null || list.size() == 0) {
            setLearningEnabled(true);
            return;
        }
        setLearningEnabled(false);
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            addCategory(it.next());
        }
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public int getWidth() {
        return getN();
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public boolean isDelta() {
        return false;
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public void encodeIntoArray(String str, int[] iArr) {
        int i;
        if (str == null || str.isEmpty()) {
            Arrays.fill(iArr, 0);
            i = 0;
        } else {
            i = getBucketIndices(str)[0];
            int[] sdr = this.sdrByCategory.getSdr(i);
            System.arraycopy(sdr, 0, iArr, 0, sdr.length);
        }
        LOG.trace("input:" + str + ", index:" + i + ", output:" + ArrayUtils.intArrayToString(iArr));
        LOG.trace("decoded:" + decodedToStr(decode(iArr, "")));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public Set<FieldMetaType> getDecoderOutputFieldTypes() {
        return new HashSet(Arrays.asList(FieldMetaType.LIST, FieldMetaType.STRING));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public int[] getBucketIndices(String str) {
        return new int[]{(int) getScalars(str).get(0)};
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.numenta.nupic.encoders.Encoder
    public <S> TDoubleList getScalars(S s) {
        String str = (String) s;
        int i = 0;
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        if (str == null || str.isEmpty()) {
            tDoubleArrayList.add(0.0d);
            return tDoubleArrayList;
        }
        if (this.sdrByCategory.containsKey(s)) {
            i = this.sdrByCategory.getIndexByCategory(str);
        } else if (isEncoderLearningEnabled()) {
            i = this.sdrByCategory.size();
            addCategory(str);
        }
        tDoubleArrayList.add(i);
        return tDoubleArrayList;
    }

    public DecodeResult decode(int[] iArr) {
        return decode(iArr, (String) null);
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public DecodeResult decode(int[] iArr, String str) {
        if (!$assertionsDisabled && !ArrayUtils.all(iArr, new Condition.Adapter<Integer>() { // from class: org.numenta.nupic.encoders.SDRCategoryEncoder.1
            @Override // org.numenta.nupic.util.Condition.Adapter, org.numenta.nupic.util.Condition
            public boolean eval(int i) {
                return i <= 1;
            }
        })) {
            throw new AssertionError();
        }
        int[] iArr2 = new int[this.sdrByCategory.size()];
        for (int i = 0; i < this.sdrByCategory.size(); i++) {
            int[] sdr = this.sdrByCategory.getSdr(i);
            for (int i2 = 0; i2 < sdr.length; i2++) {
                if (sdr[i2] == iArr[i2] && iArr[i2] == 1) {
                    int i3 = i;
                    iArr2[i3] = iArr2[i3] + 1;
                }
            }
        }
        LOG.trace("Overlaps for decoding:");
        if (LOG.isTraceEnabled()) {
            int i4 = 0;
            Iterator<String> it = this.sdrByCategory.keySet().iterator();
            while (it.hasNext()) {
                LOG.trace(iArr2[i4] + " " + it.next());
                i4++;
            }
        }
        int[] where = ArrayUtils.where(iArr2, (Condition) new Condition.Adapter<Integer>() { // from class: org.numenta.nupic.encoders.SDRCategoryEncoder.2
            @Override // org.numenta.nupic.util.Condition.Adapter, org.numenta.nupic.util.Condition
            public boolean eval(int i5) {
                return i5 > SDRCategoryEncoder.this.thresholdOverlap;
            }
        });
        StringBuilder sb = new StringBuilder();
        ArrayList arrayList = new ArrayList();
        for (int i5 : where) {
            if (sb.length() != 0) {
                sb.append(" ");
            }
            sb.append(this.sdrByCategory.getCategory(i5));
            arrayList.add(new MinMax(i5, i5));
        }
        String name = (str == null || str.isEmpty()) ? getName() : String.format("%s.%s", str, getName());
        HashMap hashMap = new HashMap();
        hashMap.put(name, new RangeList(arrayList, sb.toString()));
        return new DecodeResult(hashMap, Arrays.asList(name));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public List<Encoding> topDownCompute(int[] iArr) {
        if (this.sdrByCategory.size() == 0) {
            return new ArrayList();
        }
        return getEncoderResultsByIndex(getTopDownMapping(), ArrayUtils.argmax(rightVecProd(getTopDownMapping(), iArr)));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public List<Encoding> getBucketInfo(int[] iArr) {
        if (this.sdrByCategory.size() == 0) {
            return new ArrayList();
        }
        return getEncoderResultsByIndex(getTopDownMapping(), iArr[0]);
    }

    public SparseObjectMatrix<int[]> getTopDownMapping() {
        if (this.topDownMapping == null) {
            this.topDownMapping = new SparseObjectMatrix<>(new int[]{this.sdrByCategory.size()});
            int[] iArr = new int[getN()];
            int i = 0;
            Iterator<String> it = this.sdrByCategory.keySet().iterator();
            while (it.hasNext()) {
                encodeIntoArray(it.next(), iArr);
                this.topDownMapping.set(i, (int) Arrays.copyOf(iArr, iArr.length));
                i++;
            }
        }
        return this.topDownMapping;
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public <S> List<S> getBucketValues(Class<S> cls) {
        return new ArrayList(this.sdrByCategory.keySet());
    }

    public Collection<int[]> getSDRs() {
        return Collections.unmodifiableCollection(this.sdrByCategory.values());
    }

    private List<Encoding> getEncoderResultsByIndex(SparseObjectMatrix<int[]> sparseObjectMatrix, int i) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Encoding(this.sdrByCategory.getCategory(i), Integer.valueOf(i), sparseObjectMatrix.getObject(i)));
        return arrayList;
    }

    private void addCategory(String str) {
        if (this.sdrByCategory.containsKey(str)) {
            throw new IllegalArgumentException(String.format("Attempt to add encoder category '%s' that already exists", str));
        }
        this.sdrByCategory.put(str, newRep());
        this.topDownMapping = null;
    }

    private int[] getSortedSample(int i, int i2) {
        TIntHashSet tIntHashSet = new TIntHashSet();
        while (tIntHashSet.size() < i2) {
            tIntHashSet.add(this.random.nextInt(i));
        }
        int[] array = tIntHashSet.toArray();
        Arrays.sort(array);
        return array;
    }

    private int[] newRep() {
        boolean z = true;
        int[] iArr = new int[this.n];
        for (int i = 0; i < 1000; i++) {
            z = true;
            int[] sortedSample = getSortedSample(this.n, this.w);
            iArr = new int[this.n];
            for (int i2 : sortedSample) {
                iArr[i2] = 1;
            }
            Iterator<int[]> it = this.sdrByCategory.values().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (Arrays.equals(iArr, it.next())) {
                    z = false;
                    break;
                }
            }
            if (z) {
                break;
            }
        }
        if (z) {
            return iArr;
        }
        throw new RuntimeException(String.format("Error, could not find unique pattern %d after %d attempts", Integer.valueOf(this.sdrByCategory.size()), 1000));
    }

    static {
        $assertionsDisabled = !SDRCategoryEncoder.class.desiredAssertionStatus();
        LOG = LoggerFactory.getLogger(SDRCategoryEncoder.class);
    }
}
