Skip to content

Commit b5a54fc

Browse files
committed
Coerce lists/arrays to the expected ArrayParameter subtype in ParameterSpace
When the schema specifies a subclass of ArrayParameter (e.g. Sequence), _setitem_value was hardcoding ArrayParameter as the coercion target. This meant passing a list of numpy arrays as spike_times would produce ArrayParameter elements rather than Sequence elements, causing .get() to fail. Fix by using expected_dtype in all three coercion branches. Fixes NeuralEnsemble#709
1 parent 693a260 commit b5a54fc

2 files changed

Lines changed: 35 additions & 3 deletions

File tree

pyNN/parameters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,14 @@ def _setitem_value(self, name, value):
350350
valid_parameter_names=self.schema.keys())
351351
if issubclass(expected_dtype, ArrayParameter) and isinstance(value, Sized):
352352
if len(value) == 0:
353-
value = ArrayParameter([])
353+
value = expected_dtype([])
354354
elif not isinstance(value[0], ArrayParameter):
355355
# may be a more generic way to do it, but for now this special-casing
356356
# seems like the most robust approach
357357
if isinstance(value[0], Sized): # e.g. list of tuples
358-
value = type(value)([ArrayParameter(x) for x in value])
358+
value = type(value)([expected_dtype(x) for x in value])
359359
else:
360-
value = ArrayParameter(value)
360+
value = expected_dtype(value)
361361
try:
362362
self._parameters[name] = LazyArray(value, shape=self._shape,
363363
dtype=expected_dtype)

test/unittests/test_parameters.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,38 @@ def test_create_with_list_of_lists(self):
482482
assert_array_equal(ps['a'], np.array(
483483
[Sequence([1, 2, 3]), Sequence([4, 5, 6])], dtype=Sequence))
484484

485+
def test_create_with_plain_list_produces_sequence(self):
486+
schema = {'a': Sequence}
487+
ps = ParameterSpace({'a': [1, 2, 3]}, schema, shape=(2,))
488+
ps.evaluate()
489+
result = ps['a']
490+
assert type(result[0]) == Sequence
491+
assert_array_equal(result, np.array([Sequence([1, 2, 3]), Sequence([1, 2, 3])], dtype=Sequence))
492+
493+
def test_create_with_numpy_array_produces_sequence(self):
494+
schema = {'a': Sequence}
495+
ps = ParameterSpace({'a': np.array([1, 2, 3])}, schema, shape=(2,))
496+
ps.evaluate()
497+
result = ps['a']
498+
assert type(result[0]) == Sequence
499+
assert_array_equal(result, np.array([Sequence([1, 2, 3]), Sequence([1, 2, 3])], dtype=Sequence))
500+
501+
def test_create_with_list_of_numpy_arrays_produces_sequences(self):
502+
schema = {'a': Sequence}
503+
ps = ParameterSpace({'a': [np.array([1, 2]), np.array([3, 4])]}, schema, shape=(2,))
504+
ps.evaluate()
505+
result = ps['a']
506+
assert type(result[0]) == Sequence
507+
assert type(result[1]) == Sequence
508+
assert_array_equal(result, np.array([Sequence([1, 2]), Sequence([3, 4])], dtype=Sequence))
509+
510+
def test_create_with_empty_list_produces_sequence(self):
511+
schema = {'a': Sequence}
512+
ps = ParameterSpace({'a': []}, schema, shape=(2,))
513+
ps.evaluate()
514+
result = ps['a']
515+
assert type(result[0]) == Sequence
516+
485517
def test_keys(self):
486518
ps = ParameterSpace({'a': [2, 3, 5, 8, 13], 'b': 7, 'c': lambda i: 3 * i + 2}, shape=(5,))
487519
assert list(ps.keys()) == ["a", "b", "c"]

0 commit comments

Comments
 (0)