package org.numenta.nupic.network;

import io.cortical.rest.DefaultValues;
import java.util.Arrays;
import java.util.HashMap;
import java.util.function.Supplier;
import org.junit.Assert;
import org.junit.Test;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.datagen.ResourceLocator;
import org.numenta.nupic.encoders.MultiEncoder;
import org.numenta.nupic.network.sensor.FileSensor;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.research.SpatialPooler;
import org.numenta.nupic.research.TemporalMemory;
import org.numenta.nupic.util.MersenneTwister;
import rx.Observer;
import rx.Subscriber;

/* JADX WARN: Classes with same name are omitted:
  input_file:org/numenta/nupic/examples/cortical_io/breakingnews/breaking-news-demo-1.0.0.jar:org/numenta/nupic/network/RegionTest.class
  input_file:org/numenta/nupic/examples/cortical_io/foxeats/FoxEatsDemo.jar:org/numenta/nupic/network/RegionTest.class
 */
/* loaded from: input_file:org/numenta/nupic/examples/napi/hotgym/NAPI-Hotgym-Demo-1.0.jar:org/numenta/nupic/network/RegionTest.class */
public class RegionTest {
    boolean isHalted;
    int idx0 = 0;
    int idx1 = 0;
    int idx2 = 0;

    @Test
    public void testClose() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        Network add = Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("4", union).add(MultiEncoder.builder().name("").build())).close());
        Assert.assertTrue(add.lookup("r1").isClosed());
        try {
            add.lookup("r1").add(Network.createLayer(DefaultValues.DEF_VALUE_MAX_CONTEXTS_COUNT, union));
            Assert.fail();
        } catch (Exception e) {
            Assert.assertTrue(e.getClass().isAssignableFrom(IllegalStateException.class));
            Assert.assertEquals("Cannot add Layers when Region has already been closed.", e.getMessage());
        }
    }

    @Test
    public void testResetMethod() {
        Parameters parameters = NetworkTestHarness.getParameters();
        Region createRegion = Network.createRegion("r1");
        createRegion.add(Network.createLayer("l1", parameters).add(new TemporalMemory()));
        try {
            createRegion.reset();
            Assert.assertTrue(createRegion.lookup("l1").hasTemporalMemory());
        } catch (Exception e) {
            Assert.fail();
        }
        Region createRegion2 = Network.createRegion("r1");
        createRegion2.add(Network.createLayer("l1", parameters).add(new SpatialPooler()));
        try {
            createRegion2.reset();
            Assert.assertFalse(createRegion2.lookup("l1").hasTemporalMemory());
        } catch (Exception e2) {
            Assert.fail();
        }
    }

    @Test
    public void testResetRecordNum() {
        Parameters parameters = NetworkTestHarness.getParameters();
        Region createRegion = Network.createRegion("r1");
        createRegion.add(Network.createLayer("l1", parameters).add(new TemporalMemory()));
        createRegion.observe().subscribe(new Observer<Inference>() { // from class: org.numenta.nupic.network.RegionTest.1
            @Override // rx.Observer
            public void onCompleted() {
            }

            @Override // rx.Observer
            public void onError(Throwable th) {
                th.printStackTrace();
            }

            @Override // rx.Observer
            public void onNext(Inference inference) {
                System.out.println("output = " + Arrays.toString(inference.getSDR()));
            }
        });
        createRegion.compute(new int[]{2, 3, 4});
        createRegion.compute(new int[]{2, 3, 4});
        Assert.assertEquals(1L, createRegion.lookup("l1").getRecordNum());
        createRegion.resetRecordNum();
        Assert.assertEquals(0L, createRegion.lookup("l1").getRecordNum());
    }

    @Test
    public void testAutomaticClose() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        Region lookup = Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("4", union).add(MultiEncoder.builder().name("").build()))).lookup("r1");
        lookup.start();
        Assert.assertTrue(lookup.isClosed());
        try {
            lookup.add(Network.createLayer(DefaultValues.DEF_VALUE_MAX_CONTEXTS_COUNT, union));
            Assert.fail();
        } catch (Exception e) {
            Assert.assertTrue(e.getClass().isAssignableFrom(IllegalStateException.class));
            Assert.assertEquals("Cannot add Layers when Region has already been closed.", e.getMessage());
        }
    }

    @Test
    public void testAdd() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        Region lookup = Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("4", union).add(MultiEncoder.builder().name("").build()))).lookup("r1");
        Layer<?> lookup2 = lookup.lookup("4");
        Assert.assertNotNull(lookup2);
        Assert.assertEquals("r1:4", lookup2.getName());
        try {
            lookup.add(Network.createLayer("4", union));
            Assert.fail();
        } catch (Exception e) {
            Assert.assertTrue(e.getClass().isAssignableFrom(IllegalArgumentException.class));
            Assert.assertEquals("A Layer with the name: 4 has already been added to this Region.", e.getMessage());
        }
    }

    /* JADX WARN: Type inference failed for: r0v15, types: [org.numenta.nupic.network.RegionTest$3] */
    @Test
    public void testHalt() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        HashMap hashMap = new HashMap();
        hashMap.put(Anomaly.KEY_MODE, Anomaly.Mode.PURE);
        final Region lookup = Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("1", union).alterParameter(Parameters.KEY.AUTO_CLASSIFY, Boolean.TRUE)).add(Network.createLayer(DefaultValues.DEF_VALUE_PLOT_SCALAR, union).add(Anomaly.create(hashMap))).add(Network.createLayer("3", union).add(new TemporalMemory())).add(Network.createLayer("4", union).add(Sensor.create(FileSensor::create, SensorParams.create((Supplier<SensorParams.Keys.Args>) SensorParams.Keys::path, "", ResourceLocator.path("days-of-week.csv")))).add(new SpatialPooler())).connect("1", DefaultValues.DEF_VALUE_PLOT_SCALAR).connect(DefaultValues.DEF_VALUE_PLOT_SCALAR, "3").connect("3", "4")).lookup("r1");
        lookup.observe().subscribe((Subscriber<? super Inference>) new Subscriber<Inference>() { // from class: org.numenta.nupic.network.RegionTest.2
            int seq = 0;

            @Override // rx.Observer
            public void onCompleted() {
                System.out.println("onCompleted() called");
            }

            @Override // rx.Observer
            public void onError(Throwable th) {
                th.printStackTrace();
            }

            @Override // rx.Observer
            public void onNext(Inference inference) {
                if (this.seq == 2) {
                    RegionTest.this.isHalted = true;
                }
                this.seq++;
                System.out.println("output: " + inference.getSDR());
            }
        });
        new Thread() { // from class: org.numenta.nupic.network.RegionTest.3
            @Override // java.lang.Thread, java.lang.Runnable
            public void run() {
                while (!RegionTest.this.isHalted) {
                    try {
                        Thread.sleep(1L);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
                lookup.halt();
            }
        }.start();
        lookup.start();
        try {
            lookup.lookup("4").getLayerThread().join();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Test
    public void testInputDimensionsAutomaticallyInferredFromEncoderWidth() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        union.setParameterByKey(Parameters.KEY.INPUT_DIMENSIONS, new int[]{40, 40});
        Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("4", union).add(MultiEncoder.builder().name("").build())).close());
        Assert.assertTrue(Arrays.equals(new int[]{8}, (int[]) union.getParameterByKey(Parameters.KEY.INPUT_DIMENSIONS)));
    }

    @Test
    public void testEncoderPassesUpToTopLayer() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        HashMap hashMap = new HashMap();
        hashMap.put(Anomaly.KEY_MODE, Anomaly.Mode.PURE);
        Region lookup = Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("1", union).alterParameter(Parameters.KEY.AUTO_CLASSIFY, Boolean.TRUE)).add(Network.createLayer(DefaultValues.DEF_VALUE_PLOT_SCALAR, union).add(Anomaly.create(hashMap))).add(Network.createLayer("3", union).add(new TemporalMemory())).add(Network.createLayer("4", union).add(new SpatialPooler()).add(MultiEncoder.builder().name("").build()))).lookup("r1");
        lookup.connect("1", DefaultValues.DEF_VALUE_PLOT_SCALAR).connect(DefaultValues.DEF_VALUE_PLOT_SCALAR, "3").connect("3", "4");
        Assert.assertNotNull(lookup.lookup("1").getEncoder());
    }

    @Test
    public void testMultiLayerAssemblyNoSensor() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.COLUMN_DIMENSIONS, new int[]{30});
        union.setParameterByKey(Parameters.KEY.SYN_PERM_INACTIVE_DEC, Double.valueOf(0.1d));
        union.setParameterByKey(Parameters.KEY.SYN_PERM_ACTIVE_INC, Double.valueOf(0.1d));
        union.setParameterByKey(Parameters.KEY.SYN_PERM_TRIM_THRESHOLD, Double.valueOf(0.05d));
        union.setParameterByKey(Parameters.KEY.SYN_PERM_CONNECTED, Double.valueOf(0.4d));
        union.setParameterByKey(Parameters.KEY.MAX_BOOST, Double.valueOf(10.0d));
        union.setParameterByKey(Parameters.KEY.DUTY_CYCLE_PERIOD, 7);
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        HashMap hashMap = new HashMap();
        hashMap.put(Anomaly.KEY_MODE, Anomaly.Mode.PURE);
        Region lookup = Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("1", union).alterParameter(Parameters.KEY.AUTO_CLASSIFY, Boolean.TRUE)).add(Network.createLayer(DefaultValues.DEF_VALUE_PLOT_SCALAR, union).add(Anomaly.create(hashMap))).add(Network.createLayer("3", union).add(new TemporalMemory())).add(Network.createLayer("4", union).add(new SpatialPooler()).add(MultiEncoder.builder().name("").build())).connect("1", DefaultValues.DEF_VALUE_PLOT_SCALAR).connect(DefaultValues.DEF_VALUE_PLOT_SCALAR, "3").connect("3", "4")).lookup("r1");
        lookup.lookup("3").using(lookup.lookup("4").getConnections());
        lookup.observe().subscribe((Subscriber<? super Inference>) new Subscriber<Inference>() { // from class: org.numenta.nupic.network.RegionTest.4
            @Override // rx.Observer
            public void onCompleted() {
            }

            @Override // rx.Observer
            public void onError(Throwable th) {
                th.printStackTrace();
            }

            @Override // rx.Observer
            public void onNext(Inference inference) {
            }
        });
        HashMap hashMap2 = new HashMap();
        for (int i = 0; i < 400; i++) {
            double d = 0.0d;
            while (true) {
                double d2 = d;
                if (d2 >= 7.0d) {
                    break;
                }
                hashMap2.put("dayOfWeek", Double.valueOf(d2));
                lookup.compute(hashMap2);
                d = d2 + 1.0d;
            }
        }
        lookup.observe().subscribe((Subscriber<? super Inference>) new Subscriber<Inference>() { // from class: org.numenta.nupic.network.RegionTest.5
            @Override // rx.Observer
            public void onCompleted() {
            }

            @Override // rx.Observer
            public void onError(Throwable th) {
                th.printStackTrace();
            }

            @Override // rx.Observer
            public void onNext(Inference inference) {
                Assert.assertEquals(6L, (int) Math.rint(((Number) inference.getClassification("dayOfWeek").getMostProbableValue(1)).doubleValue()));
            }
        });
        hashMap2.put("dayOfWeek", Double.valueOf(5.0d));
        lookup.compute(hashMap2);
    }

    @Test
    public void testMultiLayerAssemblyWithSensor() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        HashMap hashMap = new HashMap();
        hashMap.put(Anomaly.KEY_MODE, Anomaly.Mode.PURE);
        Network add = Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("1", union).alterParameter(Parameters.KEY.AUTO_CLASSIFY, Boolean.TRUE)).add(Network.createLayer(DefaultValues.DEF_VALUE_PLOT_SCALAR, union).add(Anomaly.create(hashMap))).add(Network.createLayer("3", union).add(new TemporalMemory())).add(Network.createLayer("4", union).add(Sensor.create(FileSensor::create, SensorParams.create((Supplier<SensorParams.Keys.Args>) SensorParams.Keys::path, "", ResourceLocator.path("days-of-week.csv")))).add(new SpatialPooler())).connect("1", DefaultValues.DEF_VALUE_PLOT_SCALAR).connect(DefaultValues.DEF_VALUE_PLOT_SCALAR, "3").connect("3", "4"));
        add.lookup("r1").observe().subscribe((Subscriber<? super Inference>) new Subscriber<Inference>() { // from class: org.numenta.nupic.network.RegionTest.6
            int idx = 0;

            @Override // rx.Observer
            public void onCompleted() {
            }

            @Override // rx.Observer
            public void onError(Throwable th) {
                th.printStackTrace();
            }

            @Override // rx.Observer
            public void onNext(Inference inference) {
                switch (this.idx) {
                    case 0:
                        Assert.assertEquals(1.0d, inference.getAnomalyScore(), 0.0d);
                        Assert.assertEquals(1.0d, inference.getClassification("dayOfWeek").getStats(1)[0], 0.0d);
                        Assert.assertEquals(1L, inference.getClassification("dayOfWeek").getStats(1).length);
                        break;
                    case 1:
                        Assert.assertEquals(1.0d, inference.getAnomalyScore(), 0.0d);
                        Assert.assertEquals(1.0d, inference.getClassification("dayOfWeek").getStats(1)[0], 0.0d);
                        Assert.assertEquals(1L, inference.getClassification("dayOfWeek").getStats(1).length);
                        break;
                    case 2:
                        Assert.assertEquals(1.0d, inference.getAnomalyScore(), 0.0d);
                        Assert.assertTrue(Arrays.equals(new double[]{0.5d, 0.5d}, inference.getClassification("dayOfWeek").getStats(1)));
                        Assert.assertEquals(2L, inference.getClassification("dayOfWeek").getStats(1).length);
                        break;
                    case 3:
                        Assert.assertEquals(1.0d, inference.getAnomalyScore(), 0.0d);
                        Assert.assertTrue(Arrays.equals(new double[]{0.3333333333333333d, 0.3333333333333333d, 0.3333333333333333d}, inference.getClassification("dayOfWeek").getStats(1)));
                        Assert.assertEquals(3L, inference.getClassification("dayOfWeek").getStats(1).length);
                        break;
                    case 4:
                        Assert.assertEquals(1.0d, inference.getAnomalyScore(), 0.0d);
                        Assert.assertTrue(Arrays.equals(new double[]{0.25d, 0.25d, 0.25d, 0.25d}, inference.getClassification("dayOfWeek").getStats(1)));
                        Assert.assertEquals(4L, inference.getClassification("dayOfWeek").getStats(1).length);
                        break;
                    case 5:
                        Assert.assertEquals(1.0d, inference.getAnomalyScore(), 0.0d);
                        Assert.assertTrue(Arrays.equals(new double[]{0.2d, 0.2d, 0.2d, 0.2d, 0.2d}, inference.getClassification("dayOfWeek").getStats(1)));
                        Assert.assertEquals(5L, inference.getClassification("dayOfWeek").getStats(1).length);
                        break;
                    case 6:
                        Assert.assertEquals(1.0d, inference.getAnomalyScore(), 0.0d);
                        Assert.assertTrue(Arrays.equals(new double[]{0.16666666666666666d, 0.16666666666666666d, 0.16666666666666666d, 0.16666666666666666d, 0.16666666666666666d, 0.16666666666666666d}, inference.getClassification("dayOfWeek").getStats(1)));
                        Assert.assertEquals(6L, inference.getClassification("dayOfWeek").getStats(1).length);
                        break;
                }
                this.idx++;
            }
        });
        Region lookup = add.lookup("r1");
        lookup.start();
        try {
            lookup.lookup("4").getLayerThread().join();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Test
    public void test2LayerAssemblyWithSensor() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        Network add = Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("2/3", union).alterParameter(Parameters.KEY.AUTO_CLASSIFY, Boolean.TRUE).add(new TemporalMemory())).add(Network.createLayer("4", union).add(Sensor.create(FileSensor::create, SensorParams.create((Supplier<SensorParams.Keys.Args>) SensorParams.Keys::path, "", ResourceLocator.path("days-of-week.csv")))).add(new SpatialPooler())).connect("2/3", "4"));
        int[] iArr = new int[8];
        iArr[0] = 1;
        iArr[1] = 1;
        iArr[7] = 1;
        int[] iArr2 = new int[8];
        iArr2[0] = 1;
        iArr2[1] = 1;
        iArr2[2] = 1;
        int[] iArr3 = new int[8];
        iArr3[1] = 1;
        iArr3[2] = 1;
        iArr3[3] = 1;
        int[] iArr4 = new int[8];
        iArr4[2] = 1;
        iArr4[3] = 1;
        iArr4[4] = 1;
        int[] iArr5 = new int[8];
        iArr5[3] = 1;
        iArr5[4] = 1;
        iArr5[5] = 1;
        int[] iArr6 = new int[8];
        iArr6[4] = 1;
        iArr6[5] = 1;
        iArr6[6] = 1;
        int[] iArr7 = new int[8];
        iArr7[5] = 1;
        iArr7[6] = 1;
        iArr7[7] = 1;
        final int[][] iArr8 = {iArr, iArr2, iArr3, iArr4, iArr5, iArr6, iArr7};
        Region lookup = add.lookup("r1");
        lookup.lookup("4").observe().subscribe((Subscriber<? super Inference>) new Subscriber<Inference>() { // from class: org.numenta.nupic.network.RegionTest.7
            @Override // rx.Observer
            public void onCompleted() {
            }

            @Override // rx.Observer
            public void onError(Throwable th) {
                th.printStackTrace();
            }

            @Override // rx.Observer
            public void onNext(Inference inference) {
                int[][] iArr9 = iArr8;
                RegionTest regionTest = RegionTest.this;
                int i = regionTest.idx0;
                regionTest.idx0 = i + 1;
                Assert.assertTrue(Arrays.equals(iArr9[i], inference.getEncoding()));
            }
        });
        lookup.lookup("2/3").observe().subscribe((Subscriber<? super Inference>) new Subscriber<Inference>() { // from class: org.numenta.nupic.network.RegionTest.8
            @Override // rx.Observer
            public void onCompleted() {
            }

            @Override // rx.Observer
            public void onError(Throwable th) {
                th.printStackTrace();
            }

            @Override // rx.Observer
            public void onNext(Inference inference) {
                int[][] iArr9 = iArr8;
                RegionTest regionTest = RegionTest.this;
                int i = regionTest.idx1;
                regionTest.idx1 = i + 1;
                Assert.assertTrue(Arrays.equals(iArr9[i], inference.getEncoding()));
            }
        });
        lookup.observe().subscribe((Subscriber<? super Inference>) new Subscriber<Inference>() { // from class: org.numenta.nupic.network.RegionTest.9
            @Override // rx.Observer
            public void onCompleted() {
            }

            @Override // rx.Observer
            public void onError(Throwable th) {
                th.printStackTrace();
            }

            @Override // rx.Observer
            public void onNext(Inference inference) {
                int[][] iArr9 = iArr8;
                RegionTest regionTest = RegionTest.this;
                int i = regionTest.idx2;
                regionTest.idx2 = i + 1;
                Assert.assertTrue(Arrays.equals(iArr9[i], inference.getEncoding()));
            }
        });
        lookup.start();
        try {
            lookup.lookup("4").getLayerThread().join();
        } catch (Exception e) {
            e.printStackTrace();
        }
        Assert.assertEquals(7L, this.idx0);
        Assert.assertEquals(7L, this.idx1);
        Assert.assertEquals(7L, this.idx2);
    }

    @Test
    public void testAlgorithmRepetitionDetection() {
        Parameters union = NetworkTestHarness.getParameters().union(NetworkTestHarness.getDayDemoTestEncoderParams());
        union.setParameterByKey(Parameters.KEY.RANDOM, new MersenneTwister(42L));
        Assert.assertTrue(Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("2/3", union).alterParameter(Parameters.KEY.AUTO_CLASSIFY, Boolean.TRUE).add(new TemporalMemory())).add(Network.createLayer("4", union).add(Sensor.create(FileSensor::create, SensorParams.create((Supplier<SensorParams.Keys.Args>) SensorParams.Keys::path, "", ResourceLocator.path("days-of-week.csv")))).add(new SpatialPooler())).connect("2/3", "4")).lookup("r1").layersDistinct);
        Assert.assertEquals(0L, (byte) (((byte) (((byte) (r0.flagAccumulator ^ 1)) ^ 2)) ^ 4));
        Assert.assertEquals(r0.lookup("2/3").getMask(), 6L);
        Assert.assertEquals(r0.lookup("4").getMask(), 1L);
        Assert.assertFalse(Network.create("test network", union).add(Network.createRegion("r1").add(Network.createLayer("2/3", union).alterParameter(Parameters.KEY.AUTO_CLASSIFY, Boolean.TRUE).add(new TemporalMemory())).add(Network.createLayer("4", union).add(Sensor.create(FileSensor::create, SensorParams.create((Supplier<SensorParams.Keys.Args>) SensorParams.Keys::path, "", ResourceLocator.path("days-of-week.csv")))).add(new TemporalMemory()).add(new SpatialPooler())).connect("2/3", "4")).lookup("r1").layersDistinct);
        Assert.assertEquals(r0.lookup("2/3").getMask(), 6L);
        Assert.assertEquals(r0.lookup("4").getMask(), 3L);
    }
}
