Skip to content

Commit cd64c3d

Browse files
d-v-bclaude
andcommitted
test: update read/write plan tests for RangeByteRequest and nested sharding
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8831672 commit cd64c3d

2 files changed

Lines changed: 312 additions & 147 deletions

File tree

tests/test_read_plan.py

Lines changed: 21 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
55
The model:
66
- A shard (or non-sharded chunk) is a flat key-value space:
7-
``coords → ByteRange`` within one store key.
7+
``coords → RangeByteRequest`` within one store key.
88
- Index resolution (possibly recursive for nested sharding) produces
99
this flat mapping.
1010
- The pipeline then filters to needed coords, fetches those byte ranges,
@@ -17,47 +17,30 @@
1717
from typing import Any
1818

1919
import numpy as np
20+
import pytest
2021

2122
import zarr
23+
from zarr.abc.store import RangeByteRequest
2224

2325
# ---------------------------------------------------------------------------
2426
# Data model
2527
# ---------------------------------------------------------------------------
2628

2729

28-
@dataclass(frozen=True)
29-
class ByteRange:
30-
"""A contiguous byte range within a store value."""
31-
32-
offset: int
33-
length: int
34-
35-
3630
@dataclass(frozen=True)
3731
class ShardIndex:
3832
"""Flat mapping from inner chunk coordinates to byte ranges.
3933
40-
Produced by resolving the shard index (and any nested indexes).
41-
For non-sharded chunks, contains a single entry mapping ``(0,)``
42-
to ``None`` (meaning: read the full value).
43-
44-
Parameters
45-
----------
46-
key : str
47-
The store key for this shard/chunk.
48-
chunks : dict
49-
Mapping from inner chunk coords to their byte range within
50-
the blob at ``key``. A value of ``None`` means the chunk
51-
is absent (fill value).
34+
Uses ``RangeByteRequest`` from ``zarr.abc.store`` for byte ranges.
5235
"""
5336

5437
key: str
55-
chunks: dict[tuple[int, ...], ByteRange | None] = field(default_factory=dict)
38+
chunks: dict[tuple[int, ...], RangeByteRequest | None] = field(default_factory=dict)
5639

5740
@property
5841
def nbytes_data(self) -> int:
5942
"""Total data bytes across all present chunks."""
60-
return sum(r.length for r in self.chunks.values() if r is not None)
43+
return sum(r.end - r.start for r in self.chunks.values() if r is not None)
6144

6245
def filter(self, needed: set[tuple[int, ...]] | None = None) -> ShardIndex:
6346
"""Return a new ShardIndex with only the needed coords."""
@@ -176,10 +159,10 @@ def test_single_inner_chunk(self) -> None:
176159
idx = indices[0]
177160
# Only the needed inner chunk
178161
assert len(idx.chunks) == 1
179-
coords = list(idx.chunks.keys())[0]
162+
coords = next(iter(idx.chunks.keys()))
180163
byte_range = idx.chunks[coords]
181164
assert byte_range is not None
182-
assert byte_range.length == 10
165+
assert byte_range.end - byte_range.start == 10
183166

184167
def test_two_inner_chunks(self) -> None:
185168
arr, _store = _create_and_fill(shape=(100,), chunks=(10,), shards=(100,))
@@ -224,7 +207,7 @@ def test_single_inner_chunk_compressed(self) -> None:
224207
assert len(idx.chunks) == 1
225208
byte_range = list(idx.chunks.values())[0]
226209
assert byte_range is not None
227-
assert byte_range.length > 0
210+
assert byte_range.end - byte_range.start > 0
228211

229212

230213
# ---------------------------------------------------------------------------
@@ -234,7 +217,7 @@ def test_single_inner_chunk_compressed(self) -> None:
234217

235218
class TestNestedShardedIndex:
236219
"""For nested sharding, index resolution recurses through levels
237-
but produces the same flat coords → ByteRange mapping.
220+
but produces the same flat coords → RangeByteRequest mapping.
238221
"""
239222

240223
@staticmethod
@@ -263,7 +246,7 @@ def test_single_leaf_chunk(self) -> None:
263246
assert len(idx.chunks) == 1
264247
byte_range = list(idx.chunks.values())[0]
265248
assert byte_range is not None
266-
assert byte_range.length == 10
249+
assert byte_range.end - byte_range.start == 10
267250

268251
def test_full_inner_shard(self) -> None:
269252
"""One full inner shard (50 bytes = 5 leaf chunks)."""
@@ -299,109 +282,14 @@ def test_all_leaf_chunks(self) -> None:
299282
# ---------------------------------------------------------------------------
300283

301284

302-
def _resolve_shard_index(
303-
layout: Any,
304-
chunk_selection: Any,
305-
shard_blob: Any | None,
306-
base_offset: int = 0,
307-
) -> dict[tuple[int, ...], ByteRange | None]:
308-
"""Recursively resolve a flat coords → ByteRange mapping for a shard.
309-
310-
For fixed-size codecs, byte ranges are computed from coordinates alone
311-
(shard_blob can be None). For variable-size codecs, the shard blob
312-
is needed to read the index. For nested sharding, recurses into
313-
inner shards.
314-
"""
315-
from zarr.codecs.sharding import ShardingCodec
316-
from zarr.core.codec_pipeline import ShardedChunkLayout
317-
318-
needed_coords = layout.needed_coords(chunk_selection)
319-
if needed_coords is None:
320-
return {}
321-
322-
if layout._fixed_size:
323-
chunk_spec = layout.inner_transform.array_spec
324-
chunk_byte_length = layout.inner_chunk_byte_length(chunk_spec)
325-
return {
326-
coords: ByteRange(
327-
offset=base_offset + layout.chunk_byte_offset(coords, chunk_byte_length),
328-
length=chunk_byte_length,
329-
)
330-
for coords in needed_coords
331-
}
332-
333-
# Variable-size: need the blob to read the index
334-
assert shard_blob is not None
335-
chunk_dict = layout.unpack_blob(shard_blob)
336-
337-
# Check for nested sharding
338-
inner_ab = layout.inner_transform._ab_codec
339-
is_nested = isinstance(inner_ab, ShardingCodec)
340-
341-
if not is_nested:
342-
# Leaf level: read byte ranges from the index
343-
from zarr.codecs.sharding import ShardingCodecIndexLocation
344-
345-
if layout._index_location == ShardingCodecIndexLocation.start:
346-
index_bytes = shard_blob[: layout._index_size]
347-
else:
348-
index_bytes = shard_blob[-layout._index_size :]
349-
index = layout._decode_index(index_bytes)
350-
351-
result: dict[tuple[int, ...], ByteRange | None] = {}
352-
for coords in needed_coords:
353-
chunk_slice = index.get_chunk_slice(coords)
354-
if chunk_slice is not None:
355-
start, end = chunk_slice
356-
result[coords] = ByteRange(offset=base_offset + start, length=end - start)
357-
else:
358-
result[coords] = None
359-
return result
360-
361-
# Nested: resolve inner shard indexes and flatten
362-
from zarr.codecs.sharding import ShardingCodecIndexLocation
363-
from zarr.core.chunk_grids import ChunkGrid as _ChunkGrid
364-
from zarr.core.indexing import get_indexer
365-
366-
if layout._index_location == ShardingCodecIndexLocation.start:
367-
index_bytes = shard_blob[: layout._index_size]
368-
else:
369-
index_bytes = shard_blob[-layout._index_size :]
370-
outer_index = layout._decode_index(index_bytes)
371-
372-
inner_indexer = get_indexer(
373-
chunk_selection,
374-
shape=layout.chunk_shape,
375-
chunk_grid=_ChunkGrid.from_sizes(layout.chunk_shape, layout.inner_chunk_shape),
376-
)
377-
378-
inner_spec = layout.inner_transform.array_spec
379-
inner_layout = ShardedChunkLayout.from_sharding_codec(inner_ab, inner_spec)
380-
381-
flat: dict[tuple[int, ...], ByteRange | None] = {}
382-
for inner_coords, inner_sel, _, _ in inner_indexer:
383-
chunk_slice = outer_index.get_chunk_slice(inner_coords)
384-
if chunk_slice is None:
385-
continue
386-
start, end = chunk_slice
387-
inner_blob = shard_blob[start:end]
388-
inner_chunks = _resolve_shard_index(
389-
inner_layout, inner_sel, inner_blob, base_offset=base_offset + start
390-
)
391-
# Prefix leaf coords with the outer coords to make them globally unique
392-
for leaf_coords, byte_range in inner_chunks.items():
393-
flat[inner_coords + leaf_coords] = byte_range
394-
395-
return flat
396-
397-
398285
def _resolve_indices(arr: zarr.Array, selection: Any) -> list[ShardIndex]:
399286
"""Given an array and a selection, resolve ShardIndex for each chunk/shard.
400287
401-
Each ShardIndex is a flat mapping from inner chunk coords to byte ranges,
402-
regardless of how many levels of nesting exist.
288+
Uses the pipeline's ``layout.resolve_index`` to get the flat
289+
coords → RangeByteRequest mapping, then wraps in the test's
290+
ShardIndex (which has extra helper methods).
403291
"""
404-
from zarr.core.codec_pipeline import PhasedCodecPipeline, ShardedChunkLayout
292+
from zarr.core.codec_pipeline import PhasedCodecPipeline
405293
from zarr.core.indexing import BasicIndexer
406294

407295
aa = arr._async_array
@@ -423,27 +311,13 @@ def _resolve_indices(arr: zarr.Array, selection: Any) -> list[ShardIndex]:
423311
continue
424312

425313
layout = pipeline.layout
314+
store_path = aa.store_path / key
426315

427-
if not layout.is_sharded:
428-
indices.append(ShardIndex(key=key, chunks={(0,) * len(chunk_coords): None}))
429-
continue
430-
431-
assert isinstance(layout, ShardedChunkLayout)
432-
433-
if layout._fixed_size:
434-
# No blob needed
435-
chunks = _resolve_shard_index(layout, chunk_selection, shard_blob=None)
436-
else:
437-
# Need the blob to read indexes
438-
from zarr.core.buffer import default_buffer_prototype
439-
440-
store_path = aa.store_path / key
441-
shard_blob = store_path.get_sync(prototype=default_buffer_prototype())
442-
if shard_blob is None:
443-
indices.append(ShardIndex(key=key))
444-
continue
445-
chunks = _resolve_shard_index(layout, chunk_selection, shard_blob)
316+
# Use the pipeline's resolve_index — it handles all cases
317+
# (non-sharded, fixed-size, variable-size)
318+
pipeline_index = layout.resolve_index(store_path, key, chunk_selection=chunk_selection)
446319

447-
indices.append(ShardIndex(key=key, chunks=chunks))
320+
# Convert pipeline ShardIndex to test ShardIndex
321+
indices.append(ShardIndex(key=pipeline_index.key, chunks=dict(pipeline_index.chunks)))
448322

449323
return indices

0 commit comments

Comments
 (0)