Skip to content

Commit ed56505

Browse files
committed
add parse_dtype as ergonomic replacement for parse_data_type, handle more JSON-like inputs, and test for round-trips
1 parent 2420f9e commit ed56505

6 files changed

Lines changed: 87 additions & 37 deletions

File tree

src/zarr/core/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
VariableLengthUTF8,
7474
ZDType,
7575
ZDTypeLike,
76-
parse_data_type,
76+
parse_dtype,
7777
)
7878
from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec
7979
from zarr.core.indexing import (
@@ -618,7 +618,7 @@ async def _create(
618618
Deprecated in favor of :func:`zarr.api.asynchronous.create_array`.
619619
"""
620620

621-
dtype_parsed = parse_data_type(dtype, zarr_format=zarr_format)
621+
dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
622622
store_path = await make_store_path(store)
623623

624624
shape = parse_shapelike(shape)
@@ -4239,7 +4239,7 @@ async def init_array(
42394239

42404240
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
42414241

4242-
zdtype = parse_data_type(dtype, zarr_format=zarr_format)
4242+
zdtype = parse_dtype(dtype, zarr_format=zarr_format)
42434243
shape_parsed = parse_shapelike(shape)
42444244
chunk_key_encoding_parsed = _parse_chunk_key_encoding(
42454245
chunk_key_encoding, zarr_format=zarr_format

src/zarr/core/dtype/__init__.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import Sequence
34
from typing import TYPE_CHECKING, Final, TypeAlias
45

56
from zarr.core.dtype.common import (
@@ -94,6 +95,7 @@
9495
"ZDType",
9596
"data_type_registry",
9697
"parse_data_type",
98+
"parse_dtype",
9799
]
98100

99101
data_type_registry = DataTypeRegistry()
@@ -188,13 +190,69 @@ def parse_data_type(
188190
zarr_format: ZarrFormat,
189191
) -> ZDType[TBaseDType, TBaseScalar]:
190192
"""
191-
Interpret the input as a ZDType instance.
193+
Interpret the input as a ZDType.
194+
195+
This function wraps ``parse_dtype``. The only difference is the function name. This function may
196+
be deprecated in a future version of Zarr Python in favor of ``parse_dtype``.
197+
198+
Parameters
199+
----------
200+
dtype_spec : ZDTypeLike
201+
The input to be interpreted as a ZDType. This could be a ZDType, which will be returned
202+
directly, or a JSON representation of a ZDType, or a native dtype, or a python object that
203+
can be converted into a native dtype.
204+
zarr_format : ZarrFormat
205+
The Zarr format version.
206+
207+
Returns
208+
-------
209+
ZDType[TBaseDType, TBaseScalar]
210+
The ZDType corresponding to the input.
211+
212+
Examples
213+
--------
214+
>>> parse_dtype("int32", zarr_format=2)
215+
Int32(endianness="little")
216+
"""
217+
return parse_dtype(dtype_spec, zarr_format=zarr_format)
218+
219+
220+
def parse_dtype(
221+
dtype_spec: ZDTypeLike,
222+
*,
223+
zarr_format: ZarrFormat,
224+
) -> ZDType[TBaseDType, TBaseScalar]:
225+
"""
226+
Interpret the input as a ZDType.
227+
228+
Parameters
229+
----------
230+
dtype_spec : ZDTypeLike
231+
The input to be interpreted as a ZDType. This could be a ZDType, which will be returned
232+
directly, or a JSON representation of a ZDType, or a native dtype, or a python object that
233+
can be converted into a native dtype.
234+
zarr_format : ZarrFormat
235+
The Zarr format version.
236+
237+
Returns
238+
-------
239+
ZDType[TBaseDType, TBaseScalar]
240+
The ZDType corresponding to the input.
241+
242+
Examples
243+
--------
244+
>>> parse_dtype("int32", zarr_format=2)
245+
Int32(endianness="little")
192246
"""
193247
if isinstance(dtype_spec, ZDType):
194248
return dtype_spec
195-
# dict and zarr_format 3 means that we have a JSON object representation of the dtype
196-
if zarr_format == 3 and isinstance(dtype_spec, Mapping):
197-
return get_data_type_from_json(dtype_spec, zarr_format=3)
249+
# First attempt to interpret the input as JSON
250+
if isinstance(dtype_spec, Mapping | str | Sequence):
251+
try:
252+
return get_data_type_from_json(dtype_spec, zarr_format=3) # type: ignore[arg-type]
253+
except ValueError:
254+
# no data type matched this JSON-like input
255+
pass
198256
if dtype_spec in VLEN_UTF8_ALIAS:
199257
# If the dtype request is one of the aliases for variable-length UTF-8 strings,
200258
# return that dtype.

src/zarr/dtype.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ZDType,
4040
data_type_registry,
4141
parse_data_type,
42+
parse_dtype,
4243
)
4344

4445
__all__ = [
@@ -84,4 +85,5 @@
8485
"data_type_registry",
8586
"data_type_registry",
8687
"parse_data_type",
88+
"parse_dtype",
8789
]

tests/test_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
VariableLengthBytes,
5454
VariableLengthUTF8,
5555
ZDType,
56-
parse_data_type,
56+
parse_dtype,
5757
)
5858
from zarr.core.dtype.common import ENDIANNESS_STR, EndiannessStr
5959
from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str
@@ -1308,7 +1308,7 @@ async def test_v2_chunk_encoding(
13081308
filters=filters,
13091309
)
13101310
filters_expected, compressor_expected = _parse_chunk_encoding_v2(
1311-
filters=filters, compressor=compressors, dtype=parse_data_type(dtype, zarr_format=2)
1311+
filters=filters, compressor=compressors, dtype=parse_dtype(dtype, zarr_format=2)
13121312
)
13131313
assert arr.metadata.zarr_format == 2 # guard for mypy
13141314
assert arr.metadata.compressor == compressor_expected

tests/test_dtype_registry.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@
1515
AnyDType,
1616
Bool,
1717
DataTypeRegistry,
18-
DateTime64,
1918
FixedLengthUTF32,
20-
Int8,
21-
Int16,
2219
TBaseDType,
2320
TBaseScalar,
24-
VariableLengthUTF8,
2521
ZDType,
2622
data_type_registry,
2723
get_data_type_from_json,
28-
parse_data_type,
24+
parse_dtype,
2925
)
3026

3127
if TYPE_CHECKING:
@@ -174,28 +170,22 @@ def test_entrypoint_dtype(zarr_format: ZarrFormat) -> None:
174170
data_type_registry.unregister(TestDataType._zarr_v3_name)
175171

176172

177-
@pytest.mark.parametrize(
178-
("dtype_params", "expected", "zarr_format"),
179-
[
180-
("str", VariableLengthUTF8(), 2),
181-
("str", VariableLengthUTF8(), 3),
182-
("int8", Int8(), 3),
183-
(Int8(), Int8(), 3),
184-
(">i2", Int16(endianness="big"), 2),
185-
("datetime64[10s]", DateTime64(unit="s", scale_factor=10), 2),
186-
(
187-
{"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}},
188-
DateTime64(unit="s", scale_factor=10),
189-
3,
190-
),
191-
],
192-
)
193-
def test_parse_data_type(
194-
dtype_params: Any, expected: ZDType[Any, Any], zarr_format: ZarrFormat
195-
) -> None:
173+
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
174+
@pytest.mark.parametrize("data_type", zdtype_examples, ids=str)
175+
def test_parse_data_type(data_type: ZDType[Any, Any], zarr_format: ZarrFormat) -> None:
196176
"""
197177
Test that parse_data_type accepts alternative representations of ZDType instances, and resolves
198178
those inputs to the expected ZDType instance.
199179
"""
200-
observed = parse_data_type(dtype_params, zarr_format=zarr_format)
201-
assert observed == expected
180+
dtype_spec: Any
181+
if zarr_format == 2:
182+
dtype_spec = data_type.to_json(zarr_format=zarr_format)["name"]
183+
else:
184+
dtype_spec = data_type.to_json(zarr_format=zarr_format)
185+
if dtype_spec == "|O":
186+
msg = "Zarr data type resolution from object failed."
187+
with pytest.raises(ValueError, match=msg):
188+
parse_dtype(dtype_spec, zarr_format=zarr_format)
189+
else:
190+
observed = parse_dtype(dtype_spec, zarr_format=zarr_format) # type: ignore[arg-type]
191+
assert observed == data_type

tests/test_metadata/test_consolidated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
open_consolidated,
1919
)
2020
from zarr.core.buffer import cpu, default_buffer_prototype
21-
from zarr.core.dtype import parse_data_type
21+
from zarr.core.dtype import parse_dtype
2222
from zarr.core.group import ConsolidatedMetadata, GroupMetadata
2323
from zarr.core.metadata import ArrayV3Metadata
2424
from zarr.core.metadata.v2 import ArrayV2Metadata
@@ -504,7 +504,7 @@ async def test_consolidated_metadata_backwards_compatibility(
504504
async def test_consolidated_metadata_v2(self):
505505
store = zarr.storage.MemoryStore()
506506
g = await AsyncGroup.from_store(store, attributes={"key": "root"}, zarr_format=2)
507-
dtype = parse_data_type("uint8", zarr_format=2)
507+
dtype = parse_dtype("uint8", zarr_format=2)
508508
await g.create_array(name="a", shape=(1,), attributes={"key": "a"}, dtype=dtype)
509509
g1 = await g.create_group(name="g1", attributes={"key": "g1"})
510510
await g1.create_group(name="g2", attributes={"key": "g2"})

0 commit comments

Comments
 (0)