Skip to content

Commit 281538a

Browse files
committed
consolidate prototype testing
1 parent a3283a9 commit 281538a

File tree

1 file changed

+25
-83
lines changed

1 file changed

+25
-83
lines changed

src/zarr/testing/store.py

Lines changed: 25 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Store,
2424
SuffixByteRequest,
2525
)
26-
from zarr.core.buffer import Buffer, default_buffer_prototype
26+
from zarr.core.buffer import Buffer, cpu, default_buffer_prototype
2727
from zarr.core.sync import _collect_aiterator, sync
2828
from zarr.storage._utils import _normalize_byte_range_index
2929
from zarr.testing.utils import assert_bytes_equal
@@ -202,6 +202,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None:
202202
):
203203
await reader.delete("foo")
204204

205+
@pytest.mark.parametrize(
206+
"prototype",
207+
[
208+
None, # Should use store's default buffer class
209+
default_buffer_prototype(), # BufferPrototype instance
210+
default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer)
211+
],
212+
ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"],
213+
)
205214
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
206215
@pytest.mark.parametrize(
207216
("data", "byte_range"),
@@ -213,13 +222,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None:
213222
(b"", None),
214223
],
215224
)
216-
async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRequest) -> None:
225+
async def test_get(
226+
self, store: S, key: str, data: bytes, byte_range: ByteRequest, prototype: BufferLike | None
227+
) -> None:
217228
"""
218229
Ensure that data can be read from the store using the store.get method.
219230
"""
220231
data_buf = self.buffer_cls.from_bytes(data)
221232
await self.set(store, key, data_buf)
222-
observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range)
233+
observed = await store.get(key, prototype=prototype, byte_range=byte_range)
223234
start, stop = _normalize_byte_range_index(data_buf, byte_range=byte_range)
224235
expected = data_buf[start:stop]
225236
assert_bytes_equal(observed, expected)
@@ -244,32 +255,6 @@ async def test_get_raises(self, store: S) -> None:
244255
with pytest.raises((ValueError, TypeError), match=r"Unexpected byte_range, got.*"):
245256
await store.get("c/0", prototype=default_buffer_prototype(), byte_range=(0, 2)) # type: ignore[arg-type]
246257

247-
@pytest.mark.parametrize(
248-
"prototype",
249-
[
250-
None, # Should use store's default buffer class
251-
default_buffer_prototype(), # BufferPrototype instance
252-
default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer)
253-
],
254-
ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"],
255-
)
256-
async def test_get_with_buffer_like(self, store: S, prototype: BufferLike | None) -> None:
257-
"""
258-
Test that store.get() works with all BufferLike variants:
259-
- None (uses store's default)
260-
- BufferPrototype instance
261-
- Raw Buffer class
262-
"""
263-
data = b"\x01\x02\x03\x04"
264-
key = "test_buffer_like"
265-
data_buf = self.buffer_cls.from_bytes(data)
266-
await self.set(store, key, data_buf)
267-
268-
# Get with the parametrized prototype
269-
observed = await store.get(key, prototype=prototype)
270-
assert observed is not None
271-
assert_bytes_equal(observed, data_buf)
272-
273258
async def test_get_many(self, store: S) -> None:
274259
"""
275260
Ensure that multiple keys can be retrieved at once with the _get_many method.
@@ -358,6 +343,15 @@ async def test_set_many(self, store: S) -> None:
358343
for k, v in store_dict.items():
359344
assert (await self.get(store, k)).to_bytes() == v.to_bytes()
360345

346+
@pytest.mark.parametrize(
347+
"prototype",
348+
[
349+
None, # Should use store's default buffer class
350+
default_buffer_prototype(), # BufferPrototype instance
351+
default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer)
352+
],
353+
ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"],
354+
)
361355
@pytest.mark.parametrize(
362356
"key_ranges",
363357
[
@@ -372,65 +366,13 @@ async def test_set_many(self, store: S) -> None:
372366
],
373367
)
374368
async def test_get_partial_values(
375-
self, store: S, key_ranges: list[tuple[str, ByteRequest]]
369+
self, store: S, key_ranges: list[tuple[str, ByteRequest]], prototype: BufferLike | None
376370
) -> None:
377371
# put all of the data
378372
for key, _ in key_ranges:
379373
await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))
380374

381375
# read back just part of it
382-
observed_maybe = await store.get_partial_values(
383-
prototype=default_buffer_prototype(), key_ranges=key_ranges
384-
)
385-
386-
observed: list[Buffer] = []
387-
expected: list[Buffer] = []
388-
389-
for obs in observed_maybe:
390-
assert obs is not None
391-
observed.append(obs)
392-
393-
for idx in range(len(observed)):
394-
key, byte_range = key_ranges[idx]
395-
result = await store.get(
396-
key, prototype=default_buffer_prototype(), byte_range=byte_range
397-
)
398-
assert result is not None
399-
expected.append(result)
400-
401-
assert all(
402-
obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True)
403-
)
404-
405-
@pytest.mark.parametrize(
406-
"prototype",
407-
[
408-
None, # Should use store's default buffer class
409-
default_buffer_prototype(), # BufferPrototype instance
410-
default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer)
411-
],
412-
ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"],
413-
)
414-
async def test_get_partial_values_with_buffer_like(
415-
self, store: S, prototype: BufferLike | None
416-
) -> None:
417-
"""
418-
Test that store.get_partial_values() works with all BufferLike variants:
419-
- None (uses store's default)
420-
- BufferPrototype instance
421-
- Raw Buffer class
422-
"""
423-
key_ranges: list[tuple[str, ByteRequest | None]] = [
424-
("c/0", RangeByteRequest(0, 2)),
425-
("c/1", None),
426-
("c/2", SuffixByteRequest(2)),
427-
]
428-
429-
# put all of the data
430-
for key, _ in key_ranges:
431-
await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))
432-
433-
# read back with the parametrized prototype
434376
observed_maybe = await store.get_partial_values(prototype=prototype, key_ranges=key_ranges)
435377

436378
observed: list[Buffer] = []
@@ -442,7 +384,7 @@ async def test_get_partial_values_with_buffer_like(
442384

443385
for idx in range(len(observed)):
444386
key, byte_range = key_ranges[idx]
445-
result = await store.get(key, prototype=prototype, byte_range=byte_range)
387+
result = await store.get(key, prototype=cpu.Buffer, byte_range=byte_range)
446388
assert result is not None
447389
expected.append(result)
448390

0 commit comments

Comments
 (0)