diff --git a/numtraits.py b/numtraits.py index 45d059f..df0c5c3 100644 --- a/numtraits.py +++ b/numtraits.py @@ -97,7 +97,7 @@ def validate(self, obj, value): if self.ndim is not None: if self.ndim == 0: - if not is_scalar: + if not is_scalar and num_value.ndim: raise TraitError("{0} should be a scalar value".format(self.name)) if self.ndim > 0: diff --git a/test_numtraits.py b/test_numtraits.py index 5cd9ff7..96cfef1 100644 --- a/test_numtraits.py +++ b/test_numtraits.py @@ -79,6 +79,11 @@ def test_range(self): self.sp.f = 7 assert exc.value.args[0] == "f should be in the range [3:4]" + def test_scalar_quantities(self): + """ Tests for issue #14. + """ + quantities = pytest.importorskip("quantities") + self.sp.a = 1*quantities.m class ArrayProperties(HasTraits): @@ -90,6 +95,7 @@ class ArrayProperties(HasTraits): f = NumericalTrait(domain=(3, 4), ndim=1) g = NumericalTrait(shape=(3, 4)) + class TestArray(object): def setup_method(self, method):