Skip to content

Commit ed2e230

Browse files
d-v-bclaude
andcommitted
fix: nested sharding write + leaf_transform property + remove dead unpack_blob
- Add leaf_transform property to ChunkLayout base class (returns inner_transform) and override on ShardedChunkLayout (traverses nested ShardingCodecs to find innermost codec chain) - Fix write path complete-overwrite to use layout.leaf_transform instead of layout.inner_transform (was using wrong transform for nested sharding) - Fix decode_chunks_from_index to use index.is_sharded instead of fragile shape-based is_simple heuristic - Add _pack_nested to ShardedChunkLayout: groups flat leaf chunks by inner shard, packs each group into an inner shard blob, then packs into outer shard — produces correct nested shard structure - Remove dead unpack_blob from all layout classes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ae48c67 commit ed2e230

1 file changed

Lines changed: 76 additions & 36 deletions

File tree

src/zarr/core/codec_pipeline.py

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -844,13 +844,7 @@ def decode_chunks_from_index(
844844
order=chunk_spec.order,
845845
)
846846

847-
# Non-sharded: the transform covers the entire chunk (or more, for rectilinear edges).
848-
# Sharded: the transform's inner shape is smaller than chunk_spec in at least one dim.
849-
is_simple = all(
850-
t >= c for t, c in zip(transform.array_spec.shape, chunk_spec.shape, strict=True)
851-
)
852-
853-
if is_simple:
847+
if not index.is_sharded:
854848
assert len(raw_chunks) == 1
855849
raw = next(iter(raw_chunks.values()))
856850
if raw is None:
@@ -1048,6 +1042,16 @@ class ChunkLayout:
10481042
def is_sharded(self) -> bool:
10491043
return False
10501044

1045+
@property
1046+
def leaf_transform(self) -> ChunkTransform:
1047+
"""The codec chain that decodes individual leaf chunks.
1048+
1049+
For non-sharded layouts, this is the full transform.
1050+
For sharded layouts, this traverses nested ShardingCodecs to
1051+
find the innermost codec chain.
1052+
"""
1053+
return self.inner_transform
1054+
10511055
def needed_coords(self, chunk_selection: SelectorTuple) -> set[tuple[int, ...]] | None:
10521056
return None
10531057

@@ -1069,9 +1073,6 @@ async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tupl
10691073

10701074
# -- Low-level helpers --
10711075

1072-
def unpack_blob(self, blob: Buffer) -> dict[tuple[int, ...], Buffer | None]:
1073-
raise NotImplementedError
1074-
10751076
def pack_blob(
10761077
self, chunk_dict: dict[tuple[int, ...], Buffer | None], prototype: BufferPrototype
10771078
) -> Buffer | None:
@@ -1116,10 +1117,6 @@ async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tupl
11161117

11171118
# -- Low-level --
11181119

1119-
def unpack_blob(self, blob: Buffer) -> dict[tuple[int, ...], Buffer | None]:
1120-
key = (0,) * len(self.chunks_per_shard)
1121-
return {key: blob}
1122-
11231120
def pack_blob(
11241121
self, chunk_dict: dict[tuple[int, ...], Buffer | None], prototype: BufferPrototype
11251122
) -> Buffer | None:
@@ -1172,6 +1169,10 @@ def supports_partial_write(self) -> bool:
11721169
"""True when inner codecs are fixed-size, enabling byte-range writes."""
11731170
return self._fixed_size
11741171

1172+
@property
1173+
def leaf_transform(self) -> ChunkTransform:
1174+
return self._get_leaf_transform()
1175+
11751176
def _decode_index(self, index_bytes: Buffer) -> Any:
11761177
from zarr.codecs.sharding import _ShardIndex
11771178

@@ -1186,24 +1187,6 @@ def _encode_index(self, index: Any) -> Buffer:
11861187
assert result is not None
11871188
return result
11881189

1189-
def unpack_blob(self, blob: Buffer) -> dict[tuple[int, ...], Buffer | None]:
1190-
from zarr.codecs.sharding import ShardingCodecIndexLocation
1191-
1192-
if self._index_location == ShardingCodecIndexLocation.start:
1193-
index_bytes = blob[: self._index_size]
1194-
else:
1195-
index_bytes = blob[-self._index_size :]
1196-
1197-
index = self._decode_index(index_bytes)
1198-
result: dict[tuple[int, ...], Buffer | None] = {}
1199-
for chunk_coords in np.ndindex(self.chunks_per_shard):
1200-
chunk_slice = index.get_chunk_slice(chunk_coords)
1201-
if chunk_slice is not None:
1202-
result[chunk_coords] = blob[chunk_slice[0] : chunk_slice[1]]
1203-
else:
1204-
result[chunk_coords] = None
1205-
return result
1206-
12071190
def pack_blob(
12081191
self, chunk_dict: dict[tuple[int, ...], Buffer | None], prototype: BufferPrototype
12091192
) -> Buffer | None:
@@ -1444,25 +1427,82 @@ async def resolve_index_async(self, byte_getter: Any, key: str, chunk_selection:
14441427

14451428
# -- pack and store --
14461429

1430+
def _pack_nested(
1431+
self,
1432+
encoded_chunks: dict[tuple[int, ...], Buffer | None],
1433+
) -> Buffer | None:
1434+
"""Pack flat leaf chunks into a nested shard blob.
1435+
1436+
Groups leaf chunks by inner shard, packs each group into an
1437+
inner shard blob, then packs inner shard blobs into the outer
1438+
shard blob.
1439+
"""
1440+
from zarr.codecs.sharding import ShardingCodec
1441+
from zarr.core.buffer import default_buffer_prototype
1442+
1443+
inner_ab = self.inner_transform._ab_codec
1444+
assert isinstance(inner_ab, ShardingCodec)
1445+
1446+
inner_spec = self.inner_transform.array_spec
1447+
inner_layout = ShardedChunkLayout.from_sharding_codec(inner_ab, inner_spec)
1448+
inner_cps = inner_layout.chunks_per_shard
1449+
1450+
# Group leaf coords by inner shard
1451+
groups: dict[tuple[int, ...], dict[tuple[int, ...], Buffer | None]] = {}
1452+
for global_coord, chunk_bytes in encoded_chunks.items():
1453+
inner_shard_coord = tuple(gc // cps for gc, cps in zip(global_coord, inner_cps, strict=True))
1454+
leaf_coord = tuple(gc % cps for gc, cps in zip(global_coord, inner_cps, strict=True))
1455+
if inner_shard_coord not in groups:
1456+
groups[inner_shard_coord] = {}
1457+
groups[inner_shard_coord][leaf_coord] = chunk_bytes
1458+
1459+
# Pack each group into an inner shard blob
1460+
proto = default_buffer_prototype()
1461+
inner_shard_blobs: dict[tuple[int, ...], Buffer | None] = {}
1462+
for inner_shard_coord in np.ndindex(self.chunks_per_shard):
1463+
group = groups.get(inner_shard_coord, {})
1464+
# Fill missing leaf coords with None
1465+
for lc in np.ndindex(inner_cps):
1466+
if lc not in group:
1467+
group[lc] = None
1468+
inner_blob = inner_layout.pack_blob(group, proto)
1469+
inner_shard_blobs[inner_shard_coord] = inner_blob
1470+
1471+
# Pack inner shard blobs into outer shard
1472+
return self.pack_blob(inner_shard_blobs, proto)
1473+
14471474
def pack_and_store_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
1475+
from zarr.codecs.sharding import ShardingCodec
14481476
from zarr.core.buffer import default_buffer_prototype
14491477

14501478
if all(v is None for v in encoded_chunks.values()):
14511479
byte_setter.delete_sync() # type: ignore[attr-defined]
14521480
return
1453-
blob = self.pack_blob(encoded_chunks, default_buffer_prototype())
1481+
1482+
# Check for nested sharding
1483+
if isinstance(self.inner_transform._ab_codec, ShardingCodec):
1484+
blob = self._pack_nested(encoded_chunks)
1485+
else:
1486+
blob = self.pack_blob(encoded_chunks, default_buffer_prototype())
1487+
14541488
if blob is None:
14551489
byte_setter.delete_sync() # type: ignore[attr-defined]
14561490
else:
14571491
byte_setter.set_sync(blob) # type: ignore[attr-defined]
14581492

14591493
async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
1494+
from zarr.codecs.sharding import ShardingCodec
14601495
from zarr.core.buffer import default_buffer_prototype
14611496

14621497
if all(v is None for v in encoded_chunks.values()):
14631498
await byte_setter.delete()
14641499
return
1465-
blob = self.pack_blob(encoded_chunks, default_buffer_prototype())
1500+
1501+
if isinstance(self.inner_transform._ab_codec, ShardingCodec):
1502+
blob = self._pack_nested(encoded_chunks)
1503+
else:
1504+
blob = self.pack_blob(encoded_chunks, default_buffer_prototype())
1505+
14661506
if blob is None:
14671507
await byte_setter.delete()
14681508
else:
@@ -1852,7 +1892,7 @@ async def _process_chunk(
18521892

18531893
# Phase 1: resolve index (IO)
18541894
if is_complete:
1855-
index = ShardIndex(key=key, leaf_transform=layout.inner_transform, is_sharded=layout.is_sharded)
1895+
index = ShardIndex(key=key, leaf_transform=layout.leaf_transform, is_sharded=layout.is_sharded)
18561896
elif layout.is_sharded:
18571897
async with sem:
18581898
index = await layout.resolve_index_async(byte_setter, key, chunk_selection=None) # ALL coords
@@ -1974,7 +2014,7 @@ def write_sync(
19742014

19752015
# Phase 1: resolve index
19762016
if is_complete:
1977-
index = ShardIndex(key=key, leaf_transform=layout.inner_transform, is_sharded=layout.is_sharded)
2017+
index = ShardIndex(key=key, leaf_transform=layout.leaf_transform, is_sharded=layout.is_sharded)
19782018
elif layout.is_sharded:
19792019
index = layout.resolve_index(bs, key, chunk_selection=None) # ALL coords
19802020
else:

0 commit comments

Comments
 (0)