package org.numenta.nupic.encoders;

import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/numenta/nupic/encoders/CategoryEncoder.class */
public class CategoryEncoder extends Encoder<String> {
    private static final Logger LOG = LoggerFactory.getLogger(CategoryEncoder.class);
    protected int ncategories;
    protected TObjectIntMap<String> categoryToIndex;
    protected TIntObjectMap<String> indexToCategory;
    protected List<String> categoryList;
    protected int width;
    private ScalarEncoder scalarEncoder;

    /* loaded from: input_file:org/numenta/nupic/encoders/CategoryEncoder$Builder.class */
    public static class Builder extends Encoder.Builder<Builder, CategoryEncoder> {
        private List<String> categoryList;

        private Builder() {
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public CategoryEncoder build() {
            this.encoder = new CategoryEncoder();
            super.build();
            if (this.categoryList == null) {
                throw new IllegalStateException("Category List cannot be null");
            }
            ((CategoryEncoder) this.encoder).setCategoryList(this.categoryList);
            ((CategoryEncoder) this.encoder).init();
            return (CategoryEncoder) this.encoder;
        }

        public Builder categoryList(List<String> list) {
            this.categoryList = list;
            return this;
        }
    }

    private CategoryEncoder() {
        this.categoryToIndex = new TObjectIntHashMap();
        this.indexToCategory = new TIntObjectHashMap();
    }

    public static Encoder.Builder<Builder, CategoryEncoder> builder() {
        return new Builder();
    }

    public void init() {
        this.ncategories = this.categoryList == null ? 0 : this.categoryList.size() + 1;
        this.minVal = 0.0d;
        this.maxVal = this.ncategories - 1;
        try {
            this.scalarEncoder = ScalarEncoder.builder().n(this.n).w(this.w).radius(this.radius).minVal(this.minVal).maxVal(this.maxVal).periodic(this.periodic).forced(this.forced).build();
        } catch (Exception e) {
            String message = e.getMessage();
            int indexOf = message.indexOf("ScalarEncoder");
            if (indexOf != -1) {
                throw new IllegalStateException(message.substring(0, indexOf).concat("CategoryEncoder"));
            }
        }
        this.indexToCategory.put(0, "<UNKNOWN>");
        if (this.categoryList != null && !this.categoryList.isEmpty()) {
            int size = this.categoryList.size();
            for (int i = 0; i < size; i++) {
                this.categoryToIndex.put(this.categoryList.get(i), i + 1);
                this.indexToCategory.put(i + 1, this.categoryList.get(i));
            }
        }
        int i2 = this.w * this.ncategories;
        this.n = i2;
        this.width = i2;
        this.scalarEncoder.n = this.n;
        if (getWidth() != this.width) {
            throw new IllegalStateException("Width != w (num bits to represent output item) * #categories");
        }
        this.description.add(new Tuple(this.name, 0));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public <T> TDoubleList getScalars(T t) {
        return new TDoubleArrayList(new double[]{this.categoryToIndex.get(t)});
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public int[] getBucketIndices(String str) {
        if (str == null) {
            return null;
        }
        return this.scalarEncoder.getBucketIndices(this.categoryToIndex.get(str));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public void encodeIntoArray(String str, int[] iArr) {
        Object obj = null;
        double d = 0.0d;
        if (str == null) {
            obj = "<missing>";
        } else {
            double d2 = this.categoryToIndex.get(str);
            d = d2 == ((double) this.categoryToIndex.getNoEntryValue()) ? 0.0d : d2;
            this.scalarEncoder.encodeIntoArray(Double.valueOf(d), iArr);
        }
        LOG.trace("input: {}, val: {}, value: {}, output: {}", new Object[]{str, obj, Double.valueOf(d), Arrays.toString(iArr)});
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public DecodeResult decode(int[] iArr, String str) {
        DecodeResult decode = this.scalarEncoder.decode(iArr, str);
        if (decode.getFields().size() == 0) {
            return decode;
        }
        if (decode.getFields().size() != 1) {
            throw new IllegalStateException("Expecting only one field");
        }
        Map<String, RangeList> fields = decode.getFields();
        ArrayList arrayList = new ArrayList();
        StringBuilder sb = new StringBuilder();
        Iterator<String> it = fields.keySet().iterator();
        while (it.hasNext()) {
            MinMax range = fields.get(it.next()).getRange(0);
            int round = (int) Math.round(range.min());
            int round2 = (int) Math.round(range.max());
            arrayList.add(new MinMax(round, round2));
            while (round <= round2) {
                if (sb.length() > 0) {
                    sb.append(", ");
                }
                sb.append((String) this.indexToCategory.get(round));
                round++;
            }
        }
        String format = !str.isEmpty() ? String.format("%s.%s", str, this.name) : this.name;
        HashMap hashMap = new HashMap();
        hashMap.put(format, new RangeList(arrayList, sb.toString()));
        return new DecodeResult(hashMap, Arrays.asList(format));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public TDoubleList closenessScores(TDoubleList tDoubleList, TDoubleList tDoubleList2, boolean z) {
        double d = tDoubleList.get(0) == tDoubleList2.get(0) ? 1.0d : 0.0d;
        if (!z) {
            d = 1.0d - d;
        }
        return new TDoubleArrayList(new double[]{d});
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public <T> List<T> getBucketValues(Class<T> cls) {
        if (this.bucketValues == null) {
            int maxIndex = this.scalarEncoder.getTopDownMapping().getMaxIndex() + 1;
            this.bucketValues = new ArrayList();
            for (int i = 0; i < maxIndex; i++) {
                this.bucketValues.add((String) getBucketInfo(new int[]{i}).get(0).getValue());
            }
        }
        return (List<T>) this.bucketValues;
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public List<EncoderResult> getBucketInfo(int[] iArr) {
        List<EncoderResult> bucketInfo = this.scalarEncoder.getBucketInfo(iArr);
        int round = (int) Math.round(((Double) bucketInfo.get(0).getValue()).doubleValue());
        bucketInfo.set(0, new EncoderResult((String) this.indexToCategory.get(round), Integer.valueOf(round), bucketInfo.get(0).getEncoding()));
        return bucketInfo;
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public List<EncoderResult> topDownCompute(int[] iArr) {
        return getBucketInfo(new int[]{ArrayUtils.argmax(rightVecProd(this.scalarEncoder.getTopDownMapping(), iArr))});
    }

    public List<String> getCategoryList() {
        return this.categoryList;
    }

    public void setCategoryList(List<String> list) {
        this.categoryList = list;
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public int getWidth() {
        return getN();
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public boolean isDelta() {
        return false;
    }
}
