Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/zarr/core/dtype/npy/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,15 @@ def _check_scalar(self, data: object) -> TypeGuard[FloatLike]:
TypeGuard[FloatLike]
True if the input is a valid scalar value, False otherwise.
"""
if isinstance(data, str):
# Only accept strings that are valid float representations (e.g. "NaN", "inf").
# Plain strings that cannot be converted should return False so that cast_scalar
# raises TypeError rather than a confusing ValueError.
try:
self.to_native_dtype().type(data)
return True
except (ValueError, OverflowError):
return False
return isinstance(data, FloatLike)

def _cast_scalar_unchecked(self, data: FloatLike) -> TFloatScalar_co:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dtype/test_npy/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class TestFloat16(_BaseTestFloat):
(Float16(), -1.0, np.float16(-1.0)),
(Float16(), "NaN", np.float16("NaN")),
)
invalid_scalar_params = ((Float16(), {"set!"}),)
invalid_scalar_params = ((Float16(), {"set!"}), (Float16(), "not_a_float"),)
hex_string_params = (("0x7fc0", np.nan), ("0x7fc1", np.nan), ("0x3c00", 1.0))
item_size_params = (Float16(),)

Expand Down Expand Up @@ -113,7 +113,7 @@ class TestFloat32(_BaseTestFloat):
(Float32(), -1.0, np.float32(-1.0)),
(Float32(), "NaN", np.float32("NaN")),
)
invalid_scalar_params = ((Float32(), {"set!"}),)
invalid_scalar_params = ((Float32(), {"set!"}), (Float32(), "not_a_float"),)
hex_string_params = (("0x7fc00000", np.nan), ("0x7fc00001", np.nan), ("0x3f800000", 1.0))
item_size_params = (Float32(),)

Expand Down Expand Up @@ -160,7 +160,7 @@ class TestFloat64(_BaseTestFloat):
(Float64(), -1.0, np.float64(-1.0)),
(Float64(), "NaN", np.float64("NaN")),
)
invalid_scalar_params = ((Float64(), {"set!"}),)
invalid_scalar_params = ((Float64(), {"set!"}), (Float64(), "not_a_float"),)
hex_string_params = (
("0x7ff8000000000000", np.nan),
("0x7ff8000000000001", np.nan),
Expand Down
Loading