Skip to content

Commit 79d5b8f

Browse files
aldenksd-v-b
andauthored
Optimize partial shard reads (zarr-developers#3004)
* Add performance test of partial shard reads * WIP Consolidate reads of multiple chunks in the same shard Add test and make max gap and max coalesce size config options Code clarity and comments Test that chunk request coalescing reduces calls to store Profile a few values for coalesce_max_gap Update [doc]tests to include new sharding.read.* values document sharded read config options in user-guide/config.rst tweak logic: start new coalesced group if coalescing would exceed `coalesce_max_bytes` previous logic only started a new group if existing group was size already exceeded coalesce_max_bytes. set `mypy_path = "src"` to help pre-commit mypy find imported classes Reorder methods in sharding.py, add docstring + commenting wording docs fix docstring clarification trigger precommit on all python files changed in this pull request trying to get the ruff format that's happening locally during pre-commit to match the pre-commit run that is failing on CI. revert trigger for pre-commit ruff format * Add changes/3004.feature.rst * Consistently return None on failure and test partial shard read failure modes Use range of integers as out_selection not slice in CoordinateIndexer To fix issue when using vindex with repeated indexes in indexer test: improve formatting and add debugging breakpoint in array property tests test: disable hypothesis deadline for test_array_roundtrip to prevent timeout fix: initialize decode buffers with shard_spec.fill_value instead of 0 to fix partial shard holes style: reformat code for improved readability and consistency in sharding.py fix: revert incorrect RangeByteRequest length fix in sharding byte retrieval * Fix and test for case where some chunks in shard are all fill * Self review * Removing profiling code masquerading as a skipped test * revert change to indexing.py, not needed * Add test for duplicate integer indexing into a coalesced group * Undo change to fill value when initializing shard arrays * Undo change to set mypy_path = "src" * Commenting and revert uncessary changes to files for smaller diff * remove now redundant cast * Document runtime config keys * Improve changelog entry and .rst -> .md * .coords -> .chunk_coords in _ChunkCoordsByteSlice dataclass * Update test env in docs/contributing.md * Move `config.get` calls up into `_decode_partial_single` * Ensure no change in behavior when ByteGetter.get returns None + comment * Add test_sharing_unit.py, focusing on coallesce behavior, but with basic tests for other components * Fix typing errors in test_sharing_unit.py * Only use concurrent_map over chunks within a shard if > 1 groups after coalescing * no-op test change to retry CI after unavailable runner failure * Another no-op test change to retry CI after unavailable runner failure * use get_partial_values(), remove explicit coallescing and concurrent_map * cleanup * self review * Unit tests for new get_partial_values implementations * tests: work around mypy not seeing equivalent protocols as equivalent * Re-work to use Store.get_ranges. Simplify significantly. * Add missing assert in test_sharding.py * Remove unit tests that tested now-removed _ShardIndex.is_dense --------- Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
1 parent 7531de5 commit 79d5b8f

4 files changed

Lines changed: 828 additions & 17 deletions

File tree

changes/3004.feature.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Optimizes reading multiple chunks from a shard. Serial calls to `Store.get()`
2+
in the sharding codec have been replaced with a single call to
3+
`Store.get_ranges()`, which coalesces nearby byte ranges and fetches them
4+
concurrently.

src/zarr/codecs/sharding.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
parse_codecs,
6060
)
6161
from zarr.registry import get_ndbuffer_class, get_pipeline_class
62+
from zarr.storage._common import StorePath
6263
from zarr.storage._utils import _normalize_byte_range_index
6364

6465
if TYPE_CHECKING:
@@ -467,32 +468,26 @@ async def _decode_partial_single(
467468
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
468469

469470
# reading bytes of all requested chunks
470-
shard_dict: ShardMapping = {}
471+
shard_dict_maybe: ShardMapping | None
471472
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
472473
# read entire shard
473474
shard_dict_maybe = await self._load_full_shard_maybe(
474475
byte_getter=byte_getter,
475476
prototype=chunk_spec.prototype,
476477
chunks_per_shard=chunks_per_shard,
477478
)
478-
if shard_dict_maybe is None:
479-
return None
480-
shard_dict = shard_dict_maybe
481479
else:
482480
# read some chunks within the shard
483-
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
484-
if shard_index is None:
485-
return None
486-
shard_dict = {}
487-
for chunk_coords in all_chunk_coords:
488-
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
489-
if chunk_byte_slice:
490-
chunk_bytes = await byte_getter.get(
491-
prototype=chunk_spec.prototype,
492-
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
493-
)
494-
if chunk_bytes:
495-
shard_dict[chunk_coords] = chunk_bytes
481+
shard_dict_maybe = await self._load_partial_shard_maybe(
482+
byte_getter,
483+
chunk_spec.prototype,
484+
chunks_per_shard,
485+
all_chunk_coords,
486+
)
487+
488+
if shard_dict_maybe is None:
489+
return None
490+
shard_dict = shard_dict_maybe
496491

497492
# decoding chunks and writing them into the output buffer
498493
await self.codec_pipeline.read(
@@ -779,6 +774,66 @@ async def _load_full_shard_maybe(
779774
else None
780775
)
781776

777+
async def _load_partial_shard_maybe(
778+
self,
779+
byte_getter: ByteGetter,
780+
prototype: BufferPrototype,
781+
chunks_per_shard: tuple[int, ...],
782+
all_chunk_coords: set[tuple[int, ...]],
783+
) -> ShardMapping | None:
784+
"""
785+
Read chunks from `byte_getter` for the case where the read is less than a full shard.
786+
Returns a mapping of chunk coordinates to bytes or None.
787+
"""
788+
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
789+
if shard_index is None:
790+
return None
791+
792+
# Pair up chunks and their byte ranges as list[tuple[chunk_coord, byte_range]]
793+
chunk_coord_byte_ranges: list[tuple[tuple[int, ...], RangeByteRequest]] = []
794+
for chunk_coord in all_chunk_coords:
795+
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coord)
796+
if chunk_byte_slice is not None:
797+
chunk_coord_byte_ranges.append(
798+
(chunk_coord, RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]))
799+
)
800+
801+
if not chunk_coord_byte_ranges:
802+
return {}
803+
804+
shard_dict: ShardMutableMapping = {}
805+
if isinstance(byte_getter, StorePath):
806+
# External store: use Store.get_ranges for coalescing + concurrency.
807+
byte_ranges = [byte_range for _, byte_range in chunk_coord_byte_ranges]
808+
try:
809+
async for group in byte_getter.store.get_ranges(
810+
byte_getter.path, byte_ranges, prototype=prototype
811+
):
812+
for idx, buf in group:
813+
if buf is not None:
814+
chunk_coord, _ = chunk_coord_byte_ranges[idx]
815+
shard_dict[chunk_coord] = buf
816+
except BaseExceptionGroup as eg:
817+
# `Store.get_ranges` raises FileNotFoundError (wrapped in a
818+
# BaseExceptionGroup) if any underlying fetch indicates the key is
819+
# absent. The shard index loaded above, so this typically means a
820+
# race where the shard was deleted mid-read; treat it as "shard
821+
# gone" to match the index-missing branch (return None). Anything
822+
# else in the group (e.g. IO errors) is re-raised.
823+
_, rest = eg.split(FileNotFoundError)
824+
if rest is not None:
825+
raise rest from None
826+
return None
827+
else:
828+
# Any other ByteGetter. In practice only `_ShardingByteGetter` for
829+
# nested sharding, which slices an in-memory buffer (no I/O to coalesce).
830+
for chunk_coord, byte_range in chunk_coord_byte_ranges:
831+
buf = await byte_getter.get(prototype, byte_range)
832+
if buf is not None:
833+
shard_dict[chunk_coord] = buf
834+
835+
return shard_dict
836+
782837
def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int:
783838
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
784839
return input_byte_length + self._shard_index_size(chunks_per_shard)

0 commit comments

Comments
 (0)