@@ -1346,49 +1346,209 @@ def pack_blob(
13461346
13471347 return template .combine (buffers )
13481348
1349+ def _get_leaf_transform (self ) -> ChunkTransform :
1350+ """Get the innermost (leaf) transform, traversing nested ShardingCodecs."""
1351+ from zarr .codecs .sharding import ShardingCodec
1352+
1353+ transform = self .inner_transform
1354+ while isinstance (transform ._ab_codec , ShardingCodec ):
1355+ inner_sc = transform ._ab_codec
1356+ inner_spec = inner_sc ._get_chunk_spec (transform .array_spec )
1357+ inner_evolved = tuple (
1358+ c .evolve_from_array_spec (array_spec = inner_spec ) for c in inner_sc .codecs
1359+ )
1360+ transform = ChunkTransform (codecs = inner_evolved , array_spec = inner_spec )
1361+ return transform
1362+
1363+ def _fetch_index_from_blob (self , blob : Buffer ) -> Any :
1364+ """Parse the shard index from an in-memory blob."""
1365+ from zarr .codecs .sharding import ShardingCodecIndexLocation
1366+
1367+ if self ._index_location == ShardingCodecIndexLocation .start :
1368+ index_bytes = blob [: self ._index_size ]
1369+ else :
1370+ index_bytes = blob [- self ._index_size :]
1371+ return self ._decode_index (index_bytes )
1372+
13491373 # -- Phase 1: resolve index --
13501374
13511375 def resolve_index (self , byte_getter : Any , key : str , chunk_selection : SelectorTuple | None = None ) -> ShardIndex :
13521376 from zarr .abc .store import RangeByteRequest
1377+ from zarr .codecs .sharding import ShardingCodec
13531378
13541379 shard_index = self ._fetch_index_sync (byte_getter )
13551380 if shard_index is None :
1356- return ShardIndex (key = key , leaf_transform = self .inner_transform , is_sharded = True )
1381+ return ShardIndex (key = key , leaf_transform = self ._get_leaf_transform () , is_sharded = True )
13571382
13581383 if chunk_selection is not None :
13591384 needed = self .needed_coords (chunk_selection )
13601385 else :
13611386 needed = set (np .ndindex (self .chunks_per_shard ))
13621387
1363- chunks : dict [tuple [int , ...], RangeByteRequest | None ] = {}
1364- for coord in needed : # type: ignore[union-attr]
1365- chunk_slice = shard_index .get_chunk_slice (coord )
1366- if chunk_slice is not None :
1367- chunks [coord ] = RangeByteRequest (chunk_slice [0 ], chunk_slice [1 ])
1368- else :
1369- chunks [coord ] = None
1370- return ShardIndex (key = key , chunks = chunks , leaf_transform = self .inner_transform , is_sharded = True )
1388+ inner_ab = self .inner_transform ._ab_codec
1389+ if not isinstance (inner_ab , ShardingCodec ):
1390+ # Non-nested: same as before
1391+ chunks : dict [tuple [int , ...], RangeByteRequest | None ] = {}
1392+ for coord in needed : # type: ignore[union-attr]
1393+ chunk_slice = shard_index .get_chunk_slice (coord )
1394+ if chunk_slice is not None :
1395+ chunks [coord ] = RangeByteRequest (chunk_slice [0 ], chunk_slice [1 ])
1396+ else :
1397+ chunks [coord ] = None
1398+ return ShardIndex (key = key , chunks = chunks , leaf_transform = self .inner_transform , is_sharded = True )
1399+
1400+ # NESTED sharding
1401+ from zarr .core .buffer import default_buffer_prototype
1402+ from zarr .core .chunk_grids import ChunkGrid as _ChunkGrid
1403+ from zarr .core .indexing import get_indexer
1404+
1405+ leaf_transform = self ._get_leaf_transform ()
1406+
1407+ # Build inner layout for the nested ShardingCodec
1408+ inner_spec = self .inner_transform .array_spec
1409+ inner_layout = ShardedChunkLayout .from_sharding_codec (inner_ab , inner_spec )
1410+
1411+ # Build inner indexer to determine which inner shards overlap selection
1412+ sel = chunk_selection if chunk_selection is not None else tuple (
1413+ slice (0 , s ) for s in self .chunk_shape
1414+ )
1415+ inner_indexer = get_indexer (
1416+ sel ,
1417+ shape = self .chunk_shape ,
1418+ chunk_grid = _ChunkGrid .from_sizes (self .chunk_shape , self .inner_chunk_shape ),
1419+ )
1420+
1421+ flat : dict [tuple [int , ...], RangeByteRequest | None ] = {}
1422+ for inner_coords , inner_sel , _ , _ in inner_indexer :
1423+ chunk_slice = shard_index .get_chunk_slice (inner_coords )
1424+ if chunk_slice is None :
1425+ continue
1426+ start , end = chunk_slice
1427+
1428+ # Fetch the inner shard blob
1429+ inner_blob = byte_getter .get_sync (
1430+ prototype = default_buffer_prototype (),
1431+ byte_range = RangeByteRequest (start , end ),
1432+ )
1433+ if inner_blob is None :
1434+ continue
1435+
1436+ # Parse inner shard index
1437+ inner_index = inner_layout ._fetch_index_from_blob (inner_blob )
1438+ if inner_index is None :
1439+ continue
1440+
1441+ # Determine which leaf chunks within this inner shard are needed
1442+ inner_needed = inner_layout .needed_coords (inner_sel )
1443+ if inner_needed is None :
1444+ inner_needed = set (np .ndindex (inner_layout .chunks_per_shard ))
1445+
1446+ # Translate coords and byte ranges
1447+ for leaf_coord in inner_needed :
1448+ leaf_slice = inner_index .get_chunk_slice (leaf_coord )
1449+ global_coord = tuple (
1450+ ic * cps + lc
1451+ for ic , cps , lc in zip (
1452+ inner_coords , inner_layout .chunks_per_shard , leaf_coord , strict = True
1453+ )
1454+ )
1455+ if leaf_slice is not None :
1456+ abs_start = start + leaf_slice [0 ]
1457+ abs_end = start + leaf_slice [1 ]
1458+ flat [global_coord ] = RangeByteRequest (abs_start , abs_end )
1459+ else :
1460+ flat [global_coord ] = None
1461+
1462+ return ShardIndex (key = key , chunks = flat , leaf_transform = leaf_transform , is_sharded = True )
13711463
13721464 async def resolve_index_async (self , byte_getter : Any , key : str , chunk_selection : SelectorTuple | None = None ) -> ShardIndex :
13731465 from zarr .abc .store import RangeByteRequest
1466+ from zarr .codecs .sharding import ShardingCodec
13741467
13751468 shard_index = await self ._fetch_index (byte_getter )
13761469 if shard_index is None :
1377- return ShardIndex (key = key , leaf_transform = self .inner_transform , is_sharded = True )
1470+ return ShardIndex (key = key , leaf_transform = self ._get_leaf_transform () , is_sharded = True )
13781471
13791472 if chunk_selection is not None :
13801473 needed = self .needed_coords (chunk_selection )
13811474 else :
13821475 needed = set (np .ndindex (self .chunks_per_shard ))
13831476
1384- chunks : dict [tuple [int , ...], RangeByteRequest | None ] = {}
1385- for coord in needed : # type: ignore[union-attr]
1386- chunk_slice = shard_index .get_chunk_slice (coord )
1387- if chunk_slice is not None :
1388- chunks [coord ] = RangeByteRequest (chunk_slice [0 ], chunk_slice [1 ])
1389- else :
1390- chunks [coord ] = None
1391- return ShardIndex (key = key , chunks = chunks , leaf_transform = self .inner_transform , is_sharded = True )
1477+ inner_ab = self .inner_transform ._ab_codec
1478+ if not isinstance (inner_ab , ShardingCodec ):
1479+ # Non-nested: same as before
1480+ chunks : dict [tuple [int , ...], RangeByteRequest | None ] = {}
1481+ for coord in needed : # type: ignore[union-attr]
1482+ chunk_slice = shard_index .get_chunk_slice (coord )
1483+ if chunk_slice is not None :
1484+ chunks [coord ] = RangeByteRequest (chunk_slice [0 ], chunk_slice [1 ])
1485+ else :
1486+ chunks [coord ] = None
1487+ return ShardIndex (key = key , chunks = chunks , leaf_transform = self .inner_transform , is_sharded = True )
1488+
1489+ # NESTED sharding
1490+ from zarr .core .buffer import default_buffer_prototype
1491+ from zarr .core .chunk_grids import ChunkGrid as _ChunkGrid
1492+ from zarr .core .indexing import get_indexer
1493+
1494+ leaf_transform = self ._get_leaf_transform ()
1495+
1496+ # Build inner layout for the nested ShardingCodec
1497+ inner_spec = self .inner_transform .array_spec
1498+ inner_layout = ShardedChunkLayout .from_sharding_codec (inner_ab , inner_spec )
1499+
1500+ # Build inner indexer to determine which inner shards overlap selection
1501+ sel = chunk_selection if chunk_selection is not None else tuple (
1502+ slice (0 , s ) for s in self .chunk_shape
1503+ )
1504+ inner_indexer = get_indexer (
1505+ sel ,
1506+ shape = self .chunk_shape ,
1507+ chunk_grid = _ChunkGrid .from_sizes (self .chunk_shape , self .inner_chunk_shape ),
1508+ )
1509+
1510+ flat : dict [tuple [int , ...], RangeByteRequest | None ] = {}
1511+ for inner_coords , inner_sel , _ , _ in inner_indexer :
1512+ chunk_slice = shard_index .get_chunk_slice (inner_coords )
1513+ if chunk_slice is None :
1514+ continue
1515+ start , end = chunk_slice
1516+
1517+ # Fetch the inner shard blob
1518+ inner_blob = await byte_getter .get (
1519+ prototype = default_buffer_prototype (),
1520+ byte_range = RangeByteRequest (start , end ),
1521+ )
1522+ if inner_blob is None :
1523+ continue
1524+
1525+ # Parse inner shard index
1526+ inner_index = inner_layout ._fetch_index_from_blob (inner_blob )
1527+ if inner_index is None :
1528+ continue
1529+
1530+ # Determine which leaf chunks within this inner shard are needed
1531+ inner_needed = inner_layout .needed_coords (inner_sel )
1532+ if inner_needed is None :
1533+ inner_needed = set (np .ndindex (inner_layout .chunks_per_shard ))
1534+
1535+ # Translate coords and byte ranges
1536+ for leaf_coord in inner_needed :
1537+ leaf_slice = inner_index .get_chunk_slice (leaf_coord )
1538+ global_coord = tuple (
1539+ ic * cps + lc
1540+ for ic , cps , lc in zip (
1541+ inner_coords , inner_layout .chunks_per_shard , leaf_coord , strict = True
1542+ )
1543+ )
1544+ if leaf_slice is not None :
1545+ abs_start = start + leaf_slice [0 ]
1546+ abs_end = start + leaf_slice [1 ]
1547+ flat [global_coord ] = RangeByteRequest (abs_start , abs_end )
1548+ else :
1549+ flat [global_coord ] = None
1550+
1551+ return ShardIndex (key = key , chunks = flat , leaf_transform = leaf_transform , is_sharded = True )
13921552
13931553 # -- Phase 2: fetch chunk data --
13941554
0 commit comments