Skip to content

Commit 8831672

Browse files
d-v-bclaude
andcommitted
refactor: remove dead layout methods — ChunkLayout owns only resolve_index + pack_and_store
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f7772cc commit 8831672

1 file changed

Lines changed: 15 additions & 275 deletions

File tree

src/zarr/core/codec_pipeline.py

Lines changed: 15 additions & 275 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,14 +1026,17 @@ def merge_and_encode_from_index(
10261026
class ChunkLayout:
10271027
"""Describes how a stored blob maps to one or more inner chunks.
10281028
1029-
The pipeline interacts with the layout in four phases:
1030-
1031-
1. **Resolve index** (IO) — read shard indexes to determine where
1032-
chunk data lives. Returns a ``ShardIndex``.
1033-
2. **Fetch chunks** (IO) — read the byte ranges from the index.
1034-
3. **Decode / merge+encode** (compute) — decode fetched bytes, or
1035-
merge new data and re-encode.
1036-
4. **Store** (IO) — write results back.
1029+
The pipeline interacts with the layout through two IO responsibilities:
1030+
1031+
- ``resolve_index`` — read shard indexes (if any) to determine byte
1032+
ranges for inner chunks. Returns a ``ShardIndex``.
1033+
- ``pack_and_store`` — assemble encoded chunks into a blob and write
1034+
it to the store.
1035+
1036+
Fetching, decoding, merging, and encoding are handled by module-level
1037+
functions (``fetch_chunks_sync``, ``decode_chunks_from_index``,
1038+
``merge_and_encode_from_index``) that operate on the ``ShardIndex``
1039+
returned by ``resolve_index``.
10371040
"""
10381041

10391042
chunk_shape: tuple[int, ...]
@@ -1048,37 +1051,15 @@ def is_sharded(self) -> bool:
10481051
def needed_coords(self, chunk_selection: SelectorTuple) -> set[tuple[int, ...]] | None:
10491052
return None
10501053

1051-
# -- Phase 1: resolve index --
1054+
# -- resolve index (IO) --
10521055

10531056
def resolve_index(self, byte_getter: Any, key: str, chunk_selection: SelectorTuple | None = None) -> ShardIndex:
10541057
raise NotImplementedError
10551058

10561059
async def resolve_index_async(self, byte_getter: Any, key: str, chunk_selection: SelectorTuple | None = None) -> ShardIndex:
10571060
raise NotImplementedError
10581061

1059-
# -- Phase 2: fetch chunk data --
1060-
1061-
def fetch_chunks(self, byte_getter: Any, index: ShardIndex, prototype: BufferPrototype) -> dict[tuple[int, ...], Buffer | None]:
1062-
raise NotImplementedError
1063-
1064-
async def fetch_chunks_async(self, byte_getter: Any, index: ShardIndex, prototype: BufferPrototype) -> dict[tuple[int, ...], Buffer | None]:
1065-
raise NotImplementedError
1066-
1067-
# -- Phase 3: compute --
1068-
1069-
def decode_chunks(self, raw_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> NDBuffer:
1070-
raise NotImplementedError
1071-
1072-
def merge_and_encode(self, existing_chunks: dict[tuple[int, ...], Buffer | None], value: NDBuffer, chunk_spec: ArraySpec, chunk_selection: SelectorTuple, out_selection: SelectorTuple, drop_axes: tuple[int, ...]) -> dict[tuple[int, ...], Buffer | None]:
1073-
raise NotImplementedError
1074-
1075-
# -- Phase 4: store --
1076-
1077-
def store_chunks_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> None:
1078-
raise NotImplementedError
1079-
1080-
async def store_chunks_async(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> None:
1081-
raise NotImplementedError
1062+
# -- pack and store (IO) --
10821063

10831064
def pack_and_store_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
10841065
raise NotImplementedError
@@ -1115,96 +1096,7 @@ def resolve_index(self, byte_getter: Any, key: str, chunk_selection: SelectorTup
11151096
async def resolve_index_async(self, byte_getter: Any, key: str, chunk_selection: SelectorTuple | None = None) -> ShardIndex:
11161097
return self.resolve_index(byte_getter, key, chunk_selection)
11171098

1118-
# -- Phase 2: fetch chunk data --
1119-
1120-
def fetch_chunks(self, byte_getter: Any, index: ShardIndex, prototype: BufferPrototype) -> dict[tuple[int, ...], Buffer | None]:
1121-
coord = next(iter(index.chunks))
1122-
raw = byte_getter.get_sync(prototype=prototype)
1123-
return {coord: raw} # type: ignore[no-any-return]
1124-
1125-
async def fetch_chunks_async(self, byte_getter: Any, index: ShardIndex, prototype: BufferPrototype) -> dict[tuple[int, ...], Buffer | None]:
1126-
coord = next(iter(index.chunks))
1127-
raw = await byte_getter.get(prototype=prototype)
1128-
return {coord: raw} # type: ignore[no-any-return]
1129-
1130-
# -- Phase 3: compute --
1131-
1132-
def decode_chunks(self, raw_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> NDBuffer:
1133-
raw = next(iter(raw_chunks.values()))
1134-
if raw is None:
1135-
return chunk_spec.prototype.nd_buffer.create(
1136-
shape=chunk_spec.shape,
1137-
dtype=chunk_spec.dtype.to_native_dtype(),
1138-
order=chunk_spec.order,
1139-
fill_value=fill_value_or_default(chunk_spec),
1140-
)
1141-
chunk_shape = chunk_spec.shape if chunk_spec.shape != self.chunk_shape else None
1142-
return self.inner_transform.decode_chunk(raw, chunk_shape=chunk_shape)
1143-
1144-
def encode(
1145-
self,
1146-
chunk_array: NDBuffer,
1147-
chunk_spec: ArraySpec,
1148-
) -> Buffer | None:
1149-
chunk_shape = chunk_spec.shape if chunk_spec.shape != self.chunk_shape else None
1150-
return self.inner_transform.encode_chunk(chunk_array, chunk_shape=chunk_shape)
1151-
1152-
def merge_and_encode(self, existing_chunks: dict[tuple[int, ...], Buffer | None], value: NDBuffer, chunk_spec: ArraySpec, chunk_selection: SelectorTuple, out_selection: SelectorTuple, drop_axes: tuple[int, ...]) -> dict[tuple[int, ...], Buffer | None]:
1153-
coord = next(iter(existing_chunks)) if existing_chunks else (0,) * len(self.chunks_per_shard)
1154-
1155-
# Decode existing
1156-
existing_raw = existing_chunks.get(coord)
1157-
if existing_raw is not None:
1158-
chunk_array = self.inner_transform.decode_chunk(existing_raw, chunk_shape=chunk_spec.shape)
1159-
if not chunk_array.as_ndarray_like().flags.writeable: # type: ignore[attr-defined]
1160-
chunk_array = chunk_spec.prototype.nd_buffer.from_ndarray_like(
1161-
chunk_array.as_ndarray_like().copy()
1162-
)
1163-
else:
1164-
chunk_array = chunk_spec.prototype.nd_buffer.create(
1165-
shape=chunk_spec.shape,
1166-
dtype=chunk_spec.dtype.to_native_dtype(),
1167-
fill_value=fill_value_or_default(chunk_spec),
1168-
)
1169-
1170-
# Merge value
1171-
if chunk_selection == () or is_scalar(
1172-
value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype()
1173-
):
1174-
chunk_value = value
1175-
else:
1176-
chunk_value = value[out_selection]
1177-
if drop_axes:
1178-
item = tuple(
1179-
None if idx in drop_axes else slice(None) for idx in range(chunk_spec.ndim)
1180-
)
1181-
chunk_value = chunk_value[item]
1182-
chunk_array[chunk_selection] = chunk_value
1183-
1184-
# Check write_empty_chunks
1185-
if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(
1186-
chunk_spec.fill_value
1187-
):
1188-
return {coord: None}
1189-
1190-
encoded = self.encode(chunk_array, chunk_spec)
1191-
return {coord: encoded}
1192-
1193-
# -- Phase 4: store --
1194-
1195-
def store_chunks_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> None:
1196-
blob = next(iter(encoded_chunks.values()))
1197-
if blob is None:
1198-
byte_setter.delete_sync() # type: ignore[attr-defined]
1199-
else:
1200-
byte_setter.set_sync(blob) # type: ignore[attr-defined]
1201-
1202-
async def store_chunks_async(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> None:
1203-
blob = next(iter(encoded_chunks.values()))
1204-
if blob is None:
1205-
await byte_setter.delete()
1206-
else:
1207-
await byte_setter.set(blob)
1099+
# -- pack and store --
12081100

12091101
def pack_and_store_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
12101102
coord = (0,) * len(self.chunks_per_shard)
@@ -1550,130 +1442,7 @@ async def resolve_index_async(self, byte_getter: Any, key: str, chunk_selection:
15501442

15511443
return ShardIndex(key=key, chunks=flat, leaf_transform=leaf_transform, is_sharded=True)
15521444

1553-
# -- Phase 2: fetch chunk data --
1554-
1555-
def fetch_chunks(self, byte_getter: Any, index: ShardIndex, prototype: BufferPrototype) -> dict[tuple[int, ...], Buffer | None]:
1556-
result: dict[tuple[int, ...], Buffer | None] = {}
1557-
for coord, byte_range in index.chunks.items():
1558-
if byte_range is None:
1559-
result[coord] = None
1560-
else:
1561-
result[coord] = byte_getter.get_sync(prototype=prototype, byte_range=byte_range) # type: ignore[no-any-return]
1562-
return result
1563-
1564-
async def fetch_chunks_async(self, byte_getter: Any, index: ShardIndex, prototype: BufferPrototype) -> dict[tuple[int, ...], Buffer | None]:
1565-
result: dict[tuple[int, ...], Buffer | None] = {}
1566-
for coord, byte_range in index.chunks.items():
1567-
if byte_range is None:
1568-
result[coord] = None
1569-
else:
1570-
result[coord] = await byte_getter.get(prototype=prototype, byte_range=byte_range)
1571-
return result
1572-
1573-
# -- Phase 3: compute --
1574-
1575-
def decode_chunks(self, raw_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> NDBuffer:
1576-
return self._decode_per_chunk(raw_chunks, chunk_spec)
1577-
1578-
def merge_and_encode(self, existing_chunks: dict[tuple[int, ...], Buffer | None], value: NDBuffer, chunk_spec: ArraySpec, chunk_selection: SelectorTuple, out_selection: SelectorTuple, drop_axes: tuple[int, ...]) -> dict[tuple[int, ...], Buffer | None]:
1579-
from zarr.core.chunk_grids import ChunkGrid as _ChunkGrid
1580-
from zarr.core.indexing import get_indexer
1581-
1582-
chunk_dict = dict(existing_chunks)
1583-
1584-
# Fill missing coords with None
1585-
for coord in np.ndindex(self.chunks_per_shard):
1586-
if coord not in chunk_dict:
1587-
chunk_dict[coord] = None
1588-
1589-
inner_spec = ArraySpec(
1590-
shape=self.inner_chunk_shape,
1591-
dtype=chunk_spec.dtype,
1592-
fill_value=chunk_spec.fill_value,
1593-
config=chunk_spec.config,
1594-
prototype=chunk_spec.prototype,
1595-
)
1596-
1597-
# Extract the shard's portion of the write value.
1598-
if is_scalar(value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype()):
1599-
shard_value = value
1600-
else:
1601-
shard_value = value[out_selection]
1602-
if drop_axes:
1603-
item = tuple(
1604-
None if idx in drop_axes else slice(None)
1605-
for idx in range(len(chunk_spec.shape))
1606-
)
1607-
shard_value = shard_value[item]
1608-
1609-
# Determine which inner chunks are affected
1610-
indexer = get_indexer(
1611-
chunk_selection,
1612-
shape=chunk_spec.shape,
1613-
chunk_grid=_ChunkGrid.from_sizes(chunk_spec.shape, self.inner_chunk_shape),
1614-
)
1615-
1616-
for inner_coords, inner_sel, value_sel, _ in indexer:
1617-
existing_bytes = chunk_dict.get(inner_coords)
1618-
1619-
# Decode just this inner chunk
1620-
if existing_bytes is not None:
1621-
inner_array = self.inner_transform.decode_chunk(existing_bytes)
1622-
if not inner_array.as_ndarray_like().flags.writeable: # type: ignore[attr-defined]
1623-
inner_array = inner_spec.prototype.nd_buffer.from_ndarray_like(
1624-
inner_array.as_ndarray_like().copy()
1625-
)
1626-
else:
1627-
inner_array = inner_spec.prototype.nd_buffer.create(
1628-
shape=inner_spec.shape,
1629-
dtype=inner_spec.dtype.to_native_dtype(),
1630-
fill_value=fill_value_or_default(inner_spec),
1631-
)
1632-
1633-
# Merge new data
1634-
if inner_sel == () or is_scalar(
1635-
shard_value.as_ndarray_like(), inner_spec.dtype.to_native_dtype()
1636-
):
1637-
inner_value = shard_value
1638-
else:
1639-
inner_value = shard_value[value_sel]
1640-
inner_array[inner_sel] = inner_value
1641-
1642-
# Re-encode
1643-
if not chunk_spec.config.write_empty_chunks and inner_array.all_equal(
1644-
chunk_spec.fill_value
1645-
):
1646-
chunk_dict[inner_coords] = None
1647-
else:
1648-
chunk_dict[inner_coords] = self.inner_transform.encode_chunk(inner_array)
1649-
1650-
return chunk_dict
1651-
1652-
# -- Phase 4: store --
1653-
1654-
def store_chunks_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> None:
1655-
from zarr.core.buffer import default_buffer_prototype
1656-
1657-
if all(v is None for v in encoded_chunks.values()):
1658-
byte_setter.delete_sync() # type: ignore[attr-defined]
1659-
else:
1660-
blob = self.pack_blob(encoded_chunks, default_buffer_prototype())
1661-
if blob is None:
1662-
byte_setter.delete_sync() # type: ignore[attr-defined]
1663-
else:
1664-
byte_setter.set_sync(blob) # type: ignore[attr-defined]
1665-
1666-
async def store_chunks_async(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> None:
1667-
from zarr.core.buffer import default_buffer_prototype
1668-
1669-
if all(v is None for v in encoded_chunks.values()):
1670-
await byte_setter.delete()
1671-
else:
1672-
blob = self.pack_blob(encoded_chunks, default_buffer_prototype())
1673-
if blob is None:
1674-
await byte_setter.delete()
1675-
else:
1676-
await byte_setter.set(blob)
1445+
# -- pack and store --
16771446

16781447
def pack_and_store_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
16791448
from zarr.core.buffer import default_buffer_prototype
@@ -1699,35 +1468,6 @@ async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tupl
16991468
else:
17001469
await byte_setter.set(blob)
17011470

1702-
def _decode_per_chunk(
1703-
self,
1704-
chunk_dict: dict[tuple[int, ...], Buffer | None],
1705-
shard_spec: ArraySpec,
1706-
) -> NDBuffer:
1707-
"""Assemble inner chunk buffers into a chunk-shaped array."""
1708-
out = shard_spec.prototype.nd_buffer.empty(
1709-
shape=shard_spec.shape,
1710-
dtype=shard_spec.dtype.to_native_dtype(),
1711-
order=shard_spec.order,
1712-
)
1713-
1714-
inner_shape = self.inner_chunk_shape
1715-
fill = fill_value_or_default(shard_spec)
1716-
decode = self.inner_transform.decode_chunk
1717-
1718-
for coords, chunk_bytes in chunk_dict.items():
1719-
out_selection = tuple(
1720-
slice(c * s, min((c + 1) * s, sh))
1721-
for c, s, sh in zip(coords, inner_shape, shard_spec.shape, strict=True)
1722-
)
1723-
if chunk_bytes is not None:
1724-
chunk_array = decode(chunk_bytes)
1725-
out[out_selection] = chunk_array
1726-
else:
1727-
out[out_selection] = fill
1728-
1729-
return out
1730-
17311471
async def _fetch_index(self, byte_getter: Any) -> Any:
17321472
from zarr.abc.store import RangeByteRequest, SuffixByteRequest
17331473
from zarr.codecs.sharding import ShardingCodecIndexLocation

0 commit comments

Comments
 (0)