@@ -844,13 +844,7 @@ def decode_chunks_from_index(
844844 order = chunk_spec .order ,
845845 )
846846
847- # Non-sharded: the transform covers the entire chunk (or more, for rectilinear edges).
848- # Sharded: the transform's inner shape is smaller than chunk_spec in at least one dim.
849- is_simple = all (
850- t >= c for t , c in zip (transform .array_spec .shape , chunk_spec .shape , strict = True )
851- )
852-
853- if is_simple :
847+ if not index .is_sharded :
854848 assert len (raw_chunks ) == 1
855849 raw = next (iter (raw_chunks .values ()))
856850 if raw is None :
@@ -1048,6 +1042,16 @@ class ChunkLayout:
10481042 def is_sharded (self ) -> bool :
10491043 return False
10501044
1045+ @property
1046+ def leaf_transform (self ) -> ChunkTransform :
1047+ """The codec chain that decodes individual leaf chunks.
1048+
1049+ For non-sharded layouts, this is the full transform.
1050+ For sharded layouts, this traverses nested ShardingCodecs to
1051+ find the innermost codec chain.
1052+ """
1053+ return self .inner_transform
1054+
10511055 def needed_coords (self , chunk_selection : SelectorTuple ) -> set [tuple [int , ...]] | None :
10521056 return None
10531057
@@ -1069,9 +1073,6 @@ async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tupl
10691073
10701074 # -- Low-level helpers --
10711075
1072- def unpack_blob (self , blob : Buffer ) -> dict [tuple [int , ...], Buffer | None ]:
1073- raise NotImplementedError
1074-
10751076 def pack_blob (
10761077 self , chunk_dict : dict [tuple [int , ...], Buffer | None ], prototype : BufferPrototype
10771078 ) -> Buffer | None :
@@ -1116,10 +1117,6 @@ async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tupl
11161117
11171118 # -- Low-level --
11181119
1119- def unpack_blob (self , blob : Buffer ) -> dict [tuple [int , ...], Buffer | None ]:
1120- key = (0 ,) * len (self .chunks_per_shard )
1121- return {key : blob }
1122-
11231120 def pack_blob (
11241121 self , chunk_dict : dict [tuple [int , ...], Buffer | None ], prototype : BufferPrototype
11251122 ) -> Buffer | None :
@@ -1172,6 +1169,10 @@ def supports_partial_write(self) -> bool:
11721169 """True when inner codecs are fixed-size, enabling byte-range writes."""
11731170 return self ._fixed_size
11741171
1172+ @property
1173+ def leaf_transform (self ) -> ChunkTransform :
1174+ return self ._get_leaf_transform ()
1175+
11751176 def _decode_index (self , index_bytes : Buffer ) -> Any :
11761177 from zarr .codecs .sharding import _ShardIndex
11771178
@@ -1186,24 +1187,6 @@ def _encode_index(self, index: Any) -> Buffer:
11861187 assert result is not None
11871188 return result
11881189
1189- def unpack_blob (self , blob : Buffer ) -> dict [tuple [int , ...], Buffer | None ]:
1190- from zarr .codecs .sharding import ShardingCodecIndexLocation
1191-
1192- if self ._index_location == ShardingCodecIndexLocation .start :
1193- index_bytes = blob [: self ._index_size ]
1194- else :
1195- index_bytes = blob [- self ._index_size :]
1196-
1197- index = self ._decode_index (index_bytes )
1198- result : dict [tuple [int , ...], Buffer | None ] = {}
1199- for chunk_coords in np .ndindex (self .chunks_per_shard ):
1200- chunk_slice = index .get_chunk_slice (chunk_coords )
1201- if chunk_slice is not None :
1202- result [chunk_coords ] = blob [chunk_slice [0 ] : chunk_slice [1 ]]
1203- else :
1204- result [chunk_coords ] = None
1205- return result
1206-
12071190 def pack_blob (
12081191 self , chunk_dict : dict [tuple [int , ...], Buffer | None ], prototype : BufferPrototype
12091192 ) -> Buffer | None :
@@ -1444,25 +1427,82 @@ async def resolve_index_async(self, byte_getter: Any, key: str, chunk_selection:
14441427
14451428 # -- pack and store --
14461429
1430+ def _pack_nested (
1431+ self ,
1432+ encoded_chunks : dict [tuple [int , ...], Buffer | None ],
1433+ ) -> Buffer | None :
1434+ """Pack flat leaf chunks into a nested shard blob.
1435+
1436+ Groups leaf chunks by inner shard, packs each group into an
1437+ inner shard blob, then packs inner shard blobs into the outer
1438+ shard blob.
1439+ """
1440+ from zarr .codecs .sharding import ShardingCodec
1441+ from zarr .core .buffer import default_buffer_prototype
1442+
1443+ inner_ab = self .inner_transform ._ab_codec
1444+ assert isinstance (inner_ab , ShardingCodec )
1445+
1446+ inner_spec = self .inner_transform .array_spec
1447+ inner_layout = ShardedChunkLayout .from_sharding_codec (inner_ab , inner_spec )
1448+ inner_cps = inner_layout .chunks_per_shard
1449+
1450+ # Group leaf coords by inner shard
1451+ groups : dict [tuple [int , ...], dict [tuple [int , ...], Buffer | None ]] = {}
1452+ for global_coord , chunk_bytes in encoded_chunks .items ():
1453+ inner_shard_coord = tuple (gc // cps for gc , cps in zip (global_coord , inner_cps , strict = True ))
1454+ leaf_coord = tuple (gc % cps for gc , cps in zip (global_coord , inner_cps , strict = True ))
1455+ if inner_shard_coord not in groups :
1456+ groups [inner_shard_coord ] = {}
1457+ groups [inner_shard_coord ][leaf_coord ] = chunk_bytes
1458+
1459+ # Pack each group into an inner shard blob
1460+ proto = default_buffer_prototype ()
1461+ inner_shard_blobs : dict [tuple [int , ...], Buffer | None ] = {}
1462+ for inner_shard_coord in np .ndindex (self .chunks_per_shard ):
1463+ group = groups .get (inner_shard_coord , {})
1464+ # Fill missing leaf coords with None
1465+ for lc in np .ndindex (inner_cps ):
1466+ if lc not in group :
1467+ group [lc ] = None
1468+ inner_blob = inner_layout .pack_blob (group , proto )
1469+ inner_shard_blobs [inner_shard_coord ] = inner_blob
1470+
1471+ # Pack inner shard blobs into outer shard
1472+ return self .pack_blob (inner_shard_blobs , proto )
1473+
14471474 def pack_and_store_sync (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ]) -> None :
1475+ from zarr .codecs .sharding import ShardingCodec
14481476 from zarr .core .buffer import default_buffer_prototype
14491477
14501478 if all (v is None for v in encoded_chunks .values ()):
14511479 byte_setter .delete_sync () # type: ignore[attr-defined]
14521480 return
1453- blob = self .pack_blob (encoded_chunks , default_buffer_prototype ())
1481+
1482+ # Check for nested sharding
1483+ if isinstance (self .inner_transform ._ab_codec , ShardingCodec ):
1484+ blob = self ._pack_nested (encoded_chunks )
1485+ else :
1486+ blob = self .pack_blob (encoded_chunks , default_buffer_prototype ())
1487+
14541488 if blob is None :
14551489 byte_setter .delete_sync () # type: ignore[attr-defined]
14561490 else :
14571491 byte_setter .set_sync (blob ) # type: ignore[attr-defined]
14581492
14591493 async def pack_and_store_async (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ]) -> None :
1494+ from zarr .codecs .sharding import ShardingCodec
14601495 from zarr .core .buffer import default_buffer_prototype
14611496
14621497 if all (v is None for v in encoded_chunks .values ()):
14631498 await byte_setter .delete ()
14641499 return
1465- blob = self .pack_blob (encoded_chunks , default_buffer_prototype ())
1500+
1501+ if isinstance (self .inner_transform ._ab_codec , ShardingCodec ):
1502+ blob = self ._pack_nested (encoded_chunks )
1503+ else :
1504+ blob = self .pack_blob (encoded_chunks , default_buffer_prototype ())
1505+
14661506 if blob is None :
14671507 await byte_setter .delete ()
14681508 else :
@@ -1852,7 +1892,7 @@ async def _process_chunk(
18521892
18531893 # Phase 1: resolve index (IO)
18541894 if is_complete :
1855- index = ShardIndex (key = key , leaf_transform = layout .inner_transform , is_sharded = layout .is_sharded )
1895+ index = ShardIndex (key = key , leaf_transform = layout .leaf_transform , is_sharded = layout .is_sharded )
18561896 elif layout .is_sharded :
18571897 async with sem :
18581898 index = await layout .resolve_index_async (byte_setter , key , chunk_selection = None ) # ALL coords
@@ -1974,7 +2014,7 @@ def write_sync(
19742014
19752015 # Phase 1: resolve index
19762016 if is_complete :
1977- index = ShardIndex (key = key , leaf_transform = layout .inner_transform , is_sharded = layout .is_sharded )
2017+ index = ShardIndex (key = key , leaf_transform = layout .leaf_transform , is_sharded = layout .is_sharded )
19782018 elif layout .is_sharded :
19792019 index = layout .resolve_index (bs , key , chunk_selection = None ) # ALL coords
19802020 else :
0 commit comments