package org.numenta.nupic.encoders;

import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/numenta/nupic/encoders/ScalarEncoder.class */
public class ScalarEncoder extends Encoder<Double> {
    private static final long serialVersionUID = 1;
    private static final Logger LOGGER = LoggerFactory.getLogger(ScalarEncoder.class);

    /* loaded from: input_file:org/numenta/nupic/encoders/ScalarEncoder$Builder.class */
    public static class Builder extends Encoder.Builder<Builder, ScalarEncoder> {
        private Builder() {
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.numenta.nupic.encoders.Encoder.Builder
        public ScalarEncoder build() {
            this.encoder = new ScalarEncoder();
            super.build();
            ((ScalarEncoder) this.encoder).init();
            return (ScalarEncoder) this.encoder;
        }
    }

    public static Encoder.Builder<Builder, ScalarEncoder> builder() {
        return new Builder();
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public boolean isDelta() {
        return false;
    }

    public void init() {
        if (getW() % 2 == 0) {
            throw new IllegalStateException("W must be an odd number (to eliminate centering difficulty)");
        }
        setHalfWidth((getW() - 1) / 2);
        setPadding(isPeriodic() ? 0 : getHalfWidth());
        if (!Double.isNaN(getMinVal()) && !Double.isNaN(getMaxVal())) {
            if (getMinVal() >= getMaxVal()) {
                throw new IllegalStateException("maxVal must be > minVal");
            }
            setRangeInternal(getMaxVal() - getMinVal());
        }
        initEncoder(getW(), getMinVal(), getMaxVal(), getN(), getRadius(), getResolution());
        setNInternal(getN() - (2 * getPadding()));
        if (getName() == null) {
            if (getMinVal() % ((int) getMinVal()) > 0.0d || getMaxVal() % ((int) getMaxVal()) > 0.0d) {
                setName("[" + getMinVal() + ":" + getMaxVal() + "]");
            } else {
                setName("[" + ((int) getMinVal()) + ":" + ((int) getMaxVal()) + "]");
            }
        }
        if (!isForced()) {
            checkReasonableSettings();
        }
        List<Tuple> list = this.description;
        Object[] objArr = new Object[2];
        String name = getName();
        this.name = name;
        objArr[0] = name.equals("None") ? "[" + ((int) getMinVal()) + ":" + ((int) getMaxVal()) + "]" : this.name;
        objArr[1] = 0;
        list.add(new Tuple(objArr));
    }

    public void initEncoder(int i, double d, double d2, int i2, double d3, double d4) {
        if (i2 == 0) {
            if (d3 != 0.0d) {
                setResolution(getRadius() / i);
            } else {
                if (d4 == 0.0d) {
                    throw new IllegalStateException("One of n, radius, resolution must be specified for a ScalarEncoder");
                }
                setRadius(getResolution() * i);
            }
            if (isPeriodic()) {
                setRange(getRangeInternal());
            } else {
                setRange(getRangeInternal() + getResolution());
            }
            setN((int) Math.ceil((i * (getRange() / getRadius())) + (2 * getPadding())));
            return;
        }
        if (Double.isNaN(d) || Double.isNaN(d2)) {
            return;
        }
        if (isPeriodic()) {
            setResolution(getRangeInternal() / getN());
        } else {
            setResolution(getRangeInternal() / (getN() - getW()));
        }
        setRadius(getW() * getResolution());
        if (isPeriodic()) {
            setRange(getRangeInternal());
        } else {
            setRange(getRangeInternal() + getResolution());
        }
    }

    public Integer getFirstOnBit(double d) {
        if (d == Double.NaN) {
            return null;
        }
        if (d < getMinVal()) {
            if (!clipInput() || isPeriodic()) {
                throw new IllegalStateException("input (" + d + ") less than range (" + getMinVal() + " - " + getMaxVal());
            }
            LOGGER.info("Clipped input " + getName() + "=" + d + " to minval " + getMinVal());
            d = getMinVal();
        }
        if (isPeriodic()) {
            if (d >= getMaxVal()) {
                throw new IllegalStateException("input (" + d + ") greater than periodic range (" + getMinVal() + " - " + getMaxVal());
            }
        } else if (d > getMaxVal()) {
            if (!clipInput()) {
                throw new IllegalStateException("input (" + d + ") greater than periodic range (" + getMinVal() + " - " + getMaxVal());
            }
            LOGGER.info("Clipped input " + getName() + "=" + d + " to maxval " + getMaxVal());
            d = getMaxVal();
        }
        return Integer.valueOf((isPeriodic() ? ((int) (((d - getMinVal()) * getNInternal()) / getRange())) + getPadding() : ((int) (((d - getMinVal()) + (getResolution() / 2.0d)) / getResolution())) + getPadding()) - getHalfWidth());
    }

    public void checkReasonableSettings() {
        if (getW() < 21) {
            throw new IllegalStateException("Number of bits in the SDR (%d) must be greater than 2, and recommended >= 21 (use forced=True to override)");
        }
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public Set<FieldMetaType> getDecoderOutputFieldTypes() {
        return new LinkedHashSet(Arrays.asList(FieldMetaType.FLOAT, FieldMetaType.INTEGER));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public int getWidth() {
        return getN();
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public int[] getBucketIndices(String str) {
        return null;
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public int[] getBucketIndices(double d) {
        int i;
        int intValue = getFirstOnBit(d).intValue();
        if (isPeriodic()) {
            i = intValue + getHalfWidth();
            if (i < 0) {
                i += getN();
            }
        } else {
            i = intValue;
        }
        return new int[]{i};
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.numenta.nupic.encoders.Encoder
    public void encodeIntoArray(Double d, int[] iArr) {
        if (Double.isNaN(d.doubleValue())) {
            Arrays.fill(iArr, 0);
            return;
        }
        Integer firstOnBit = getFirstOnBit(d.doubleValue());
        if (firstOnBit != null) {
            int intValue = firstOnBit.intValue();
            Arrays.fill(iArr, 0);
            int i = intValue;
            int halfWidth = i + (2 * getHalfWidth());
            if (isPeriodic()) {
                if (halfWidth >= getN()) {
                    ArrayUtils.setIndexesTo(iArr, ArrayUtils.range(0, (halfWidth - getN()) + 1), 1);
                    halfWidth = getN() - 1;
                }
                if (i < 0) {
                    ArrayUtils.setIndexesTo(iArr, ArrayUtils.range(getN() - (-i), getN()), 1);
                    i = 0;
                }
            }
            ArrayUtils.setIndexesTo(iArr, ArrayUtils.range(i, halfWidth + 1), 1);
        }
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("");
            LOGGER.trace("input: " + d);
            LOGGER.trace("range: " + getMinVal() + " - " + getMaxVal());
            LOGGER.trace("n:" + getN() + "w:" + getW() + "resolution:" + getResolution() + "radius:" + getRadius() + "periodic:" + isPeriodic());
            LOGGER.trace("output: " + Arrays.toString(iArr));
            LOGGER.trace("input desc: " + decode(iArr, ""));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v145, types: [java.util.List] */
    @Override // org.numenta.nupic.encoders.Encoder
    public DecodeResult decode(int[] iArr, String str) {
        int halfWidth;
        int halfWidth2;
        double padding;
        double padding2;
        if (iArr == null || iArr.length < 1) {
            return null;
        }
        int[] copyOf = Arrays.copyOf(iArr, iArr.length);
        int halfWidth3 = getHalfWidth();
        for (int i = 0; i < halfWidth3; i++) {
            int[] iArr2 = new int[i + 3];
            Arrays.fill(iArr2, 1);
            ArrayUtils.setRangeTo(iArr2, 1, -1, 0);
            int length = iArr2.length;
            if (isPeriodic()) {
                for (int i2 = 0; i2 < getN(); i2++) {
                    int[] modulo = ArrayUtils.modulo(ArrayUtils.range(i2, i2 + length), getN());
                    if (Arrays.equals(iArr2, ArrayUtils.sub(copyOf, modulo))) {
                        ArrayUtils.setIndexesTo(copyOf, modulo, 1);
                    }
                }
            } else {
                for (int i3 = 0; i3 < (getN() - length) + 1; i3++) {
                    if (Arrays.equals(iArr2, ArrayUtils.sub(copyOf, ArrayUtils.range(i3, i3 + length)))) {
                        ArrayUtils.setRangeTo(copyOf, i3, i3 + length, 1);
                    }
                }
            }
        }
        LOGGER.trace("raw output:" + Arrays.toString(ArrayUtils.sub(iArr, ArrayUtils.range(0, getN()))));
        LOGGER.trace("filtered output:" + Arrays.toString(copyOf));
        int[] where = ArrayUtils.where(copyOf, (Condition) new Condition.Adapter<Integer>() { // from class: org.numenta.nupic.encoders.ScalarEncoder.1
            @Override // org.numenta.nupic.util.Condition.Adapter, org.numenta.nupic.util.Condition
            public boolean eval(int i4) {
                return i4 > 0;
            }
        });
        ArrayList<Tuple> arrayList = new ArrayList();
        Arrays.sort(where);
        int[] iArr3 = {where[0], 1};
        for (int i4 = 1; i4 < where.length; i4++) {
            if (where[i4] == iArr3[0] + iArr3[1]) {
                int[] iArr4 = iArr3;
                iArr4[1] = iArr4[1] + 1;
            } else {
                arrayList.add(new Tuple(Integer.valueOf(iArr3[0]), Integer.valueOf(iArr3[1])));
                iArr3 = new int[]{where[i4], 1};
            }
        }
        arrayList.add(new Tuple(Integer.valueOf(iArr3[0]), Integer.valueOf(iArr3[1])));
        if (isPeriodic() && arrayList.size() > 1) {
            int size = arrayList.size() - 1;
            if (((Integer) ((Tuple) arrayList.get(0)).get(0)).intValue() == 0 && ((Integer) ((Tuple) arrayList.get(size)).get(0)).intValue() + ((Integer) ((Tuple) arrayList.get(size)).get(1)).intValue() == getN()) {
                arrayList.set(size, new Tuple((Integer) ((Tuple) arrayList.get(size)).get(0), Integer.valueOf(((Integer) ((Tuple) arrayList.get(size)).get(1)).intValue() + ((Integer) ((Tuple) arrayList.get(0)).get(1)).intValue())));
                arrayList = arrayList.subList(1, arrayList.size());
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (Tuple tuple : arrayList) {
            int intValue = ((Integer) tuple.get(0)).intValue();
            int intValue2 = ((Integer) tuple.get(1)).intValue();
            if (intValue2 <= getW()) {
                int i5 = intValue + (intValue2 / 2);
                halfWidth2 = i5;
                halfWidth = i5;
            } else {
                halfWidth = intValue + getHalfWidth();
                halfWidth2 = ((intValue + intValue2) - 1) - getHalfWidth();
            }
            if (isPeriodic()) {
                padding = (((halfWidth - getPadding()) * getRange()) / getNInternal()) + getMinVal();
                padding2 = (((halfWidth2 - getPadding()) * getRange()) / getNInternal()) + getMinVal();
            } else {
                padding = ((halfWidth - getPadding()) * getResolution()) + getMinVal();
                padding2 = ((halfWidth2 - getPadding()) * getResolution()) + getMinVal();
            }
            if (isPeriodic() && padding >= getMaxVal()) {
                padding -= getRange();
                padding2 -= getRange();
            }
            if (padding < getMinVal()) {
                padding = getMinVal();
            }
            if (padding2 < getMinVal()) {
                padding2 = getMinVal();
            }
            if (!isPeriodic() || padding2 < getMaxVal()) {
                if (padding2 > getMaxVal()) {
                    padding2 = getMaxVal();
                }
                if (padding > getMaxVal()) {
                    padding = getMaxVal();
                }
                arrayList2.add(new MinMax(padding, padding2));
            } else {
                arrayList2.add(new MinMax(padding, getMaxVal()));
                arrayList2.add(new MinMax(getMinVal(), padding2 - getRange()));
            }
        }
        String generateRangeDescription = generateRangeDescription(arrayList2);
        String name = (str == null || str.isEmpty()) ? getName() : String.format("%s.%s", str, getName());
        RangeList rangeList = new RangeList(arrayList2, generateRangeDescription);
        HashMap hashMap = new HashMap();
        hashMap.put(name, rangeList);
        return new DecodeResult(hashMap, Arrays.asList(name));
    }

    public String generateRangeDescription(List<MinMax> list) {
        StringBuilder sb = new StringBuilder();
        int size = list.size();
        for (int i = 0; i < size; i++) {
            if (list.get(i).min() != list.get(i).max()) {
                sb.append(String.format("%.2f-%.2f", Double.valueOf(list.get(i).min()), Double.valueOf(list.get(i).max())));
            } else {
                sb.append(String.format("%.2f", Double.valueOf(list.get(i).min())));
            }
            if (i < size - 1) {
                sb.append(", ");
            }
        }
        return sb.toString();
    }

    public SparseObjectMatrix<int[]> getTopDownMapping() {
        if (this.topDownMapping == null) {
            if (isPeriodic()) {
                setTopDownValues(ArrayUtils.arange(getMinVal() + (getResolution() / 2.0d), getMaxVal(), getResolution()));
            } else {
                setTopDownValues(ArrayUtils.arange(getMinVal(), getMaxVal() + (getResolution() / 2.0d), getResolution()));
            }
        }
        int length = getTopDownValues().length;
        SparseObjectMatrix<int[]> sparseObjectMatrix = new SparseObjectMatrix<>(new int[]{length});
        setTopDownMapping(sparseObjectMatrix);
        double[] topDownValues = getTopDownValues();
        int[] iArr = new int[getN()];
        double minVal = getMinVal();
        double maxVal = getMaxVal();
        for (int i = 0; i < length; i++) {
            encodeIntoArray(Double.valueOf(Math.min(Math.max(topDownValues[i], minVal), maxVal)), iArr);
            sparseObjectMatrix.set(i, (int) Arrays.copyOf(iArr, iArr.length));
        }
        return sparseObjectMatrix;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.numenta.nupic.encoders.Encoder
    public <S> TDoubleList getScalars(S s) {
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        tDoubleArrayList.add(((Double) s).doubleValue());
        return tDoubleArrayList;
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public <S> List<S> getBucketValues(Class<S> cls) {
        if (this.bucketValues == null) {
            int maxIndex = getTopDownMapping().getMaxIndex() + 1;
            this.bucketValues = new ArrayList();
            for (int i = 0; i < maxIndex; i++) {
                this.bucketValues.add((Double) getBucketInfo(new int[]{i}).get(0).get(1));
            }
        }
        return (List<S>) this.bucketValues;
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public List<Encoding> getBucketInfo(int[] iArr) {
        SparseObjectMatrix<int[]> topDownMapping = getTopDownMapping();
        int i = iArr[0];
        int[] object = topDownMapping.getObject(i);
        double minVal = isPeriodic() ? getMinVal() + (getResolution() / 2.0d) + (i * getResolution()) : getMinVal() + (i * getResolution());
        return Arrays.asList(new Encoding(Double.valueOf(minVal), Double.valueOf(minVal), object));
    }

    @Override // org.numenta.nupic.encoders.Encoder
    public List<Encoding> topDownCompute(int[] iArr) {
        return getBucketInfo(new int[]{ArrayUtils.argmax(rightVecProd(getTopDownMapping(), iArr))});
    }

    public List<Tuple> dict() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Tuple("maxval", Double.valueOf(getMaxVal())));
        arrayList.add(new Tuple("bucketValues", getBucketValues(Double.class)));
        arrayList.add(new Tuple("nInternal", Integer.valueOf(getNInternal())));
        arrayList.add(new Tuple("name", getName()));
        arrayList.add(new Tuple("minval", Double.valueOf(getMinVal())));
        arrayList.add(new Tuple("topDownValues", Arrays.toString(getTopDownValues())));
        arrayList.add(new Tuple("clipInput", Boolean.valueOf(clipInput())));
        arrayList.add(new Tuple("n", Integer.valueOf(getN())));
        arrayList.add(new Tuple("padding", Integer.valueOf(getPadding())));
        arrayList.add(new Tuple("range", Double.valueOf(getRange())));
        arrayList.add(new Tuple("periodic", Boolean.valueOf(isPeriodic())));
        arrayList.add(new Tuple("radius", Double.valueOf(getRadius())));
        arrayList.add(new Tuple("w", Integer.valueOf(getW())));
        arrayList.add(new Tuple("topDownMappingM", getTopDownMapping()));
        arrayList.add(new Tuple("halfwidth", Integer.valueOf(getHalfWidth())));
        arrayList.add(new Tuple("resolution", Double.valueOf(getResolution())));
        arrayList.add(new Tuple("rangeInternal", Double.valueOf(getRangeInternal())));
        return arrayList;
    }
}
