Skip to content

Commit ffc89d0

Browse files
committed
Added check for geo and cood types to disable bucket formation for classifier, added error conditions
1 parent 5955282 commit ffc89d0

4 files changed

Lines changed: 173 additions & 14 deletions

File tree

src/main/java/org/numenta/nupic/network/Layer.java

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,11 @@ public void start() {
816816

817817
this.encoder = encoder == null ? sensor.getEncoder() : encoder;
818818

819-
completeDispatch((T)new int[] {});
819+
try {
820+
completeDispatch((T)new int[] {});
821+
}catch(Exception e) {
822+
notifyError(e);
823+
}
820824

821825
(LAYER_THREAD = new Thread("Sensor Layer [" + getName() + "] Thread") {
822826
public void run() {
@@ -832,11 +836,20 @@ public void run() {
832836
return false;
833837
}
834838

839+
if(Thread.currentThread().isInterrupted()) {
840+
notifyError(new RuntimeException("Unknown Exception while filtering input"));
841+
}
842+
835843
return true;
836844
}).forEach(intArray -> {
837845
((ManualInput)Layer.this.factory.inference).encoding(intArray);
846+
838847
Layer.this.compute((T)intArray);
839848

849+
if(Thread.currentThread().isInterrupted()) {
850+
notifyError(new RuntimeException("Unknown Exception during compute"));
851+
}
852+
840853
// Notify all downstream observers that the stream is closed
841854
if(!sensor.hasNext()) {
842855
notifyComplete();
@@ -1194,6 +1207,16 @@ void notifyComplete() {
11941207
publisher.onCompleted();
11951208
}
11961209

1210+
void notifyError(Exception e) {
1211+
for(Observer<Inference> o : subscribers) {
1212+
o.onError(e);
1213+
}
1214+
for(Observer<Inference> o : observers) {
1215+
o.onError(e);
1216+
}
1217+
publisher.onError(e);
1218+
}
1219+
11971220
/**
11981221
* <p>
11991222
* Returns the content mask used to indicate what algorithm
@@ -1305,9 +1328,13 @@ private Map<Class<T>, Observable<ManualInput>> createDispatchMap() {
13051328
private Observable<ManualInput> mapEncoderBuckets(Observable<ManualInput> sequence) {
13061329
if(hasSensor()) {
13071330
if(getSensor().getMetaInfo().getFieldTypes().stream().anyMatch(
1308-
ft -> { return ft == FieldMetaType.SARR || ft == FieldMetaType.DARR; })) {
1331+
ft -> { return ft == FieldMetaType.SARR ||
1332+
ft == FieldMetaType.DARR ||
1333+
ft == FieldMetaType.COORD ||
1334+
ft == FieldMetaType.GEO; })) {
13091335
if(autoCreateClassifiers) {
1310-
throw new IllegalStateException("Cannot autoclassify with raw array input... Remove auto classify setting.");
1336+
throw new IllegalStateException("Cannot autoclassify with raw array input or " +
1337+
" Coordinate based encoders... Remove auto classify setting.");
13111338
}
13121339
return sequence;
13131340
}
@@ -1680,7 +1707,7 @@ public ManualInput call(String[] t1) {
16801707
sdr[i] = Integer.parseInt(t1[i]);
16811708
}
16821709

1683-
return inference.sdr(sdr).layerInput(sdr);
1710+
return inference.recordNum(getRecordNum()).sdr(sdr).layerInput(sdr);
16841711
}
16851712
});
16861713
}
@@ -1718,7 +1745,7 @@ public ManualInput call(Map t1) {
17181745

17191746
doEncoderBucketMapping(inference, t1);
17201747

1721-
return inference.layerInput(t1);
1748+
return inference.recordNum(getRecordNum()).layerInput(t1);
17221749
}
17231750
});
17241751
}
@@ -1740,7 +1767,7 @@ public Observable<ManualInput> call(Observable<int[]> t1) {
17401767
@Override
17411768
public ManualInput call(int[] t1) {
17421769
// Indicates a value that skips the encoding step
1743-
return inference.sdr(t1).layerInput(t1);
1770+
return inference.recordNum(getRecordNum()).sdr(t1).layerInput(t1);
17441771
}
17451772
});
17461773
}
@@ -1767,7 +1794,7 @@ public ManualInput call(ManualInput t1) {
17671794
swapped = true;
17681795
}
17691796
// Indicates a value that skips the encoding step
1770-
return inference.sdr(t1.getSDR()).recordNum(t1.getRecordNum()).layerInput(t1);
1797+
return inference.recordNum(getRecordNum()).sdr(t1.getSDR()).recordNum(t1.getRecordNum()).layerInput(t1);
17711798
}
17721799
});
17731800
}

src/test/java/org/numenta/nupic/network/LayerTest.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ public void testLayerWithObservableInput() {
401401
fail();
402402
}
403403
}
404-
404+
405405
@Test
406406
public void testLayerWithObservableInputIntegerArray() {
407407
Publisher manual = Publisher.builder()
@@ -784,15 +784,15 @@ public void testBasicSetup_TemporalMemory_MANUAL_MODE() {
784784
int[] expected1 = { 1, 5, 11, 12, 13 };
785785
int[] expected2 = { 2, 3, 11, 12, 13, 14 };
786786
int[] expected3 = { 2, 3, 8, 9, 12, 17, 18 };
787-
int[] expected4 = { 2, 3, 8, 12, 17, 18 };
788-
int[] expected5 = { 2, 7, 8, 9, 17, 18, 19 };
789-
int[] expected6 = { 1, 7, 8, 9, 17, 18 };
790-
int[] expected7 = { 1, 5, 7, 11, 12, 16 };
787+
int[] expected4 = { 2, 3, 7, 8, 9, 12, 17, 18, 19 };
788+
int[] expected5 = { 1, 2, 3, 7, 8, 9, 12, 17, 18, 19 };
789+
int[] expected6 = { 1, 5, 7, 8, 9, 11, 12, 16, 17, 18, 19 };
790+
int[] expected7 = { 1, 5, 7, 8, 9, 11, 12, 16, 17, 18 };
791791
final int[][] expecteds = { expected1, expected2, expected3, expected4, expected5, expected6, expected7 };
792792

793793
Layer<int[]> l = new Layer<>(p, null, null, new TemporalMemory(), null, null);
794794

795-
int timeUntilStable = 400;
795+
int timeUntilStable = 600;
796796

797797
l.subscribe(new Observer<Inference>() {
798798
int test = 0;
@@ -803,7 +803,8 @@ public void testBasicSetup_TemporalMemory_MANUAL_MODE() {
803803
@Override
804804
public void onNext(Inference output) {
805805
if(seq / 7 >= timeUntilStable) {
806-
//System.out.println("seq: " + (seq) + " --> " + (test) + " output = " + Arrays.toString(output.getSDR()));
806+
// System.out.println("seq: " + (seq) + " --> " + (test) + " output = " + Arrays.toString(output.getSDR()) +
807+
// "\t\t\t\t\t exp = " + Arrays.toString(expecteds[test]));
807808
assertTrue(Arrays.equals(expecteds[test], output.getSDR()));
808809
}
809810

src/test/java/org/numenta/nupic/network/NetworkTest.java

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
import org.numenta.nupic.datagen.ResourceLocator;
4747
import org.numenta.nupic.encoders.MultiEncoder;
4848
import org.numenta.nupic.network.sensor.FileSensor;
49+
import org.numenta.nupic.network.sensor.HTMSensor;
50+
import org.numenta.nupic.network.sensor.ObservableSensor;
51+
import org.numenta.nupic.network.sensor.Publisher;
4952
import org.numenta.nupic.network.sensor.Sensor;
5053
import org.numenta.nupic.network.sensor.SensorParams;
5154
import org.numenta.nupic.network.sensor.SensorParams.Keys;
@@ -636,4 +639,129 @@ public void testThreadedStartFlagging() {
636639
}
637640
}
638641

642+
double anomaly = 1;
643+
boolean completed = false;
644+
@Test
645+
public void testObservableWithCoordinateEncoder() {
646+
Publisher manual = Publisher.builder()
647+
.addHeader("timestamp,consumption,location")
648+
.addHeader("datetime,float,geo")
649+
.addHeader("T,,").build();
650+
651+
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
652+
ObservableSensor::create, SensorParams.create(Keys::obs, "", manual));
653+
654+
Parameters p = NetworkTestHarness.getParameters().copy();
655+
p = p.union(NetworkTestHarness.getGeospatialTestEncoderParams());
656+
p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42));
657+
658+
HTMSensor<ObservableSensor<String[]>> htmSensor = (HTMSensor<ObservableSensor<String[]>>)sensor;
659+
660+
Network network = Network.create("test network", p)
661+
.add(Network.createRegion("r1")
662+
.add(Network.createLayer("1", p)
663+
.add(Anomaly.create())
664+
.add(new TemporalMemory())
665+
.add(new SpatialPooler())
666+
.add(htmSensor)));
667+
668+
network.start();
669+
670+
network.observe().subscribe(new Observer<Inference>() {
671+
@Override public void onCompleted() {
672+
assertEquals(0, anomaly, 0);
673+
completed = true;
674+
}
675+
@Override public void onError(Throwable e) { e.printStackTrace(); }
676+
@Override public void onNext(Inference output) {
677+
//System.out.println(output.getRecordNum() + ": input = " + Arrays.toString(output.getEncoding()));//output = " + Arrays.toString(output.getSDR()) + ", " + output.getAnomalyScore());
678+
if(output.getAnomalyScore() < anomaly) {
679+
anomaly = output.getAnomalyScore();
680+
System.out.println("anomaly = " + anomaly);
681+
}
682+
}
683+
});
684+
685+
int x = 0;
686+
for(int i = 0;i < 100;i++) {
687+
x = i % 10;
688+
manual.onNext("7/12/10 13:10,35.3,40.6457;-73.7" + x + "692;" + x); //5 = meters per second
689+
}
690+
691+
manual.onComplete();
692+
693+
Layer<?> l = network.lookup("r1").lookup("1");
694+
try {
695+
l.getLayerThread().join();
696+
}catch(Exception e) {
697+
e.printStackTrace();
698+
}
699+
700+
assertTrue(completed);
701+
702+
}
703+
704+
String errorMessage = null;
705+
@Test
706+
public void testObservableWithCoordinateEncoder_NEGATIVE() {
707+
Publisher manual = Publisher.builder()
708+
.addHeader("timestamp,consumption,location")
709+
.addHeader("datetime,float,geo")
710+
.addHeader("T,,").build();
711+
712+
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
713+
ObservableSensor::create, SensorParams.create(Keys::obs, "", manual));
714+
715+
Parameters p = NetworkTestHarness.getParameters().copy();
716+
p = p.union(NetworkTestHarness.getGeospatialTestEncoderParams());
717+
p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42));
718+
719+
HTMSensor<ObservableSensor<String[]>> htmSensor = (HTMSensor<ObservableSensor<String[]>>)sensor;
720+
721+
Network network = Network.create("test network", p)
722+
.add(Network.createRegion("r1")
723+
.add(Network.createLayer("1", p)
724+
.alterParameter(KEY.AUTO_CLASSIFY, Boolean.TRUE)
725+
.add(Anomaly.create())
726+
.add(new TemporalMemory())
727+
.add(new SpatialPooler())
728+
.add(htmSensor)));
729+
730+
network.observe().subscribe(new Observer<Inference>() {
731+
@Override public void onCompleted() {
732+
//Should never happen here.
733+
assertEquals(0, anomaly, 0);
734+
completed = true;
735+
}
736+
@Override public void onError(Throwable e) {
737+
errorMessage = e.getMessage();
738+
network.halt();
739+
}
740+
@Override public void onNext(Inference output) {}
741+
});
742+
743+
network.start();
744+
745+
int x = 0;
746+
for(int i = 0;i < 100;i++) {
747+
x = i % 10;
748+
manual.onNext("7/12/10 13:10,35.3,40.6457;-73.7" + x + "692;" + x); //1st "x" is attempt to vary coords, 2nd "x" = meters per second
749+
}
750+
751+
manual.onComplete();
752+
753+
Layer<?> l = network.lookup("r1").lookup("1");
754+
try {
755+
l.getLayerThread().join();
756+
}catch(Exception e) {
757+
assertEquals(InterruptedException.class, e.getClass());
758+
}
759+
760+
// Assert onNext condition never gets set
761+
assertFalse(completed);
762+
assertEquals("Cannot autoclassify with raw array input or " +
763+
"Coordinate based encoders... Remove auto classify setting.", errorMessage);
764+
}
765+
766+
639767
}

vocabulary.dictionary

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,6 @@ keyframes
202202
popup
203203
affero
204204
deselected
205+
coords
206+
expecteds
207+
exp

0 commit comments

Comments
 (0)