Skip to content

Commit 6028dd5

Browse files
thodson-usgsclaude
andcommitted
refactor(chunking): Flatten _fan_out_async into focused helpers
Cognitive-burden refactor of the async fan-out path; behavior unchanged. Five targeted cleanups: 1. Drop the ``Semaphore | None`` Optional + ``_bounded`` wrapper. The semaphore is now unconditional, sized to ``max_concurrent or sys.maxsize`` ("unbounded" is just a very-large counter). One ``async with semaphore`` everywhere; no branching. 2. Extract ``_probe_first`` and ``_fan_out_rest`` so ``_fan_out_async``'s body reads as ``probe -> check quota -> fan out rest -> combine`` instead of inlining the two try/except blocks. Each helper has a focused docstring + ``_Track`` type alias for the shared callable. 3. Simplify the post-gather exception walker with the walrus: ``if (interrupted := call.wrap_failure(exc)) is not None``. One assignment per iteration, no double-call. 4. Use the existing ``ChunkedCall.completed_chunks`` property in the progress tick instead of poking ``len(call._chunks)``. 5. Extract ``_execute_in_parallel`` from the decorator wrapper. The wrapper now reads as ``serial-or-parallel?`` in five lines; the helper owns the ``fetch_async is None`` and running-event-loop fallbacks (with their UserWarnings). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 070be77 commit 6028dd5

1 file changed

Lines changed: 138 additions & 107 deletions

File tree

dataretrieval/waterdata/chunking.py

Lines changed: 138 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import itertools
5454
import math
5555
import os
56+
import sys
5657
import warnings
5758
from collections.abc import Awaitable, Callable, Iterator
5859
from contextlib import contextmanager, suppress
@@ -247,6 +248,12 @@ def get_active_async_session() -> httpx.AsyncClient | None:
247248
_FetchOnceAsync = Callable[
248249
[dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]]
249250
]
251+
# A "tracked" sub-request issuer: takes ``(index, sub_args)``, issues
252+
# one sub-request, records the completion on the shared ``ChunkedCall``,
253+
# and ticks the progress reporter. The probe and fan-out helpers both
254+
# call into the same closure so bookkeeping happens exactly once per
255+
# success.
256+
_Track = Callable[[int, dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]]]
250257

251258

252259
class _RetryableTransportError(RuntimeError):
@@ -1429,6 +1436,51 @@ def _check_quota_remaining(self) -> None:
14291436
)
14301437

14311438

1439+
async def _probe_first(
1440+
call: ChunkedCall, sub_args: dict[str, Any], track: _Track
1441+
) -> tuple[pd.DataFrame, httpx.Response]:
1442+
"""
1443+
Issue sub-request 0 alone so its ``x-ratelimit-remaining`` header
1444+
can gate the rest of the plan before the burst goes out. A
1445+
transient failure here is routed through ``call.wrap_failure`` →
1446+
:class:`ChunkInterrupted`; non-transient failures re-raise so the
1447+
caller sees their original type.
1448+
"""
1449+
try:
1450+
return await track(0, sub_args)
1451+
except (RuntimeError, httpx.HTTPError) as exc:
1452+
if (interrupted := call.wrap_failure(exc)) is not None:
1453+
raise interrupted from exc
1454+
raise
1455+
1456+
1457+
async def _fan_out_rest(
1458+
call: ChunkedCall, sub_args_rest: list[dict[str, Any]], track: _Track
1459+
) -> None:
1460+
"""
1461+
Dispatch sub-requests 1..N-1 concurrently.
1462+
1463+
Completed pairs survive a sibling's transient failure via
1464+
``return_exceptions=True``, so the partial result stays
1465+
recoverable through :meth:`ChunkedCall.resume`. On any failure,
1466+
prefer raising the first *recognized transient* — so the user
1467+
still gets a resumable :class:`ChunkInterrupted` even when a
1468+
non-transient bug landed earlier in submission order. Fall back
1469+
to the first failure (preserving its type) when nothing is
1470+
transient.
1471+
"""
1472+
results = await asyncio.gather(
1473+
*(track(i, args) for i, args in enumerate(sub_args_rest, start=1)),
1474+
return_exceptions=True,
1475+
)
1476+
failures = [r for r in results if isinstance(r, BaseException)]
1477+
for exc in failures:
1478+
if (interrupted := call.wrap_failure(exc)) is not None:
1479+
raise interrupted from exc
1480+
if failures:
1481+
raise failures[0]
1482+
1483+
14321484
async def _fan_out_async(
14331485
plan: ChunkPlan,
14341486
fetch_once: _FetchOnce,
@@ -1443,24 +1495,27 @@ async def _fan_out_async(
14431495
The fan-out preserves the same safety contracts the serial
14441496
:class:`ChunkedCall` path provides:
14451497
1446-
* **Quota check.** The first sub-request is issued alone; its
1447-
``x-ratelimit-remaining`` header is read before any other
1448-
sub-request is dispatched. If the remaining plan can't fit the
1449-
window, :class:`RequestExceedsQuota` fires (matching
1450-
:meth:`ChunkedCall._check_quota_remaining`).
1451-
* **Resumable interruptions.** ``asyncio.gather`` runs with
1452-
``return_exceptions=True`` so completed sub-requests survive a
1453-
sibling's transient failure. On a recognized transient
1454-
(:class:`RateLimited`, :class:`ServiceUnavailable`) a
1455-
:class:`ChunkInterrupted` subclass is raised with ``.call`` set
1456-
to a :class:`ChunkedCall` carrying the completed sub-args as a
1457-
sparse index map. Calling ``exc.call.resume()`` re-issues only
1458-
the unfinished sub-requests, via the sync ``fetch_once`` path.
1459-
1460-
Bounded by an :class:`asyncio.Semaphore` when ``max_concurrent``
1461-
is set; unbounded otherwise. The shared client is published on
1498+
* **Quota check.** The first sub-request is issued alone via
1499+
:func:`_probe_first`; its ``x-ratelimit-remaining`` header is
1500+
read before any other sub-request is dispatched. If the
1501+
remaining plan can't fit the window,
1502+
:class:`RequestExceedsQuota` fires.
1503+
* **Resumable interruptions.** :func:`_fan_out_rest` runs
1504+
``asyncio.gather`` with ``return_exceptions=True`` so completed
1505+
sub-requests survive a sibling's transient failure. On a
1506+
recognized transient (:class:`RateLimited`,
1507+
:class:`ServiceUnavailable`) a :class:`ChunkInterrupted`
1508+
subclass is raised with ``.call`` set to a
1509+
:class:`ChunkedCall` carrying the sparse completed sub-args;
1510+
``exc.call.resume()`` re-issues only the unfinished ones via
1511+
the sync ``fetch_once`` path.
1512+
1513+
In-flight sub-requests are capped by an
1514+
:class:`asyncio.Semaphore`; ``max_concurrent=None`` ("unbounded")
1515+
uses ``sys.maxsize`` so every call site can take the same
1516+
``async with semaphore`` path. The shared client is published on
14621517
:data:`_chunked_async_session` so async paginated-loop helpers
1463-
downstream reuse its connection pool.
1518+
reuse its connection pool.
14641519
14651520
Parameters
14661521
----------
@@ -1497,57 +1552,37 @@ async def _fan_out_async(
14971552

14981553
# ``httpx.Limits()`` defaults to ``max_connections=100`` — at
14991554
# higher concurrency the pool would silently bottleneck the
1500-
# fan-out behind the connection cap. Pass an explicit cap that
1501-
# matches the semaphore, or ``None`` for truly unbounded.
1555+
# fan-out behind the connection cap. Match it to the semaphore,
1556+
# or ``None`` for truly unbounded.
15021557
limits = httpx.Limits(
15031558
max_connections=max_concurrent, max_keepalive_connections=max_concurrent
15041559
)
1505-
semaphore: asyncio.Semaphore | None = (
1506-
asyncio.Semaphore(max_concurrent) if max_concurrent is not None else None
1507-
)
1508-
1509-
async def _bounded(args: dict[str, Any]) -> tuple[pd.DataFrame, httpx.Response]:
1510-
if semaphore is None:
1511-
return await fetch_async(args)
1512-
async with semaphore:
1513-
return await fetch_async(args)
1514-
1560+
# ``sys.maxsize`` stands in for "unbounded": ``asyncio.Semaphore``
1561+
# only decrements a counter, never preallocates slots.
1562+
semaphore = asyncio.Semaphore(max_concurrent or sys.maxsize)
15151563
call = ChunkedCall(plan, fetch_once)
15161564

15171565
async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client:
15181566
with _publish_async_session(client):
1519-
# Record the total so the progress line can show
1520-
# ``chunk K/N``; ``_track`` bumps K as each sub-request
1521-
# completes so the parallel display advances chunk by
1522-
# chunk just like the serial path.
15231567
reporter = _progress.current()
15241568
if reporter is not None:
15251569
reporter.set_chunks(plan.total)
15261570

1527-
async def _track(offset: int, args: dict[str, Any]):
1528-
"""Issue one sub-request, record its result, and
1529-
report completion. Used for both the probe-first call
1530-
and the gathered fan-out so the record+report happen
1531-
exactly once per success. asyncio is single-threaded
1532-
within one event loop, so the record + len read
1533-
sequence is atomic at the scheduler level."""
1534-
result = await _bounded(args)
1571+
async def track(
1572+
offset: int, args: dict[str, Any]
1573+
) -> tuple[pd.DataFrame, httpx.Response]:
1574+
"""One sub-request + record + progress tick. asyncio
1575+
is single-threaded within an event loop, so the
1576+
record-then-read sequence is atomic at the scheduler
1577+
level."""
1578+
async with semaphore:
1579+
result = await fetch_async(args)
15351580
call.record(offset, result)
15361581
if reporter is not None:
1537-
reporter.start_chunk(len(call._chunks))
1582+
reporter.start_chunk(call.completed_chunks)
15381583
return result
15391584

1540-
# Probe-first: issue index 0 alone, check quota, then
1541-
# fan out the rest. A transient failure here surfaces as a
1542-
# ChunkInterrupted whose .call has no completed sub-requests
1543-
# yet, so .call.resume() re-issues the entire plan.
1544-
try:
1545-
first_pair = await _track(0, sub_args_list[0])
1546-
except (RuntimeError, httpx.HTTPError) as exc:
1547-
interrupted = call.wrap_failure(exc)
1548-
if interrupted is not None:
1549-
raise interrupted from exc
1550-
raise
1585+
first_pair = await _probe_first(call, sub_args_list[0], track)
15511586

15521587
if len(sub_args_list) > 1 and not _quota_check_disabled():
15531588
remaining = _read_remaining(first_pair[1])
@@ -1559,26 +1594,7 @@ async def _track(offset: int, args: dict[str, Any]):
15591594
call=call,
15601595
)
15611596

1562-
# Fan out the remaining sub-requests. Completed pairs
1563-
# survive a sibling's transient failure (``return_exceptions``),
1564-
# so the partial result stays recoverable through
1565-
# ``ChunkedCall.resume()``.
1566-
results = await asyncio.gather(
1567-
*(_track(off, a) for off, a in enumerate(sub_args_list[1:], start=1)),
1568-
return_exceptions=True,
1569-
)
1570-
exceptions = [r for r in results if isinstance(r, BaseException)]
1571-
# Prefer wrapping the first *recognized transient* failure
1572-
# so the user still gets a resumable ``ChunkInterrupted``
1573-
# even if a non-transient error happened to land first by
1574-
# submission order. Only if none of the failures is
1575-
# transient do we fall back to raising the first one.
1576-
for exc in exceptions:
1577-
interrupted = call.wrap_failure(exc)
1578-
if interrupted is not None:
1579-
raise interrupted from exc
1580-
if exceptions:
1581-
raise exceptions[0]
1597+
await _fan_out_rest(call, sub_args_list[1:], track)
15821598

15831599
ordered = call._ordered_chunks()
15841600
return (
@@ -1660,48 +1676,63 @@ def wrapper(
16601676
plan = ChunkPlan(args, build_request, limit)
16611677
concurrency = _read_concurrency_env()
16621678

1663-
# Stay on the sync path for trivial plans and explicit
1664-
# opt-outs. The remaining branches all need a wired
1665-
# ``fetch_async`` and a usable event loop.
1679+
# Trivial plans and explicit opt-outs stay on the sync
1680+
# path; ``_execute_in_parallel`` owns the rest of the
1681+
# serial/parallel decision (async wiring, running loop).
16661682
if plan.total <= 1 or concurrency == 1:
16671683
return plan.execute(fetch_once)
1668-
if fetch_async is None:
1669-
warnings.warn(
1670-
f"{_CONCURRENCY_ENV} is set to {concurrency} but this "
1671-
f"call site has no async fetch sibling wired; falling "
1672-
f"back to the serial path. Either set "
1673-
f"{_CONCURRENCY_ENV}=1 to silence this warning or pass "
1674-
f"fetch_async= to @multi_value_chunked.",
1675-
UserWarning,
1676-
stacklevel=2,
1677-
)
1678-
return plan.execute(fetch_once)
1679-
# ``asyncio.run`` raises ``RuntimeError`` if an event loop
1680-
# is already running (e.g. Jupyter / IPython kernels,
1681-
# async apps). Detect that case and fall back to the
1682-
# serial path with a one-time warning so notebook users
1683-
# don't see a confusing ``RuntimeError``.
1684-
if _running_event_loop() is not None:
1685-
warnings.warn(
1686-
"Detected a running asyncio event loop; the parallel "
1687-
f"chunker path cannot run inside one. Falling back to "
1688-
f"the serial path. Set {_CONCURRENCY_ENV}=1 to silence "
1689-
f"this warning.",
1690-
UserWarning,
1691-
stacklevel=2,
1692-
)
1693-
return plan.execute(fetch_once)
1694-
return asyncio.run(
1695-
_fan_out_async(
1696-
plan, fetch_once, fetch_async, max_concurrent=concurrency
1697-
)
1698-
)
1684+
return _execute_in_parallel(plan, fetch_once, fetch_async, concurrency)
16991685

17001686
return wrapper
17011687

17021688
return decorator
17031689

17041690

1691+
def _execute_in_parallel(
1692+
plan: ChunkPlan,
1693+
fetch_once: _FetchOnce,
1694+
fetch_async: _FetchOnceAsync | None,
1695+
concurrency: int | None,
1696+
) -> tuple[pd.DataFrame, httpx.Response]:
1697+
"""
1698+
Run ``plan`` on the parallel async path, falling back to the
1699+
serial sync path when the runtime can't host an event loop.
1700+
1701+
Falls back (with a one-time :class:`UserWarning`) when:
1702+
1703+
* ``fetch_async`` wasn't wired into the decorator, or
1704+
* an asyncio event loop is already running (Jupyter / IPython
1705+
kernels, async apps — ``asyncio.run`` would raise).
1706+
1707+
Otherwise opens a fresh event loop via :func:`asyncio.run` and
1708+
drives :func:`_fan_out_async`.
1709+
"""
1710+
if fetch_async is None:
1711+
warnings.warn(
1712+
f"{_CONCURRENCY_ENV} is set to {concurrency} but this "
1713+
f"call site has no async fetch sibling wired; falling "
1714+
f"back to the serial path. Either set "
1715+
f"{_CONCURRENCY_ENV}=1 to silence this warning or pass "
1716+
f"fetch_async= to @multi_value_chunked.",
1717+
UserWarning,
1718+
stacklevel=3,
1719+
)
1720+
return plan.execute(fetch_once)
1721+
if _running_event_loop() is not None:
1722+
warnings.warn(
1723+
"Detected a running asyncio event loop; the parallel "
1724+
f"chunker path cannot run inside one. Falling back to "
1725+
f"the serial path. Set {_CONCURRENCY_ENV}=1 to silence "
1726+
f"this warning.",
1727+
UserWarning,
1728+
stacklevel=3,
1729+
)
1730+
return plan.execute(fetch_once)
1731+
return asyncio.run(
1732+
_fan_out_async(plan, fetch_once, fetch_async, max_concurrent=concurrency)
1733+
)
1734+
1735+
17051736
def _running_event_loop() -> asyncio.AbstractEventLoop | None:
17061737
"""Return the active asyncio event loop, or ``None`` when none."""
17071738
try:

0 commit comments

Comments
 (0)