Skip to content

Commit acfc59f

Browse files
committed
fix: dispatch partial-decode in SyncCodecPipeline.read_sync
read_sync was always fetching the full chunk/shard blob and decoding it through the full codec chain. For sharded arrays, this meant a single-element read fetched the entire shard (~125x more IO than needed) and decoded every inner chunk (~125x more compute). Mirror the partial-encode dispatch already in write_sync: when the AB codec implements partial decoding (i.e. ShardingCodec), let the codec own its IO via _decode_partial_sync, fetching only the inner-chunk byte ranges that overlap the read selection. Add ShardingCodec._decode_partial_sync — sync equivalent of _decode_partial_single. Reads the shard index (or full shard if the selection covers everything), decodes only the needed inner chunks through the inner ChunkTransform, scatters into the output buffer. Also extend tests/test_pipeline_parity.py with test_pipeline_read_parity: parametric over (codec config, layout, selection) where selections include scalar reads, partial slices, strided reads, and full reads. The original parity test only exercised full reads — this new test covers the partial-read code path that the regression hit. Benchmark on shape=(105,)^3, chunks=(10,)^3, shards=(50,)^3, MemoryStore: Selection batched sync (before) sync (after) scalar (0,0,0) 0.46 ms 1.6 ms 0.24 ms full slice 83.4 ms (n/a) 17.5 ms strided 4 82.8 ms (n/a) 16.7 ms sub-block (10:-10:4) 42.3 ms (n/a) 9.7 ms Fixes the codspeed regression on test_slice_indexing[(50,50,50)-(0,0,0)-memory] (was 4.6x slower, now 1.9x faster) and similar partial-read cases.
1 parent 1a1ff73 commit acfc59f

3 files changed

Lines changed: 185 additions & 0 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,92 @@ async def _decode_partial_single(
882882
else:
883883
return out
884884

885+
def _decode_partial_sync(
886+
self,
887+
byte_getter: Any,
888+
selection: SelectorTuple,
889+
shard_spec: ArraySpec,
890+
) -> NDBuffer | None:
891+
"""Sync equivalent of ``_decode_partial_single``.
892+
893+
Reads only the inner-chunk byte ranges that overlap ``selection``
894+
(plus the shard index) and decodes them through the inner codec
895+
chain. The store must support ``get_sync`` with byte ranges.
896+
"""
897+
shard_shape = shard_spec.shape
898+
chunk_shape = self.chunk_shape
899+
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
900+
chunk_spec = self._get_chunk_spec(shard_spec)
901+
inner_transform = self._get_inner_chunk_transform(shard_spec)
902+
903+
indexer = get_indexer(
904+
selection,
905+
shape=shard_shape,
906+
chunk_grid=ChunkGrid.from_sizes(shard_shape, chunk_shape),
907+
)
908+
909+
out = shard_spec.prototype.nd_buffer.empty(
910+
shape=indexer.shape,
911+
dtype=shard_spec.dtype.to_native_dtype(),
912+
order=shard_spec.order,
913+
)
914+
915+
indexed_chunks = list(indexer)
916+
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
917+
918+
# Read just the inner chunks we need.
919+
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
920+
shard_bytes = byte_getter.get_sync(prototype=chunk_spec.prototype)
921+
if shard_bytes is None:
922+
return None
923+
shard_reader = self._shard_reader_from_bytes_sync(shard_bytes, chunks_per_shard)
924+
shard_dict: ShardMapping = shard_reader
925+
else:
926+
shard_index_size = self._shard_index_size(chunks_per_shard)
927+
if self.index_location == ShardingCodecIndexLocation.start:
928+
index_bytes = byte_getter.get_sync(
929+
prototype=numpy_buffer_prototype(),
930+
byte_range=RangeByteRequest(0, shard_index_size),
931+
)
932+
else:
933+
index_bytes = byte_getter.get_sync(
934+
prototype=numpy_buffer_prototype(),
935+
byte_range=SuffixByteRequest(shard_index_size),
936+
)
937+
if index_bytes is None:
938+
return None
939+
shard_index = self._decode_shard_index_sync(index_bytes, chunks_per_shard)
940+
shard_dict_mut: dict[tuple[int, ...], Buffer | None] = {}
941+
for chunk_coords in all_chunk_coords:
942+
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
943+
if chunk_byte_slice is not None:
944+
chunk_bytes = byte_getter.get_sync(
945+
prototype=chunk_spec.prototype,
946+
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
947+
)
948+
if chunk_bytes is not None:
949+
shard_dict_mut[chunk_coords] = chunk_bytes
950+
shard_dict = shard_dict_mut
951+
952+
# Decode each needed inner chunk and scatter into out.
953+
fill_value = shard_spec.fill_value
954+
if fill_value is None:
955+
fill_value = shard_spec.dtype.default_scalar()
956+
for chunk_coords, chunk_selection, out_selection, _ in indexed_chunks:
957+
try:
958+
chunk_bytes = shard_dict[chunk_coords]
959+
except KeyError:
960+
chunk_bytes = None
961+
if chunk_bytes is None:
962+
out[out_selection] = fill_value
963+
continue
964+
chunk_array = inner_transform.decode_chunk(chunk_bytes, chunk_spec)
965+
out[out_selection] = chunk_array[chunk_selection]
966+
967+
if hasattr(indexer, "sel_shape"):
968+
return out.reshape(indexer.sel_shape)
969+
return out
970+
885971
async def _encode_single(
886972
self,
887973
shard_array: NDBuffer,

src/zarr/core/codec_pipeline.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,12 @@ def read_sync(
904904
When ``n_workers > 0`` and there are multiple chunks, the decode
905905
step is parallelized across threads. This helps when codecs
906906
release the GIL (e.g. gzip, blosc, zstd).
907+
908+
Mirrors ``BatchedCodecPipeline.read_batch``: when the AB codec
909+
supports partial decoding (e.g. sharding), the codec handles its
910+
own IO and only fetches the inner-chunk byte ranges that overlap
911+
the read selection. Otherwise the pipeline fetches the full
912+
blob and decodes the whole chunk.
907913
"""
908914
assert self._sync_transform is not None
909915
transform = self._sync_transform
@@ -915,6 +921,25 @@ def read_sync(
915921
fill = fill_value_or_default(batch[0][1])
916922
_missing = GetResult(status="missing")
917923

924+
# Partial-decode fast path: the AB codec owns IO (read only the
925+
# byte ranges needed for the requested selection). Same condition
926+
# and dispatch as BatchedCodecPipeline.read_batch.
927+
if self.supports_partial_decode:
928+
codec = self.array_bytes_codec
929+
assert hasattr(codec, "_decode_partial_sync")
930+
partial_results: list[GetResult] = []
931+
for byte_getter, chunk_spec, chunk_selection, out_selection, _ in batch:
932+
decoded = codec._decode_partial_sync(byte_getter, chunk_selection, chunk_spec)
933+
if decoded is None:
934+
out[out_selection] = fill
935+
partial_results.append(_missing)
936+
continue
937+
if drop_axes:
938+
decoded = decoded.squeeze(axis=drop_axes)
939+
out[out_selection] = decoded
940+
partial_results.append(GetResult(status="present"))
941+
return tuple(partial_results)
942+
918943
# Phase 1: fetch all chunks (IO, sequential)
919944
raw_buffers: list[Buffer | None] = [
920945
bg.get_sync(prototype=cs.prototype) # type: ignore[attr-defined]

tests/test_pipeline_parity.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,77 @@ def test_pipeline_parity(
279279
sync_arr,
280280
err_msg="BatchedCodecPipeline could not correctly read SyncCodecPipeline's output",
281281
)
282+
283+
284+
# ---------------------------------------------------------------------------
285+
# Read parity: cover partial reads (not just full reads as in the matrix above)
286+
# ---------------------------------------------------------------------------
287+
288+
289+
def _read_selections(shape: tuple[int, ...]) -> list[tuple[str, Any]]:
290+
"""Selections that exercise the partial-decode path differently."""
291+
if len(shape) == 1:
292+
n = shape[0]
293+
return [
294+
("scalar-first", (0,)),
295+
("scalar-mid", (n // 2,)),
296+
("partial-slice", (slice(n // 4, 3 * n // 4),)),
297+
("strided", (slice(0, n, 3),)),
298+
("full", (slice(None),)),
299+
]
300+
return [
301+
("scalar-first", (0,) * len(shape)),
302+
("scalar-mid", tuple(s // 2 for s in shape)),
303+
("partial-slice", tuple(slice(s // 4, 3 * s // 4) for s in shape)),
304+
("full", (slice(None),) * len(shape)),
305+
]
306+
307+
308+
def _read_matrix() -> Iterator[Any]:
309+
for codec_id, codec_kwargs in CODEC_CONFIGS:
310+
for layout_id, layout in LAYOUT_CONFIGS:
311+
for sel_id, sel in _read_selections(layout["shape"]):
312+
yield pytest.param(
313+
codec_kwargs,
314+
layout,
315+
sel,
316+
id=f"{layout_id}-{codec_id}-{sel_id}",
317+
)
318+
319+
320+
@pytest.mark.parametrize(
321+
("codec_kwargs", "layout", "selection"),
322+
list(_read_matrix()),
323+
)
324+
def test_pipeline_read_parity(
325+
codec_kwargs: CodecConfig,
326+
layout: LayoutConfig,
327+
selection: Any,
328+
) -> None:
329+
"""Partial reads via SyncCodecPipeline must match BatchedCodecPipeline.
330+
331+
The full-write/full-read parity test above doesn't exercise partial
332+
reads (e.g. a single element from a sharded array), which take a
333+
different code path (``_decode_partial_single`` on the sharding
334+
codec). This test fills the array under one pipeline and reads
335+
arbitrary selections under both, asserting equality.
336+
"""
337+
# Fill under batched (the canonical pipeline) so the contents are
338+
# well-defined regardless of the codec under test.
339+
store, _full = _write_under_pipeline(
340+
_BATCHED, codec_kwargs, layout, _full_overwrite(layout["shape"]), True
341+
)
342+
343+
with zarr_config.set({"codec_pipeline.path": _BATCHED}):
344+
batched_arr = zarr.open_array(store=store, mode="r")[selection]
345+
with zarr_config.set({"codec_pipeline.path": _SYNC}):
346+
sync_arr = zarr.open_array(store=store, mode="r")[selection]
347+
348+
np.testing.assert_array_equal(
349+
sync_arr,
350+
batched_arr,
351+
err_msg=(
352+
f"SyncCodecPipeline read returned different result than BatchedCodecPipeline "
353+
f"for selection {selection!r}"
354+
),
355+
)

0 commit comments

Comments
 (0)