|
| 1 | +# src/zarr/core/_coalesce.py |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import asyncio |
| 5 | +from typing import TYPE_CHECKING, NamedTuple |
| 6 | + |
| 7 | +from zarr.abc.store import RangeByteRequest |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence |
| 11 | + |
| 12 | + from zarr.abc.store import ByteRequest |
| 13 | + from zarr.core.buffer import Buffer |
| 14 | + |
| 15 | + |
| 16 | +class _WorkerCtx(NamedTuple): |
| 17 | + """Shared state passed to the per-task worker coroutines. |
| 18 | +
|
| 19 | + Bundling these lets the workers declare their dependencies as one |
| 20 | + parameter instead of capturing them implicitly via closure. |
| 21 | + """ |
| 22 | + |
| 23 | + fetch: Callable[[ByteRequest | None], Awaitable[Buffer | None]] |
| 24 | + semaphore: asyncio.Semaphore |
| 25 | + |
| 26 | + |
| 27 | +async def _fetch_single( |
| 28 | + ctx: _WorkerCtx, idx: int, req: ByteRequest | None |
| 29 | +) -> Sequence[tuple[int, Buffer | None]]: |
| 30 | + """Fetch one byte range. Raises FileNotFoundError if the key is absent.""" |
| 31 | + async with ctx.semaphore: |
| 32 | + buf = await ctx.fetch(req) |
| 33 | + if buf is None: |
| 34 | + raise FileNotFoundError |
| 35 | + return ((idx, buf),) |
| 36 | + |
| 37 | + |
| 38 | +async def _fetch_group( |
| 39 | + ctx: _WorkerCtx, members: list[tuple[int, RangeByteRequest]] |
| 40 | +) -> Sequence[tuple[int, Buffer | None]]: |
| 41 | + """Fetch one merged byte range and slice it back into per-input buffers. |
| 42 | +
|
| 43 | + `members` must already be sorted by `start`; callers in this module |
| 44 | + build it from the sorted mergeable list. Raises `FileNotFoundError` |
| 45 | + if the key is absent. |
| 46 | + """ |
| 47 | + if len(members) == 1: |
| 48 | + solo_idx, solo_req = members[0] |
| 49 | + return await _fetch_single(ctx, solo_idx, solo_req) |
| 50 | + |
| 51 | + start = members[0][1].start |
| 52 | + end = max(r.end for _, r in members) |
| 53 | + async with ctx.semaphore: |
| 54 | + big = await ctx.fetch(RangeByteRequest(start, end)) |
| 55 | + if big is None: |
| 56 | + raise FileNotFoundError |
| 57 | + sliced = [(idx, big[r.start - start : r.end - start]) for idx, r in members] |
| 58 | + return tuple(sliced) |
| 59 | + |
| 60 | + |
| 61 | +def coalesce_ranges( |
| 62 | + byte_ranges: Sequence[ByteRequest | None], |
| 63 | + *, |
| 64 | + max_gap_bytes: int, |
| 65 | + max_coalesced_bytes: int, |
| 66 | +) -> tuple[ |
| 67 | + list[list[tuple[int, RangeByteRequest]]], |
| 68 | + list[tuple[int, ByteRequest | None]], |
| 69 | +]: |
| 70 | + """Plan a set of byte-range fetches: which inputs merge, which stand alone. |
| 71 | +
|
| 72 | + Pure (no I/O). The result is the I/O plan a caller would execute: each |
| 73 | + group corresponds to one fetch of a coalesced byte range, and each |
| 74 | + uncoalescable item corresponds to one fetch of the original request. |
| 75 | +
|
| 76 | + All tuning knobs are required keyword arguments. `Store.get_ranges` is |
| 77 | + the public entry point and owns the canonical default values; this |
| 78 | + function takes them explicitly to avoid duplicating policy. |
| 79 | +
|
| 80 | + Parameters |
| 81 | + ---------- |
| 82 | + byte_ranges |
| 83 | + Input ranges. `None` means "the whole value". |
| 84 | + max_gap_bytes |
| 85 | + Two `RangeByteRequest`s separated by at most this many bytes may be |
| 86 | + merged into one fetch. |
| 87 | + max_coalesced_bytes |
| 88 | + Upper bound on the size of a single merged fetch. |
| 89 | +
|
| 90 | + Returns |
| 91 | + ------- |
| 92 | + groups |
| 93 | + List of merged groups. Each group is a list of |
| 94 | + `(input_index, RangeByteRequest)` pairs sorted by `start`. A |
| 95 | + single-element group represents a `RangeByteRequest` that did not |
| 96 | + merge with any neighbor. |
| 97 | + uncoalescable |
| 98 | + List of `(input_index, request)` pairs for inputs that are not |
| 99 | + `RangeByteRequest` (`OffsetByteRequest`, `SuffixByteRequest`, |
| 100 | + `None`). Indices are preserved from the input order. |
| 101 | +
|
| 102 | + Notes |
| 103 | + ----- |
| 104 | + Only `RangeByteRequest` inputs participate in coalescing. Two ranges |
| 105 | + merge when both: their gap (next `start` minus current group's running |
| 106 | + `end`) is `<= max_gap_bytes`, and the resulting merged span is |
| 107 | + `<= max_coalesced_bytes`. |
| 108 | + """ |
| 109 | + indexed = list(enumerate(byte_ranges)) |
| 110 | + mergeable = [(i, r) for i, r in indexed if isinstance(r, RangeByteRequest)] |
| 111 | + uncoalescable: list[tuple[int, ByteRequest | None]] = [ |
| 112 | + (i, r) for i, r in indexed if not isinstance(r, RangeByteRequest) |
| 113 | + ] |
| 114 | + |
| 115 | + # Sort mergeables by start offset, then merge. Track running start/end of the |
| 116 | + # current group so each merge step is O(1) instead of O(group size). |
| 117 | + mergeable.sort(key=lambda pair: pair[1].start) |
| 118 | + groups: list[list[tuple[int, RangeByteRequest]]] = [] |
| 119 | + group_start = 0 |
| 120 | + group_end = 0 |
| 121 | + for pair in mergeable: |
| 122 | + _i, r = pair |
| 123 | + if groups and r.start - group_end <= max_gap_bytes: |
| 124 | + prospective_end = max(group_end, r.end) |
| 125 | + if prospective_end - group_start <= max_coalesced_bytes: |
| 126 | + groups[-1].append(pair) |
| 127 | + group_end = prospective_end |
| 128 | + continue |
| 129 | + groups.append([pair]) |
| 130 | + group_start = r.start |
| 131 | + group_end = r.end |
| 132 | + |
| 133 | + return groups, uncoalescable |
| 134 | + |
| 135 | + |
| 136 | +async def coalesced_get( |
| 137 | + fetch: Callable[[ByteRequest | None], Awaitable[Buffer | None]], |
| 138 | + byte_ranges: Sequence[ByteRequest | None], |
| 139 | + *, |
| 140 | + max_concurrency: int, |
| 141 | + max_gap_bytes: int, |
| 142 | + max_coalesced_bytes: int, |
| 143 | +) -> AsyncGenerator[Sequence[tuple[int, Buffer | None]]]: |
| 144 | + """Read many byte ranges through `fetch` with coalescing and concurrency. |
| 145 | +
|
| 146 | + Nearby ranges are merged into a single underlying I/O, and merged fetches |
| 147 | + are run concurrently. Each yield corresponds to exactly one underlying I/O |
| 148 | + operation: a sequence of `(input_index, result)` tuples for all input |
| 149 | + ranges served by that I/O. Tuples within a yielded sequence are ordered by |
| 150 | + start offset. Yields across groups are in completion order, not input |
| 151 | + order. |
| 152 | +
|
| 153 | + All tuning knobs are required keyword arguments. `Store.get_ranges` is |
| 154 | + the public entry point and owns the canonical default values; this |
| 155 | + function takes them explicitly to avoid duplicating policy. |
| 156 | +
|
| 157 | + Parameters |
| 158 | + ---------- |
| 159 | + fetch |
| 160 | + Callable that reads one byte range and returns a `Buffer` (or `None` |
| 161 | + if the underlying key does not exist). Typically constructed via |
| 162 | + `functools.partial(store.get, key, prototype)`. |
| 163 | + byte_ranges |
| 164 | + Input ranges. `None` means "the whole value". |
| 165 | + max_concurrency |
| 166 | + Maximum number of merged fetches in flight at once. |
| 167 | + max_gap_bytes |
| 168 | + Forwarded to `coalesce_ranges`. |
| 169 | + max_coalesced_bytes |
| 170 | + Forwarded to `coalesce_ranges`. |
| 171 | +
|
| 172 | + Yields |
| 173 | + ------ |
| 174 | + Sequence[tuple[int, Buffer | None]] |
| 175 | + Per-I/O batch of `(input_index, result)` tuples. |
| 176 | +
|
| 177 | + Notes |
| 178 | + ----- |
| 179 | + - Only `RangeByteRequest` inputs are coalesced. `OffsetByteRequest`, |
| 180 | + `SuffixByteRequest`, and `None` are each treated as uncoalescable |
| 181 | + (one fetch, one single-tuple yield per input). |
| 182 | + - Failures from underlying fetches surface as a `BaseExceptionGroup` |
| 183 | + (PEP 654). Inner exceptions include `FileNotFoundError` if a fetch |
| 184 | + returns `None`, plus any exception `fetch` raises. Pending fetches are |
| 185 | + cancelled as soon as one task fails, so the group typically contains a |
| 186 | + single non-`CancelledError` exception even under high concurrency. |
| 187 | + - Groups completed before the failure remain observable on the yields |
| 188 | + preceding the raise. |
| 189 | + - `GeneratorExit` raised by `aclose()` is filtered out so the iterator |
| 190 | + closes cleanly; callers don't see a group containing only it. |
| 191 | + """ |
| 192 | + if not byte_ranges: |
| 193 | + return |
| 194 | + |
| 195 | + groups, singles = coalesce_ranges( |
| 196 | + byte_ranges, |
| 197 | + max_gap_bytes=max_gap_bytes, |
| 198 | + max_coalesced_bytes=max_coalesced_bytes, |
| 199 | + ) |
| 200 | + |
| 201 | + ctx = _WorkerCtx(fetch=fetch, semaphore=asyncio.Semaphore(max_concurrency)) |
| 202 | + |
| 203 | + # Launch all work as tasks. The semaphore bounds actual I/O concurrency. |
| 204 | + # TaskGroup wraps task exceptions in BaseExceptionGroup; we propagate the |
| 205 | + # group unchanged as part of the public contract (callers handle batch |
| 206 | + # failures via `except*` / PEP 654). GeneratorExit (raised when the |
| 207 | + # consumer calls aclose()) is filtered out so close completes cleanly. |
| 208 | + try: |
| 209 | + async with asyncio.TaskGroup() as tg: |
| 210 | + tasks = [ |
| 211 | + *(tg.create_task(_fetch_group(ctx, group)) for group in groups), |
| 212 | + *(tg.create_task(_fetch_single(ctx, i, single)) for i, single in singles), |
| 213 | + ] |
| 214 | + |
| 215 | + for fut in asyncio.as_completed(tasks): |
| 216 | + yield await fut |
| 217 | + except BaseExceptionGroup as eg: |
| 218 | + # Strip GeneratorExits (consumer aclose()) and propagate whatever remains. |
| 219 | + _, other_errors = eg.split(GeneratorExit) |
| 220 | + |
| 221 | + if other_errors is not None: |
| 222 | + raise other_errors from None |
0 commit comments