Skip to content

Commit f7772cc

Browse files
d-v-bclaude
andcommitted
feat: recursive resolve_index for nested sharding with leaf_transform
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 638d57f commit f7772cc

2 files changed

Lines changed: 203 additions & 18 deletions

File tree

src/zarr/core/codec_pipeline.py

Lines changed: 178 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,49 +1346,209 @@ def pack_blob(
13461346

13471347
return template.combine(buffers)
13481348

1349+
def _get_leaf_transform(self) -> ChunkTransform:
1350+
"""Get the innermost (leaf) transform, traversing nested ShardingCodecs."""
1351+
from zarr.codecs.sharding import ShardingCodec
1352+
1353+
transform = self.inner_transform
1354+
while isinstance(transform._ab_codec, ShardingCodec):
1355+
inner_sc = transform._ab_codec
1356+
inner_spec = inner_sc._get_chunk_spec(transform.array_spec)
1357+
inner_evolved = tuple(
1358+
c.evolve_from_array_spec(array_spec=inner_spec) for c in inner_sc.codecs
1359+
)
1360+
transform = ChunkTransform(codecs=inner_evolved, array_spec=inner_spec)
1361+
return transform
1362+
1363+
def _fetch_index_from_blob(self, blob: Buffer) -> Any:
1364+
"""Parse the shard index from an in-memory blob."""
1365+
from zarr.codecs.sharding import ShardingCodecIndexLocation
1366+
1367+
if self._index_location == ShardingCodecIndexLocation.start:
1368+
index_bytes = blob[: self._index_size]
1369+
else:
1370+
index_bytes = blob[-self._index_size :]
1371+
return self._decode_index(index_bytes)
1372+
13491373
# -- Phase 1: resolve index --
13501374

13511375
def resolve_index(self, byte_getter: Any, key: str, chunk_selection: SelectorTuple | None = None) -> ShardIndex:
13521376
from zarr.abc.store import RangeByteRequest
1377+
from zarr.codecs.sharding import ShardingCodec
13531378

13541379
shard_index = self._fetch_index_sync(byte_getter)
13551380
if shard_index is None:
1356-
return ShardIndex(key=key, leaf_transform=self.inner_transform, is_sharded=True)
1381+
return ShardIndex(key=key, leaf_transform=self._get_leaf_transform(), is_sharded=True)
13571382

13581383
if chunk_selection is not None:
13591384
needed = self.needed_coords(chunk_selection)
13601385
else:
13611386
needed = set(np.ndindex(self.chunks_per_shard))
13621387

1363-
chunks: dict[tuple[int, ...], RangeByteRequest | None] = {}
1364-
for coord in needed: # type: ignore[union-attr]
1365-
chunk_slice = shard_index.get_chunk_slice(coord)
1366-
if chunk_slice is not None:
1367-
chunks[coord] = RangeByteRequest(chunk_slice[0], chunk_slice[1])
1368-
else:
1369-
chunks[coord] = None
1370-
return ShardIndex(key=key, chunks=chunks, leaf_transform=self.inner_transform, is_sharded=True)
1388+
inner_ab = self.inner_transform._ab_codec
1389+
if not isinstance(inner_ab, ShardingCodec):
1390+
# Non-nested: same as before
1391+
chunks: dict[tuple[int, ...], RangeByteRequest | None] = {}
1392+
for coord in needed: # type: ignore[union-attr]
1393+
chunk_slice = shard_index.get_chunk_slice(coord)
1394+
if chunk_slice is not None:
1395+
chunks[coord] = RangeByteRequest(chunk_slice[0], chunk_slice[1])
1396+
else:
1397+
chunks[coord] = None
1398+
return ShardIndex(key=key, chunks=chunks, leaf_transform=self.inner_transform, is_sharded=True)
1399+
1400+
# NESTED sharding
1401+
from zarr.core.buffer import default_buffer_prototype
1402+
from zarr.core.chunk_grids import ChunkGrid as _ChunkGrid
1403+
from zarr.core.indexing import get_indexer
1404+
1405+
leaf_transform = self._get_leaf_transform()
1406+
1407+
# Build inner layout for the nested ShardingCodec
1408+
inner_spec = self.inner_transform.array_spec
1409+
inner_layout = ShardedChunkLayout.from_sharding_codec(inner_ab, inner_spec)
1410+
1411+
# Build inner indexer to determine which inner shards overlap selection
1412+
sel = chunk_selection if chunk_selection is not None else tuple(
1413+
slice(0, s) for s in self.chunk_shape
1414+
)
1415+
inner_indexer = get_indexer(
1416+
sel,
1417+
shape=self.chunk_shape,
1418+
chunk_grid=_ChunkGrid.from_sizes(self.chunk_shape, self.inner_chunk_shape),
1419+
)
1420+
1421+
flat: dict[tuple[int, ...], RangeByteRequest | None] = {}
1422+
for inner_coords, inner_sel, _, _ in inner_indexer:
1423+
chunk_slice = shard_index.get_chunk_slice(inner_coords)
1424+
if chunk_slice is None:
1425+
continue
1426+
start, end = chunk_slice
1427+
1428+
# Fetch the inner shard blob
1429+
inner_blob = byte_getter.get_sync(
1430+
prototype=default_buffer_prototype(),
1431+
byte_range=RangeByteRequest(start, end),
1432+
)
1433+
if inner_blob is None:
1434+
continue
1435+
1436+
# Parse inner shard index
1437+
inner_index = inner_layout._fetch_index_from_blob(inner_blob)
1438+
if inner_index is None:
1439+
continue
1440+
1441+
# Determine which leaf chunks within this inner shard are needed
1442+
inner_needed = inner_layout.needed_coords(inner_sel)
1443+
if inner_needed is None:
1444+
inner_needed = set(np.ndindex(inner_layout.chunks_per_shard))
1445+
1446+
# Translate coords and byte ranges
1447+
for leaf_coord in inner_needed:
1448+
leaf_slice = inner_index.get_chunk_slice(leaf_coord)
1449+
global_coord = tuple(
1450+
ic * cps + lc
1451+
for ic, cps, lc in zip(
1452+
inner_coords, inner_layout.chunks_per_shard, leaf_coord, strict=True
1453+
)
1454+
)
1455+
if leaf_slice is not None:
1456+
abs_start = start + leaf_slice[0]
1457+
abs_end = start + leaf_slice[1]
1458+
flat[global_coord] = RangeByteRequest(abs_start, abs_end)
1459+
else:
1460+
flat[global_coord] = None
1461+
1462+
return ShardIndex(key=key, chunks=flat, leaf_transform=leaf_transform, is_sharded=True)
13711463

13721464
async def resolve_index_async(self, byte_getter: Any, key: str, chunk_selection: SelectorTuple | None = None) -> ShardIndex:
13731465
from zarr.abc.store import RangeByteRequest
1466+
from zarr.codecs.sharding import ShardingCodec
13741467

13751468
shard_index = await self._fetch_index(byte_getter)
13761469
if shard_index is None:
1377-
return ShardIndex(key=key, leaf_transform=self.inner_transform, is_sharded=True)
1470+
return ShardIndex(key=key, leaf_transform=self._get_leaf_transform(), is_sharded=True)
13781471

13791472
if chunk_selection is not None:
13801473
needed = self.needed_coords(chunk_selection)
13811474
else:
13821475
needed = set(np.ndindex(self.chunks_per_shard))
13831476

1384-
chunks: dict[tuple[int, ...], RangeByteRequest | None] = {}
1385-
for coord in needed: # type: ignore[union-attr]
1386-
chunk_slice = shard_index.get_chunk_slice(coord)
1387-
if chunk_slice is not None:
1388-
chunks[coord] = RangeByteRequest(chunk_slice[0], chunk_slice[1])
1389-
else:
1390-
chunks[coord] = None
1391-
return ShardIndex(key=key, chunks=chunks, leaf_transform=self.inner_transform, is_sharded=True)
1477+
inner_ab = self.inner_transform._ab_codec
1478+
if not isinstance(inner_ab, ShardingCodec):
1479+
# Non-nested: same as before
1480+
chunks: dict[tuple[int, ...], RangeByteRequest | None] = {}
1481+
for coord in needed: # type: ignore[union-attr]
1482+
chunk_slice = shard_index.get_chunk_slice(coord)
1483+
if chunk_slice is not None:
1484+
chunks[coord] = RangeByteRequest(chunk_slice[0], chunk_slice[1])
1485+
else:
1486+
chunks[coord] = None
1487+
return ShardIndex(key=key, chunks=chunks, leaf_transform=self.inner_transform, is_sharded=True)
1488+
1489+
# NESTED sharding
1490+
from zarr.core.buffer import default_buffer_prototype
1491+
from zarr.core.chunk_grids import ChunkGrid as _ChunkGrid
1492+
from zarr.core.indexing import get_indexer
1493+
1494+
leaf_transform = self._get_leaf_transform()
1495+
1496+
# Build inner layout for the nested ShardingCodec
1497+
inner_spec = self.inner_transform.array_spec
1498+
inner_layout = ShardedChunkLayout.from_sharding_codec(inner_ab, inner_spec)
1499+
1500+
# Build inner indexer to determine which inner shards overlap selection
1501+
sel = chunk_selection if chunk_selection is not None else tuple(
1502+
slice(0, s) for s in self.chunk_shape
1503+
)
1504+
inner_indexer = get_indexer(
1505+
sel,
1506+
shape=self.chunk_shape,
1507+
chunk_grid=_ChunkGrid.from_sizes(self.chunk_shape, self.inner_chunk_shape),
1508+
)
1509+
1510+
flat: dict[tuple[int, ...], RangeByteRequest | None] = {}
1511+
for inner_coords, inner_sel, _, _ in inner_indexer:
1512+
chunk_slice = shard_index.get_chunk_slice(inner_coords)
1513+
if chunk_slice is None:
1514+
continue
1515+
start, end = chunk_slice
1516+
1517+
# Fetch the inner shard blob
1518+
inner_blob = await byte_getter.get(
1519+
prototype=default_buffer_prototype(),
1520+
byte_range=RangeByteRequest(start, end),
1521+
)
1522+
if inner_blob is None:
1523+
continue
1524+
1525+
# Parse inner shard index
1526+
inner_index = inner_layout._fetch_index_from_blob(inner_blob)
1527+
if inner_index is None:
1528+
continue
1529+
1530+
# Determine which leaf chunks within this inner shard are needed
1531+
inner_needed = inner_layout.needed_coords(inner_sel)
1532+
if inner_needed is None:
1533+
inner_needed = set(np.ndindex(inner_layout.chunks_per_shard))
1534+
1535+
# Translate coords and byte ranges
1536+
for leaf_coord in inner_needed:
1537+
leaf_slice = inner_index.get_chunk_slice(leaf_coord)
1538+
global_coord = tuple(
1539+
ic * cps + lc
1540+
for ic, cps, lc in zip(
1541+
inner_coords, inner_layout.chunks_per_shard, leaf_coord, strict=True
1542+
)
1543+
)
1544+
if leaf_slice is not None:
1545+
abs_start = start + leaf_slice[0]
1546+
abs_end = start + leaf_slice[1]
1547+
flat[global_coord] = RangeByteRequest(abs_start, abs_end)
1548+
else:
1549+
flat[global_coord] = None
1550+
1551+
return ShardIndex(key=key, chunks=flat, leaf_transform=leaf_transform, is_sharded=True)
13921552

13931553
# -- Phase 2: fetch chunk data --
13941554

tests/test_codec_pipeline.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,28 @@ async def test_read_missing_chunks_true_fills(pipeline_class: str) -> None:
366366
)
367367
# Don't write anything
368368
np.testing.assert_array_equal(arr[:], np.full(20, -999.0))
369+
370+
371+
async def test_nested_sharding_roundtrip(pipeline_class: str) -> None:
372+
"""Nested sharding: data survives write/read roundtrip."""
373+
from zarr.codecs.bytes import BytesCodec
374+
from zarr.codecs.sharding import ShardingCodec
375+
376+
inner_sharding = ShardingCodec(chunk_shape=(10,), codecs=[BytesCodec()])
377+
outer_sharding = ShardingCodec(chunk_shape=(50,), codecs=[inner_sharding])
378+
379+
store = MemoryStore()
380+
arr = zarr.create_array(
381+
store=store,
382+
shape=(100,),
383+
dtype="uint8",
384+
chunks=(100,),
385+
compressors=None,
386+
fill_value=0,
387+
serializer=outer_sharding,
388+
)
389+
data = np.arange(100, dtype="uint8")
390+
arr[:] = data
391+
np.testing.assert_array_equal(arr[:], data)
392+
# Partial read
393+
np.testing.assert_array_equal(arr[40:60], data[40:60])

0 commit comments

Comments
 (0)