Skip to content

Commit 9e92bc5

Browse files
committed
make v3 regression testing for string work properly
1 parent 3463b02 commit 9e92bc5

3 files changed

Lines changed: 140 additions & 15 deletions

File tree

tests/test_dtype/test_npy/test_string.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TestVariableLengthString(BaseTestZDType):
1919
np.dtype("|S10"),
2020
)
2121
valid_json_v2 = ({"name": "|O", "object_codec_id": "vlen-utf8"},)
22-
valid_json_v3 = ("variable_length_utf8",)
22+
valid_json_v3 = ("string",)
2323
invalid_json_v2 = (
2424
"|S10",
2525
"|f8",
@@ -53,7 +53,7 @@ class TestVariableLengthString(BaseTestZDType): # type: ignore[no-redef]
5353
np.dtype("|S10"),
5454
)
5555
valid_json_v2 = ({"name": "|O", "object_codec_id": "vlen-utf8"},)
56-
valid_json_v3 = ("variable_length_utf8",)
56+
valid_json_v3 = ("string",)
5757
invalid_json_v2 = (
5858
"|S10",
5959
"|f8",

tests/test_regression/scripts/v3.0.8.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,53 @@
55
# ]
66
# ///
77

8+
# /// script
9+
# requires-python = ">=3.11"
10+
# dependencies = [
11+
# "zarr==2.18",
12+
# "numcodecs==0.15"
13+
# ]
14+
# ///
15+
16+
import argparse
17+
18+
import zarr
19+
from zarr.abc.store import Store
20+
21+
def copy_group(
22+
*, node: zarr.Group, store: Store, path: str, overwrite: bool
23+
) -> zarr.Group:
24+
result = zarr.create_group(
25+
store=store,
26+
path=path,
27+
overwrite=overwrite,
28+
attributes=node.attrs.asdict(),
29+
zarr_format=node.metadata.zarr_format)
30+
for key, child in node.members():
31+
child_path = f"{path}/{key}"
32+
if isinstance(child, zarr.Group):
33+
copy_group(node=child, store=store, path=child_path, overwrite=overwrite)
34+
else:
35+
copy_array(node=child, store=store, overwrite=overwrite, path=child_path)
36+
return result
37+
38+
39+
def copy_array(
40+
*, node: zarr.Array, store: Store, path: str, overwrite: bool
41+
) -> zarr.Array:
42+
result = zarr.from_array(store, name=path, data=node, write_data=True)
43+
return result
44+
45+
46+
def copy_node(
47+
node: zarr.Group | zarr.Array, store: Store, path: str, overwrite: bool
48+
) -> zarr.Group | zarr.Array:
49+
if isinstance(node, zarr.Group):
50+
return copy_group(node=node, store=store, path=path, overwrite=overwrite)
51+
else:
52+
return copy_array(node=node, store=store, path=path, overwrite=overwrite)
53+
54+
855
def cli() -> None:
956
parser = argparse.ArgumentParser(
1057
description="Copy a zarr hierarchy from one location to another"
@@ -15,7 +62,7 @@ def cli() -> None:
1562

1663
src, dst = args.source, args.destination
1764
root_src = zarr.open(src, mode="r")
18-
result = copy_node(node=root_src, store=zarr.NestedDirectoryStore(dst), path="", overwrite=True)
65+
result = copy_node(node=root_src, store=dst, path="", overwrite=True)
1966

2067
print(f"successfully created {result} at {dst}")
2168

tests/test_regression/test_v2_dtype_regression.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from numcodecs import LZ4, LZMA, Blosc, GZip, VLenBytes, VLenUTF8, Zstd
1111

1212
import zarr
13+
import zarr.abc
14+
import zarr.abc.codec
15+
import zarr.codecs as zarrcodecs
1316
from zarr.core.array import Array
1417
from zarr.core.chunk_key_encodings import V2ChunkKeyEncoding
1518
from zarr.core.dtype.npy.bytes import VariableLengthBytes
@@ -38,6 +41,7 @@ class ArrayParams:
3841
values: np.ndarray[tuple[int], np.dtype[np.generic]]
3942
fill_value: np.generic | str | int | bytes
4043
filters: tuple[numcodecs.abc.Codec, ...] = ()
44+
serializer: str | None = None
4145
compressor: numcodecs.abc.Codec
4246

4347

@@ -77,15 +81,15 @@ class ArrayParams:
7781
ArrayParams(
7882
values=np.array(["a", "bb", "ccc", "dddd"], dtype="O"),
7983
fill_value="1",
80-
filters=(VLenUTF8(),),
84+
serializer="vlen-utf8",
8185
compressor=GZip(),
8286
)
8387
]
8488
vlen_bytes_cases = [
8589
ArrayParams(
8690
values=np.array([b"a", b"bb", b"ccc", b"dddd"], dtype="O"),
8791
fill_value=b"1",
88-
filters=(VLenBytes(),),
92+
serializer="vlen-bytes",
8993
compressor=GZip(),
9094
)
9195
]
@@ -102,7 +106,7 @@ class ArrayParams:
102106

103107

104108
@pytest.fixture
105-
def source_array(tmp_path: Path, request: pytest.FixtureRequest) -> Array:
109+
def source_array_v2(tmp_path: Path, request: pytest.FixtureRequest) -> Array:
106110
"""
107111
Writes a zarr array to a temporary directory based on the provided ArrayParams. The array is
108112
returned.
@@ -113,19 +117,22 @@ def source_array(tmp_path: Path, request: pytest.FixtureRequest) -> Array:
113117
compressor = array_params.compressor
114118
chunk_key_encoding = V2ChunkKeyEncoding(separator="/")
115119
dtype: ZDTypeLike
116-
if array_params.values.dtype == np.dtype("|O") and array_params.filters == (VLenUTF8(),):
120+
if array_params.values.dtype == np.dtype("|O") and array_params.serializer == "vlen-utf8":
117121
dtype = VariableLengthUTF8() # type: ignore[assignment]
118-
elif array_params.values.dtype == np.dtype("|O") and array_params.filters == (VLenBytes(),):
122+
filters = array_params.filters + (VLenUTF8(),)
123+
elif array_params.values.dtype == np.dtype("|O") and array_params.serializer == "vlen-bytes":
119124
dtype = VariableLengthBytes()
125+
filters = array_params.filters + (VLenBytes(),)
120126
else:
121127
dtype = array_params.values.dtype
128+
filters = array_params.filters
122129
z = zarr.create_array(
123130
store,
124131
shape=array_params.values.shape,
125132
dtype=dtype,
126133
chunks=array_params.values.shape,
127134
compressors=compressor,
128-
filters=array_params.filters,
135+
filters=filters,
129136
fill_value=array_params.fill_value,
130137
order="C",
131138
chunk_key_encoding=chunk_key_encoding,
@@ -136,29 +143,100 @@ def source_array(tmp_path: Path, request: pytest.FixtureRequest) -> Array:
136143
return z
137144

138145

146+
@pytest.fixture
147+
def source_array_v3(tmp_path: Path, request: pytest.FixtureRequest) -> Array:
148+
"""
149+
Writes a zarr array to a temporary directory based on the provided ArrayParams. The array is
150+
returned.
151+
"""
152+
dest = tmp_path / "in"
153+
store = LocalStore(dest)
154+
array_params: ArrayParams = request.param
155+
chunk_key_encoding = V2ChunkKeyEncoding(separator="/")
156+
dtype: ZDTypeLike
157+
serializer: Literal["auto"] | zarr.abc.codec.Codec
158+
if array_params.values.dtype == np.dtype("|O") and array_params.serializer == "vlen-utf8":
159+
dtype = VariableLengthUTF8() # type: ignore[assignment]
160+
serializer = zarrcodecs.VLenUTF8Codec()
161+
elif array_params.values.dtype == np.dtype("|O") and array_params.serializer == "vlen-bytes":
162+
dtype = VariableLengthBytes()
163+
serializer = zarrcodecs.VLenBytesCodec()
164+
else:
165+
dtype = array_params.values.dtype
166+
serializer = "auto"
167+
if array_params.compressor == GZip():
168+
compressor = zarrcodecs.GzipCodec()
169+
else:
170+
msg = (
171+
"This test is only compatible with gzip compression at the moment, because the author"
172+
"did not want to implement a complete abstraction layer for v2 and v3 codecs in this test."
173+
)
174+
raise ValueError(msg)
175+
z = zarr.create_array(
176+
store,
177+
shape=array_params.values.shape,
178+
dtype=dtype,
179+
chunks=array_params.values.shape,
180+
compressors=compressor,
181+
filters=array_params.filters,
182+
serializer=serializer,
183+
fill_value=array_params.fill_value,
184+
chunk_key_encoding=chunk_key_encoding,
185+
write_data=True,
186+
zarr_format=3,
187+
)
188+
z[:] = array_params.values
189+
return z
190+
191+
139192
# TODO: make this dynamic based on the installed scripts
140193
script_paths = [Path(__file__).resolve().parent / "scripts" / "v2.18.py"]
141194

142195

143196
@pytest.mark.skipif(not runner_installed(), reason="no python script runner installed")
144197
@pytest.mark.parametrize(
145-
"source_array", array_cases_v2_18, indirect=True, ids=tuple(map(str, array_cases_v2_18))
198+
"source_array_v2", array_cases_v2_18, indirect=True, ids=tuple(map(str, array_cases_v2_18))
146199
)
147200
@pytest.mark.parametrize("script_path", script_paths)
148-
def test_roundtrip_v2(source_array: Array, tmp_path: Path, script_path: Path) -> None:
201+
def test_roundtrip_v2(source_array_v2: Array, tmp_path: Path, script_path: Path) -> None:
149202
out_path = tmp_path / "out"
150203
copy_op = subprocess.run(
151204
[
152205
"uv",
153206
"run",
154-
script_path,
155-
str(source_array.store).removeprefix("file://"),
207+
str(script_path),
208+
str(source_array_v2.store).removeprefix("file://"),
156209
str(out_path),
157210
],
158211
capture_output=True,
159212
text=True,
160213
)
161214
assert copy_op.returncode == 0
162215
out_array = zarr.open_array(store=out_path, mode="r", zarr_format=2)
163-
assert source_array.metadata.to_dict() == out_array.metadata.to_dict()
164-
assert np.array_equal(source_array[:], out_array[:])
216+
assert source_array_v2.metadata.to_dict() == out_array.metadata.to_dict()
217+
assert np.array_equal(source_array_v2[:], out_array[:])
218+
219+
220+
@pytest.mark.skipif(not runner_installed(), reason="no python script runner installed")
221+
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
222+
@pytest.mark.parametrize(
223+
"source_array_v3", array_cases_v3_08, indirect=True, ids=tuple(map(str, array_cases_v3_08))
224+
)
225+
def test_roundtrip_v3(source_array_v3: Array, tmp_path: Path) -> None:
226+
script_path = Path(__file__).resolve().parent / "scripts" / "v3.0.8.py"
227+
out_path = tmp_path / "out"
228+
copy_op = subprocess.run(
229+
[
230+
"uv",
231+
"run",
232+
str(script_path),
233+
str(source_array_v3.store).removeprefix("file://"),
234+
str(out_path),
235+
],
236+
capture_output=True,
237+
text=True,
238+
)
239+
assert copy_op.returncode == 0
240+
out_array = zarr.open_array(store=out_path, mode="r", zarr_format=3)
241+
assert source_array_v3.metadata.to_dict() == out_array.metadata.to_dict()
242+
assert np.array_equal(source_array_v3[:], out_array[:])

0 commit comments

Comments
 (0)