Skip to content

Commit f4d1221

Browse files
flying-sheepd-v-b
andauthored
fix: raise error when encountering nullable string dtype (zarr-developers#3695)
* fix: raise error when encountering nullable string dtype * add change entry * fix typing * move exception on na_object-bearing-string-dtype to inside from_native_dtype --------- Co-authored-by: Davis Vann Bennett <davis.v.bennett@gmail.com>
1 parent 229a690 commit f4d1221

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

changes/3695.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Raise error when trying to encode :class:`numpy.dtypes.StringDType` with `na_object` set.

src/zarr/core/dtype/npy/string.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,43 @@ class VariableLengthUTF8(UTF8Base[np.dtypes.StringDType]): # type: ignore[type-
742742

743743
dtype_cls = np.dtypes.StringDType
744744

745+
@classmethod
746+
def from_native_dtype(cls, dtype: TBaseDType) -> Self:
747+
"""
748+
Create an instance of this data type from a compatible NumPy data type.
749+
We reject NumPy StringDType instances that have the `na_object` field set,
750+
because this is not representable by the Zarr `string` data type.
751+
752+
Parameters
753+
----------
754+
dtype : TBaseDType
755+
The native data type.
756+
757+
Returns
758+
-------
759+
Self
760+
An instance of this data type.
761+
762+
Raises
763+
------
764+
DataTypeValidationError
765+
If the input is not compatible with this data type.
766+
ValueError
767+
If the input is `numpy.dtypes.StringDType` and has `na_object` set.
768+
"""
769+
if cls._check_native_dtype(dtype):
770+
if hasattr(dtype, "na_object"):
771+
msg = (
772+
f"Zarr data type resolution from {dtype} failed. "
773+
"Attempted to resolve a zarr data type from a `numpy.dtypes.StringDType` "
774+
"with `na_object` set, which is not supported."
775+
)
776+
raise ValueError(msg)
777+
return cls()
778+
raise DataTypeValidationError(
779+
f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}"
780+
)
781+
745782
def to_native_dtype(self) -> np.dtypes.StringDType:
746783
"""
747784
Create a NumPy string dtype from this VariableLengthUTF8 ZDType.

src/zarr/core/dtype/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def match_dtype(self, dtype: TBaseDType) -> ZDType[TBaseDType, TBaseScalar]:
161161
raise ValueError(msg)
162162
matched: list[ZDType[TBaseDType, TBaseScalar]] = []
163163
for val in self.contents.values():
164+
# DataTypeValidationError means "this dtype doesn't match me", which is
165+
# expected and suppressed. Other exceptions (e.g. ValueError for a dtype
166+
# that matches the type but has an invalid configuration) are propagated
167+
# to the caller.
164168
with contextlib.suppress(DataTypeValidationError):
165169
matched.append(val.from_native_dtype(dtype))
166170
if len(matched) == 1:

tests/test_dtype_registry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
get_data_type_from_json,
1616
)
1717
from zarr.core.dtype.common import unpack_dtype_json
18+
from zarr.core.dtype.npy.string import _NUMPY_SUPPORTS_VLEN_STRING
1819
from zarr.dtype import ( # type: ignore[attr-defined]
1920
Bool,
2021
FixedLengthUTF32,
22+
VariableLengthUTF8,
2123
ZDType,
2224
data_type_registry,
2325
parse_data_type,
@@ -74,6 +76,16 @@ def test_match_dtype(
7476
data_type_registry_fixture.register(wrapper_cls._zarr_v3_name, wrapper_cls)
7577
assert isinstance(data_type_registry_fixture.match_dtype(np.dtype(dtype_str)), wrapper_cls)
7678

79+
@pytest.mark.skipif(not _NUMPY_SUPPORTS_VLEN_STRING, reason="requires numpy with T dtype")
80+
@staticmethod
81+
def test_match_dtype_string_na_object_error(
82+
data_type_registry_fixture: DataTypeRegistry,
83+
) -> None:
84+
data_type_registry_fixture.register(VariableLengthUTF8._zarr_v3_name, VariableLengthUTF8) # type: ignore[arg-type]
85+
dtype: np.dtype[Any] = np.dtypes.StringDType(na_object=None) # type: ignore[call-arg]
86+
with pytest.raises(ValueError, match=r"Zarr data type resolution from StringDType.*failed"):
87+
data_type_registry_fixture.match_dtype(dtype)
88+
7789
@staticmethod
7890
def test_unregistered_dtype(data_type_registry_fixture: DataTypeRegistry) -> None:
7991
"""

0 commit comments

Comments
 (0)