Skip to content

Commit 36d3909

Browse files
committed
Unit tests for new get_partial_values implementations
1 parent 56ea19c commit 36d3909

1 file changed

Lines changed: 94 additions & 6 deletions

File tree

tests/test_codecs/test_sharding_unit.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,94 @@ def test_shard_index_is_dense_with_empty_chunks() -> None:
179179
assert index.is_dense(chunk_byte_length) is True
180180

181181

182+
# ============================================================================
183+
# _ShardingByteGetter.get_partial_values tests
184+
# ============================================================================
185+
186+
187+
async def test_sharding_byte_getter_get_partial_values_returns_slices() -> None:
188+
"""Test that get_partial_values returns correct slices from the shard dict."""
189+
from zarr.codecs.sharding import _ShardingByteGetter
190+
191+
chunk_data = Buffer.from_bytes(b"AAAABBBB")
192+
shard_dict: dict[tuple[int, ...], Buffer | None] = {(0,): chunk_data}
193+
getter = _ShardingByteGetter(shard_dict, (0,))
194+
195+
from zarr.abc.store import RangeByteRequest
196+
197+
results = await getter.get_partial_values(
198+
default_buffer_prototype(),
199+
[RangeByteRequest(0, 4), RangeByteRequest(4, 8)],
200+
)
201+
202+
assert len(results) == 2
203+
assert results[0] is not None
204+
assert results[0].as_numpy_array().tobytes() == b"AAAA"
205+
assert results[1] is not None
206+
assert results[1].as_numpy_array().tobytes() == b"BBBB"
207+
208+
209+
async def test_sharding_byte_getter_get_partial_values_missing_chunk() -> None:
210+
"""Test that get_partial_values returns None for a missing chunk."""
211+
from zarr.codecs.sharding import _ShardingByteGetter
212+
213+
shard_dict: dict[tuple[int, ...], Buffer | None] = {}
214+
getter = _ShardingByteGetter(shard_dict, (0,))
215+
216+
from zarr.abc.store import RangeByteRequest
217+
218+
results = await getter.get_partial_values(
219+
default_buffer_prototype(),
220+
[RangeByteRequest(0, 10)],
221+
)
222+
223+
assert results == [None]
224+
225+
226+
# ============================================================================
227+
# StorePath.get_partial_values tests
228+
# ============================================================================
229+
230+
231+
async def test_store_path_get_partial_values() -> None:
232+
"""Test that StorePath.get_partial_values delegates to Store.get_partial_values."""
233+
from zarr.abc.store import RangeByteRequest
234+
from zarr.storage._common import StorePath
235+
from zarr.storage._memory import MemoryStore
236+
237+
store = MemoryStore()
238+
await store.set("key", Buffer.from_bytes(b"0123456789"))
239+
path = StorePath(store, "key")
240+
241+
results = await path.get_partial_values(
242+
default_buffer_prototype(),
243+
[RangeByteRequest(0, 3), RangeByteRequest(7, 10)],
244+
)
245+
246+
assert len(results) == 2
247+
assert results[0] is not None
248+
assert results[0].as_numpy_array().tobytes() == b"012"
249+
assert results[1] is not None
250+
assert results[1].as_numpy_array().tobytes() == b"789"
251+
252+
253+
async def test_store_path_get_partial_values_missing_key() -> None:
254+
"""Test that StorePath.get_partial_values returns None for a missing key."""
255+
from zarr.abc.store import RangeByteRequest
256+
from zarr.storage._common import StorePath
257+
from zarr.storage._memory import MemoryStore
258+
259+
store = MemoryStore()
260+
path = StorePath(store, "nonexistent")
261+
262+
results = await path.get_partial_values(
263+
default_buffer_prototype(),
264+
[RangeByteRequest(0, 10)],
265+
)
266+
267+
assert results == [None]
268+
269+
182270
# ============================================================================
183271
# Mock ByteGetter for _load_partial_shard_maybe tests
184272
# ============================================================================
@@ -262,7 +350,7 @@ async def test_load_partial_shard_maybe_index_load_fails() -> None:
262350
all_chunk_coords: set[tuple[int, ...]] = {(0,)}
263351

264352
result = await codec._load_partial_shard_maybe(
265-
byte_getter=byte_getter, # type: ignore[arg-type] # mypy false positive: identical signatures
353+
byte_getter=byte_getter,
266354
prototype=default_buffer_prototype(),
267355
chunks_per_shard=chunks_per_shard,
268356
all_chunk_coords=all_chunk_coords,
@@ -299,7 +387,7 @@ async def mock_load_index(
299387
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)}
300388

301389
result = await codec._load_partial_shard_maybe(
302-
byte_getter=byte_getter, # type: ignore[arg-type] # mypy false positive: identical signatures
390+
byte_getter=byte_getter,
303391
prototype=default_buffer_prototype(),
304392
chunks_per_shard=chunks_per_shard,
305393
all_chunk_coords=all_chunk_coords,
@@ -335,7 +423,7 @@ async def mock_load_index(
335423
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)}
336424

337425
result = await codec._load_partial_shard_maybe(
338-
byte_getter=byte_getter, # type: ignore[arg-type] # mypy false positive: identical signatures
426+
byte_getter=byte_getter,
339427
prototype=default_buffer_prototype(),
340428
chunks_per_shard=chunks_per_shard,
341429
all_chunk_coords=all_chunk_coords,
@@ -369,7 +457,7 @@ async def mock_load_index(
369457
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,)}
370458

371459
result = await codec._load_partial_shard_maybe(
372-
byte_getter=byte_getter, # type: ignore[arg-type] # mypy false positive: identical signatures
460+
byte_getter=byte_getter,
373461
prototype=default_buffer_prototype(),
374462
chunks_per_shard=chunks_per_shard,
375463
all_chunk_coords=all_chunk_coords,
@@ -407,7 +495,7 @@ async def mock_load_index(
407495
all_chunk_coords: set[tuple[int, ...]] = {(1,)}
408496

409497
result = await codec._load_partial_shard_maybe(
410-
byte_getter=byte_getter, # type: ignore[arg-type] # mypy false positive: identical signatures
498+
byte_getter=byte_getter,
411499
prototype=default_buffer_prototype(),
412500
chunks_per_shard=chunks_per_shard,
413501
all_chunk_coords=all_chunk_coords,
@@ -446,7 +534,7 @@ async def mock_load_index(
446534
all_chunk_coords: set[tuple[int, ...]] = {(0,)}
447535

448536
result = await codec._load_partial_shard_maybe(
449-
byte_getter=byte_getter, # type: ignore[arg-type] # mypy false positive: identical signatures
537+
byte_getter=byte_getter,
450538
prototype=default_buffer_prototype(),
451539
chunks_per_shard=chunks_per_shard,
452540
all_chunk_coords=all_chunk_coords,

0 commit comments

Comments
 (0)