package org.numenta.nupic.encoders;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import org.junit.Assert;
import org.junit.Test;
import org.numenta.nupic.util.ArrayUtils;

/* JADX WARN: Classes with same name are omitted:
  input_file:org/numenta/nupic/examples/cortical_io/breakingnews/breaking-news-demo-1.0.0.jar:org/numenta/nupic/encoders/SDRCategoryEncoderTest.class
  input_file:org/numenta/nupic/examples/cortical_io/foxeats/FoxEatsDemo.jar:org/numenta/nupic/encoders/SDRCategoryEncoderTest.class
 */
/* loaded from: input_file:org/numenta/nupic/examples/napi/hotgym/NAPI-Hotgym-Demo-1.0.jar:org/numenta/nupic/encoders/SDRCategoryEncoderTest.class */
public class SDRCategoryEncoderTest {
    @Test
    public void testSDRCategoryEncoder() {
        System.out.println("Testing CategoryEncoder...");
        String[] strArr = {"ES", "S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10", "S11", "S12", "S13", "S14", "S15", "S16", "S17", "S18", "S19", "GB", "US"};
        SDRCategoryEncoder build = SDRCategoryEncoder.builder().n(100).w(10).categoryList(Arrays.asList(strArr)).name("foo").forced(true).build();
        Assert.assertEquals(build.getSDRs().size(), 23L);
        Assert.assertEquals(build.getSDRs().iterator().next().length, 100);
        int[] encode = build.encode("ES");
        Assert.assertEquals(ArrayUtils.aggregateArray(encode), 10);
        Assert.assertEquals(encode.length, 100);
        DecodeResult decode = build.decode(encode);
        Assert.assertTrue("foo".equals(decode.getDescriptions().iterator().next()));
        Assert.assertTrue("ES".equals(decode.getFields().get("foo").getDescription()));
        EncoderResult encoderResult = build.topDownCompute(encode).get(0);
        Assert.assertEquals(encoderResult.getValue(), "ES");
        Assert.assertEquals(encoderResult.getScalar(), 1);
        Assert.assertEquals(ArrayUtils.aggregateArray(encoderResult.getEncoding()), 10);
        for (String str : strArr) {
            EncoderResult encoderResult2 = build.topDownCompute(build.encode(str)).get(0);
            Assert.assertEquals(encoderResult2.getValue(), str);
            Assert.assertEquals(encoderResult2.getScalar(), Integer.valueOf((int) build.getScalars(str).get(0)));
            System.out.print("bucket index =>" + build.getBucketIndices(str)[0]);
        }
        int[] encode2 = build.encode("ASDFLKJLK");
        Assert.assertEquals(ArrayUtils.aggregateArray(encode2), 10);
        Assert.assertEquals(encode2.length, 100);
        Assert.assertEquals(build.decode(encode2).getFields().get("foo").getDescription(), "<UNKNOWN>");
        EncoderResult encoderResult3 = build.topDownCompute(encode2).get(0);
        Assert.assertEquals(encoderResult3.getValue(), "<UNKNOWN>");
        Assert.assertEquals(encoderResult3.getScalar(), 0);
        int[] encode3 = build.encode("US");
        Assert.assertEquals(ArrayUtils.aggregateArray(encode3), 10);
        Assert.assertEquals(encode3.length, 100);
        Assert.assertEquals(ArrayUtils.aggregateArray(encode3), 10);
        Assert.assertEquals(build.decode(encode3).getFields().get("foo").getDescription(), "US");
        EncoderResult encoderResult4 = build.topDownCompute(encode3).get(0);
        Assert.assertEquals(encoderResult4.getValue(), "US");
        Assert.assertEquals(encoderResult4.getScalar(), Integer.valueOf(strArr.length));
        Assert.assertEquals(ArrayUtils.aggregateArray(encoderResult4.getEncoding()), 10);
        String[] strArr2 = new String[2];
        strArr2[1] = "";
        for (String str2 : strArr2) {
            int[] encode4 = build.encode(str2);
            Assert.assertEquals(ArrayUtils.aggregateArray(encode4), 0L);
            Assert.assertEquals(encode4.length, 100);
        }
        int nextInt = new Random().nextInt(build.getWidth() - 1);
        encode3[nextInt] = 1 - encode3[nextInt];
        Assert.assertEquals(build.decode(encode3).getFields().get("foo").getDescription(), "US");
        int[] or = ArrayUtils.or(encode2, encode3);
        String description = build.decode(or).getFields().get("foo").getDescription();
        if ("US <UNKNOWN>".equals(description) && "<UNKNOWN> US".equals(description)) {
            String replace = description.replace("US", "").replace("<UNKNOWN>", "").replace(" ", "");
            System.out.println(String.format("Got: %s instead of US/<UNKNOWN>", description));
            System.out.println(String.format("US: %s", ArrayUtils.intArrayToString(encode3)));
            System.out.println(String.format("unknown: %s", ArrayUtils.intArrayToString(encode2)));
            System.out.println(String.format("Sum: %s", ArrayUtils.intArrayToString(or)));
            System.out.println(String.format("%s: %s", replace, ArrayUtils.intArrayToString(build.encode(replace))));
            throw new RuntimeException("Decoding failure");
        }
        SDRCategoryEncoder build2 = SDRCategoryEncoder.builder().n(100).w(10).name("bar").forced(true).build();
        int[] encode5 = build2.encode("ES");
        Assert.assertEquals(ArrayUtils.aggregateArray(encode5), 10);
        Assert.assertEquals(encode5.length, 100);
        DecodeResult decode2 = build2.decode(encode5);
        Assert.assertEquals(decode2.getDescriptions().get(0), "bar");
        Assert.assertEquals(decode2.getFields().get("bar").getDescription(), "ES");
        int[] encode6 = build2.encode("US");
        Assert.assertEquals(ArrayUtils.aggregateArray(encode6), 10);
        Assert.assertEquals(encode6.length, 100);
        DecodeResult decode3 = build2.decode(encode6);
        Assert.assertEquals(decode3.getDescriptions().get(0), "bar");
        Assert.assertEquals(decode3.getFields().get("bar").getDescription(), "US");
        Assert.assertEquals(build2.decode(encode6).getFields().get("bar").getDescription(), "US");
        Assert.assertTrue(Arrays.equals(encode5, build2.encode("ES")));
        Assert.assertTrue(Arrays.equals(encode6, build2.encode("US")));
        int nextInt2 = new Random().nextInt(build2.getWidth() - 1);
        encode6[nextInt2] = 1 - encode6[nextInt2];
        Assert.assertEquals(build2.decode(encode6).getFields().get("bar").getDescription(), "US");
        String description2 = build2.decode(ArrayUtils.or(encode6, encode5)).getFields().get("bar").getDescription();
        Assert.assertTrue("US ES".equals(description2) || "ES US".equals(description2));
        boolean z = false;
        ArrayList arrayList = new ArrayList(Arrays.asList(strArr));
        arrayList.add("ES");
        try {
            SDRCategoryEncoder.builder().n(100).w(10).categoryList(arrayList).name("foo").forced(true).build();
        } catch (IllegalArgumentException e) {
            z = true;
        }
        if (!z) {
            throw new RuntimeException("Did not catch duplicate category in constructor");
        }
    }

    @Test
    public void testAutoGrow() {
        SDRCategoryEncoder build = SDRCategoryEncoder.builder().n(100).w(10).name("foo").forced(true).build();
        int[] iArr = new int[100];
        Arrays.fill(iArr, 0);
        Assert.assertEquals(build.topDownCompute(iArr).get(0).getValue(), "<UNKNOWN>");
        build.encodeIntoArray("catA", iArr);
        Assert.assertEquals(ArrayUtils.aggregateArray(iArr), 10);
        Assert.assertEquals(build.getScalars("catA").get(0), 1.0d, 0.0d);
        int[] iArr2 = new int[iArr.length];
        System.arraycopy(iArr, 0, iArr2, 0, iArr.length);
        build.encodeIntoArray("catB", iArr);
        Assert.assertEquals(ArrayUtils.aggregateArray(iArr), 10);
        Assert.assertEquals(build.getScalars("catB").get(0), 2.0d, 0.0d);
        int[] iArr3 = new int[iArr.length];
        System.arraycopy(iArr, 0, iArr3, 0, iArr.length);
        Assert.assertEquals(build.topDownCompute(iArr2).get(0).getValue(), "catA");
        Assert.assertEquals(build.topDownCompute(iArr3).get(0).getValue(), "catB");
        String[] strArr = new String[2];
        strArr[1] = "";
        for (String str : strArr) {
            build.encodeIntoArray(str, iArr);
            Assert.assertEquals(ArrayUtils.aggregateArray(iArr), 0L);
            Assert.assertEquals(build.topDownCompute(iArr).get(0).getValue(), "<UNKNOWN>");
        }
        build.setLearning(false);
        build.encodeIntoArray("catC", iArr);
        Assert.assertEquals(ArrayUtils.aggregateArray(iArr), 10);
        Assert.assertEquals(build.getScalars("catC").get(0), 0.0d, 0.0d);
        Assert.assertEquals(build.topDownCompute(iArr).get(0).getValue(), "<UNKNOWN>");
        build.setLearning(true);
        build.encodeIntoArray("catC", iArr);
        Assert.assertEquals(ArrayUtils.aggregateArray(iArr), 10);
        Assert.assertEquals(build.getScalars("catC").get(0), 3.0d, 0.0d);
        Assert.assertEquals(build.topDownCompute(iArr).get(0).getValue(), "catC");
    }
}
