forked from zarr-developers/zarr-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_dtype_registry.py
More file actions
201 lines (170 loc) · 6.99 KB
/
test_dtype_registry.py
File metadata and controls
201 lines (170 loc) · 6.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from __future__ import annotations
import re
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, get_args
import numpy as np
import pytest
import zarr
from tests.conftest import skip_object_dtype
from zarr.core.config import config
from zarr.core.dtype import (
AnyDType,
Bool,
DataTypeRegistry,
DateTime64,
FixedLengthUTF32,
Int8,
Int16,
TBaseDType,
TBaseScalar,
VariableLengthUTF8,
ZDType,
data_type_registry,
get_data_type_from_json,
parse_data_type,
)
if TYPE_CHECKING:
from collections.abc import Generator
from zarr.core.common import ZarrFormat
from .test_dtype.conftest import zdtype_examples
@pytest.fixture
def data_type_registry_fixture() -> DataTypeRegistry:
return DataTypeRegistry()
class TestRegistry:
@staticmethod
def test_register(data_type_registry_fixture: DataTypeRegistry) -> None:
"""
Test that registering a dtype in a data type registry works.
"""
data_type_registry_fixture.register(Bool._zarr_v3_name, Bool)
assert data_type_registry_fixture.get(Bool._zarr_v3_name) == Bool
assert isinstance(data_type_registry_fixture.match_dtype(np.dtype("bool")), Bool)
@staticmethod
def test_override(data_type_registry_fixture: DataTypeRegistry) -> None:
"""
Test that registering a new dtype with the same name works (overriding the previous one).
"""
data_type_registry_fixture.register(Bool._zarr_v3_name, Bool)
class NewBool(Bool):
def default_scalar(self) -> np.bool_:
return np.True_
data_type_registry_fixture.register(NewBool._zarr_v3_name, NewBool)
assert isinstance(data_type_registry_fixture.match_dtype(np.dtype("bool")), NewBool)
@staticmethod
@pytest.mark.parametrize(
("wrapper_cls", "dtype_str"), [(Bool, "bool"), (FixedLengthUTF32, "|U4")]
)
def test_match_dtype(
data_type_registry_fixture: DataTypeRegistry,
wrapper_cls: type[ZDType[TBaseDType, TBaseScalar]],
dtype_str: str,
) -> None:
"""
Test that match_dtype resolves a numpy dtype into an instance of the correspond wrapper for that dtype.
"""
data_type_registry_fixture.register(wrapper_cls._zarr_v3_name, wrapper_cls)
assert isinstance(data_type_registry_fixture.match_dtype(np.dtype(dtype_str)), wrapper_cls)
@staticmethod
def test_unregistered_dtype(data_type_registry_fixture: DataTypeRegistry) -> None:
"""
Test that match_dtype raises an error if the dtype is not registered.
"""
outside_dtype_name = "int8"
outside_dtype = np.dtype(outside_dtype_name)
msg = f"No Zarr data type found that matches dtype '{outside_dtype!r}'"
with pytest.raises(ValueError, match=re.escape(msg)):
data_type_registry_fixture.match_dtype(outside_dtype)
with pytest.raises(KeyError):
data_type_registry_fixture.get(outside_dtype_name)
@staticmethod
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize("zdtype", zdtype_examples)
def test_registered_dtypes_match_dtype(zdtype: ZDType[TBaseDType, TBaseScalar]) -> None:
"""
Test that the registered dtypes can be retrieved from the registry.
"""
skip_object_dtype(zdtype)
assert data_type_registry.match_dtype(zdtype.to_native_dtype()) == zdtype
@staticmethod
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize("zdtype", zdtype_examples)
def test_registered_dtypes_match_json(
zdtype: ZDType[TBaseDType, TBaseScalar], zarr_format: ZarrFormat
) -> None:
assert (
data_type_registry.match_json(
zdtype.to_json(zarr_format=zarr_format), zarr_format=zarr_format
)
== zdtype
)
@staticmethod
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize("zdtype", zdtype_examples)
def test_match_dtype_unique(
zdtype: ZDType[Any, Any],
data_type_registry_fixture: DataTypeRegistry,
zarr_format: ZarrFormat,
) -> None:
"""
Test that the match_dtype method uniquely specifies a registered data type. We create a local registry
that excludes the data type class being tested, and ensure that an instance of the wrapped data type
fails to match anything in the registry
"""
skip_object_dtype(zdtype)
for _cls in get_args(AnyDType):
if _cls is not type(zdtype):
data_type_registry_fixture.register(_cls._zarr_v3_name, _cls)
dtype_instance = zdtype.to_native_dtype()
msg = f"No Zarr data type found that matches dtype '{dtype_instance!r}'"
with pytest.raises(ValueError, match=re.escape(msg)):
data_type_registry_fixture.match_dtype(dtype_instance)
instance_dict = zdtype.to_json(zarr_format=zarr_format)
msg = f"No Zarr data type found that matches {instance_dict!r}"
with pytest.raises(ValueError, match=re.escape(msg)):
data_type_registry_fixture.match_json(instance_dict, zarr_format=zarr_format)
# this is copied from the registry tests -- we should deduplicate
here = str(Path(__file__).parent.absolute())
@pytest.fixture
def set_path() -> Generator[None, None, None]:
sys.path.append(here)
zarr.registry._collect_entrypoints()
yield
sys.path.remove(here)
registries = zarr.registry._collect_entrypoints()
for registry in registries:
registry.lazy_load_list.clear()
config.reset()
@pytest.mark.usefixtures("set_path")
def test_entrypoint_dtype(zarr_format: ZarrFormat) -> None:
from package_with_entrypoint import TestDataType
data_type_registry.lazy_load()
instance = TestDataType()
dtype_json = instance.to_json(zarr_format=zarr_format)
assert get_data_type_from_json(dtype_json, zarr_format=zarr_format) == instance
data_type_registry.unregister(TestDataType._zarr_v3_name)
@pytest.mark.parametrize(
("dtype_params", "expected", "zarr_format"),
[
("str", VariableLengthUTF8(), 2),
("str", VariableLengthUTF8(), 3),
("int8", Int8(), 3),
(Int8(), Int8(), 3),
(">i2", Int16(endianness="big"), 2),
("datetime64[10s]", DateTime64(unit="s", scale_factor=10), 2),
(
{"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}},
DateTime64(unit="s", scale_factor=10),
3,
),
],
)
def test_parse_data_type(
dtype_params: Any, expected: ZDType[Any, Any], zarr_format: ZarrFormat
) -> None:
"""
Test that parse_data_type accepts alternative representations of ZDType instances, and resolves
those inputs to the expected ZDType instance.
"""
observed = parse_data_type(dtype_params, zarr_format=zarr_format)
assert observed == expected