Skip to content

Commit 638d57f

Browse files
committed
refactor: write paths use generic merge_and_encode with ShardIndex.leaf_transform
1 parent ba60797 commit 638d57f

1 file changed

Lines changed: 216 additions & 14 deletions

File tree

src/zarr/core/codec_pipeline.py

Lines changed: 216 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,152 @@ def decode_chunks_from_index(
877877
return out
878878

879879

880+
def merge_and_encode_from_index(
881+
existing_raw: dict[tuple[int, ...], Buffer | None],
882+
index: ShardIndex,
883+
value: NDBuffer,
884+
chunk_spec: ArraySpec,
885+
chunk_selection: SelectorTuple,
886+
out_selection: SelectorTuple,
887+
drop_axes: tuple[int, ...],
888+
) -> dict[tuple[int, ...], Buffer | None]:
889+
"""Merge new data into existing chunk(s) and encode, using index.leaf_transform.
890+
891+
For non-sharded layouts (``index.is_sharded`` is False): decode the single
892+
existing chunk (or create from fill value), merge *value* at the given
893+
selection, and encode. Returns ``{(0,...): encoded}``.
894+
895+
For sharded layouts (``index.is_sharded`` is True): start with existing raw
896+
chunks, fill missing coords with None, then iterate over affected inner
897+
chunks using ``get_indexer``. Decode/merge/encode each. Returns the full
898+
chunk dict for subsequent packing into a shard blob.
899+
"""
900+
from zarr.core.indexing import get_indexer
901+
902+
assert index.leaf_transform is not None
903+
transform = index.leaf_transform
904+
905+
if not index.is_sharded:
906+
# --- Simple (non-sharded) path ---
907+
coord = next(iter(existing_raw)) if existing_raw else (0,) * len(chunk_spec.shape)
908+
909+
existing_bytes = existing_raw.get(coord)
910+
if existing_bytes is not None:
911+
chunk_array = transform.decode_chunk(existing_bytes, chunk_shape=chunk_spec.shape)
912+
if not chunk_array.as_ndarray_like().flags.writeable: # type: ignore[attr-defined]
913+
chunk_array = chunk_spec.prototype.nd_buffer.from_ndarray_like(
914+
chunk_array.as_ndarray_like().copy()
915+
)
916+
else:
917+
chunk_array = chunk_spec.prototype.nd_buffer.create(
918+
shape=chunk_spec.shape,
919+
dtype=chunk_spec.dtype.to_native_dtype(),
920+
fill_value=fill_value_or_default(chunk_spec),
921+
)
922+
923+
# Merge value
924+
if chunk_selection == () or is_scalar(
925+
value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype()
926+
):
927+
chunk_value = value
928+
else:
929+
chunk_value = value[out_selection]
930+
if drop_axes:
931+
item = tuple(
932+
None if idx in drop_axes else slice(None) for idx in range(chunk_spec.ndim)
933+
)
934+
chunk_value = chunk_value[item]
935+
chunk_array[chunk_selection] = chunk_value
936+
937+
# Check write_empty_chunks
938+
if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(
939+
chunk_spec.fill_value
940+
):
941+
return {coord: None}
942+
943+
chunk_shape = chunk_spec.shape if chunk_spec.shape != transform.array_spec.shape else None
944+
encoded = transform.encode_chunk(chunk_array, chunk_shape=chunk_shape)
945+
return {coord: encoded}
946+
947+
# --- Sharded path ---
948+
inner_shape = transform.array_spec.shape
949+
chunks_per_shard = tuple(
950+
s // cs for s, cs in zip(chunk_spec.shape, inner_shape, strict=True)
951+
)
952+
953+
chunk_dict: dict[tuple[int, ...], Buffer | None] = dict(existing_raw)
954+
955+
# Fill missing coords with None
956+
for coord in np.ndindex(chunks_per_shard):
957+
if coord not in chunk_dict:
958+
chunk_dict[coord] = None
959+
960+
inner_spec = ArraySpec(
961+
shape=inner_shape,
962+
dtype=chunk_spec.dtype,
963+
fill_value=chunk_spec.fill_value,
964+
config=chunk_spec.config,
965+
prototype=chunk_spec.prototype,
966+
)
967+
968+
# Extract the shard's portion of the write value
969+
if is_scalar(value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype()):
970+
shard_value = value
971+
else:
972+
shard_value = value[out_selection]
973+
if drop_axes:
974+
item = tuple(
975+
None if idx in drop_axes else slice(None)
976+
for idx in range(len(chunk_spec.shape))
977+
)
978+
shard_value = shard_value[item]
979+
980+
# Determine which inner chunks are affected
981+
from zarr.core.chunk_grids import ChunkGrid as _ChunkGrid
982+
983+
indexer = get_indexer(
984+
chunk_selection,
985+
shape=chunk_spec.shape,
986+
chunk_grid=_ChunkGrid.from_sizes(chunk_spec.shape, inner_shape),
987+
)
988+
989+
for inner_coords, inner_sel, value_sel, _ in indexer:
990+
existing_bytes = chunk_dict.get(inner_coords)
991+
992+
# Decode just this inner chunk
993+
if existing_bytes is not None:
994+
inner_array = transform.decode_chunk(existing_bytes)
995+
if not inner_array.as_ndarray_like().flags.writeable: # type: ignore[attr-defined]
996+
inner_array = inner_spec.prototype.nd_buffer.from_ndarray_like(
997+
inner_array.as_ndarray_like().copy()
998+
)
999+
else:
1000+
inner_array = inner_spec.prototype.nd_buffer.create(
1001+
shape=inner_spec.shape,
1002+
dtype=inner_spec.dtype.to_native_dtype(),
1003+
fill_value=fill_value_or_default(inner_spec),
1004+
)
1005+
1006+
# Merge new data
1007+
if inner_sel == () or is_scalar(
1008+
shard_value.as_ndarray_like(), inner_spec.dtype.to_native_dtype()
1009+
):
1010+
inner_value = shard_value
1011+
else:
1012+
inner_value = shard_value[value_sel]
1013+
inner_array[inner_sel] = inner_value
1014+
1015+
# Re-encode
1016+
if not chunk_spec.config.write_empty_chunks and inner_array.all_equal(
1017+
chunk_spec.fill_value
1018+
):
1019+
chunk_dict[inner_coords] = None
1020+
else:
1021+
chunk_dict[inner_coords] = transform.encode_chunk(inner_array)
1022+
1023+
return chunk_dict
1024+
1025+
8801026
class ChunkLayout:
8811027
"""Describes how a stored blob maps to one or more inner chunks.
8821028
@@ -934,6 +1080,12 @@ def store_chunks_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ..
9341080
async def store_chunks_async(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec) -> None:
9351081
raise NotImplementedError
9361082

1083+
def pack_and_store_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
1084+
raise NotImplementedError
1085+
1086+
async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
1087+
raise NotImplementedError
1088+
9371089
# -- Low-level helpers --
9381090

9391091
def unpack_blob(self, blob: Buffer) -> dict[tuple[int, ...], Buffer | None]:
@@ -1054,6 +1206,22 @@ async def store_chunks_async(self, byte_setter: Any, encoded_chunks: dict[tuple[
10541206
else:
10551207
await byte_setter.set(blob)
10561208

1209+
def pack_and_store_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
1210+
coord = (0,) * len(self.chunks_per_shard)
1211+
blob = encoded_chunks.get(coord)
1212+
if blob is None:
1213+
byte_setter.delete_sync() # type: ignore[attr-defined]
1214+
else:
1215+
byte_setter.set_sync(blob) # type: ignore[attr-defined]
1216+
1217+
async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
1218+
coord = (0,) * len(self.chunks_per_shard)
1219+
blob = encoded_chunks.get(coord)
1220+
if blob is None:
1221+
await byte_setter.delete()
1222+
else:
1223+
await byte_setter.set(blob)
1224+
10571225
# -- Low-level --
10581226

10591227
def unpack_blob(self, blob: Buffer) -> dict[tuple[int, ...], Buffer | None]:
@@ -1347,6 +1515,30 @@ async def store_chunks_async(self, byte_setter: Any, encoded_chunks: dict[tuple[
13471515
else:
13481516
await byte_setter.set(blob)
13491517

1518+
def pack_and_store_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
1519+
from zarr.core.buffer import default_buffer_prototype
1520+
1521+
if all(v is None for v in encoded_chunks.values()):
1522+
byte_setter.delete_sync() # type: ignore[attr-defined]
1523+
return
1524+
blob = self.pack_blob(encoded_chunks, default_buffer_prototype())
1525+
if blob is None:
1526+
byte_setter.delete_sync() # type: ignore[attr-defined]
1527+
else:
1528+
byte_setter.set_sync(blob) # type: ignore[attr-defined]
1529+
1530+
async def pack_and_store_async(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ...], Buffer | None]) -> None:
1531+
from zarr.core.buffer import default_buffer_prototype
1532+
1533+
if all(v is None for v in encoded_chunks.values()):
1534+
await byte_setter.delete()
1535+
return
1536+
blob = self.pack_blob(encoded_chunks, default_buffer_prototype())
1537+
if blob is None:
1538+
await byte_setter.delete()
1539+
else:
1540+
await byte_setter.set(blob)
1541+
13501542
def _decode_per_chunk(
13511543
self,
13521544
chunk_dict: dict[tuple[int, ...], Buffer | None],
@@ -1760,7 +1952,7 @@ async def _process_chunk(
17601952

17611953
# Phase 1: resolve index (IO)
17621954
if is_complete:
1763-
index = ShardIndex(key=key)
1955+
index = ShardIndex(key=key, leaf_transform=layout.inner_transform, is_sharded=layout.is_sharded)
17641956
elif layout.is_sharded:
17651957
async with sem:
17661958
index = await layout.resolve_index_async(byte_setter, key, chunk_selection=None) # ALL coords
@@ -1771,25 +1963,26 @@ async def _process_chunk(
17711963
# Phase 2: fetch existing chunks (IO)
17721964
if index.chunks:
17731965
async with sem:
1774-
existing_chunks = await layout.fetch_chunks_async(byte_setter, index, prototype=chunk_spec.prototype)
1966+
existing = await fetch_chunks_async(byte_setter, index, prototype=chunk_spec.prototype)
17751967
else:
1776-
existing_chunks = {}
1968+
existing = {}
17771969

17781970
# Phase 3: merge and encode (compute)
1779-
encoded_chunks = await loop.run_in_executor(
1971+
encoded = await loop.run_in_executor(
17801972
pool,
1781-
layout.merge_and_encode,
1782-
existing_chunks,
1973+
merge_and_encode_from_index,
1974+
existing,
1975+
index,
17831976
value,
17841977
chunk_spec,
17851978
chunk_selection,
17861979
out_selection,
17871980
drop_axes,
17881981
)
17891982

1790-
# Phase 4: store (IO)
1983+
# Phase 4: pack + store (IO)
17911984
async with sem:
1792-
await layout.store_chunks_async(byte_setter, encoded_chunks, chunk_spec)
1985+
await layout.pack_and_store_async(byte_setter, encoded)
17931986

17941987
await asyncio.gather(
17951988
*[
@@ -1868,24 +2061,33 @@ def write_sync(
18682061
if not batch:
18692062
return
18702063

2064+
assert self.layout is not None
2065+
default_layout = self.layout
2066+
18712067
for bs, chunk_spec, chunk_selection, out_selection, is_complete in batch:
18722068
layout = (
1873-
self.layout
1874-
if self.layout is not None and chunk_spec.shape == self.layout.chunk_shape
2069+
default_layout
2070+
if chunk_spec.shape == default_layout.chunk_shape
18752071
else self._get_layout(chunk_spec)
18762072
)
18772073
key = bs.path if hasattr(bs, "path") else ""
18782074

2075+
# Phase 1: resolve index
18792076
if is_complete:
1880-
index = ShardIndex(key=key)
2077+
index = ShardIndex(key=key, leaf_transform=layout.inner_transform, is_sharded=layout.is_sharded)
18812078
elif layout.is_sharded:
18822079
index = layout.resolve_index(bs, key, chunk_selection=None) # ALL coords
18832080
else:
18842081
index = layout.resolve_index(bs, key, chunk_selection=chunk_selection)
18852082

1886-
existing_chunks = layout.fetch_chunks(bs, index, prototype=chunk_spec.prototype) if index.chunks else {}
1887-
encoded_chunks = layout.merge_and_encode(existing_chunks, value, chunk_spec, chunk_selection, out_selection, drop_axes)
1888-
layout.store_chunks_sync(bs, encoded_chunks, chunk_spec)
2083+
# Phase 2: fetch existing
2084+
existing = fetch_chunks_sync(bs, index, prototype=chunk_spec.prototype) if index.chunks else {}
2085+
2086+
# Phase 3: merge + encode (compute)
2087+
encoded = merge_and_encode_from_index(existing, index, value, chunk_spec, chunk_selection, out_selection, drop_axes)
2088+
2089+
# Phase 4: pack + store
2090+
layout.pack_and_store_sync(bs, encoded)
18892091

18902092

18912093
register_pipeline(PhasedCodecPipeline)

0 commit comments

Comments
 (0)