Skip to content

Commit 1d0acb7

Browse files
committed
refactor(waterdata): maintainer polish for chunk retries
1 parent d2bf71f commit 1d0acb7

4 files changed

Lines changed: 235 additions & 91 deletions

File tree

dataretrieval/waterdata/_progress.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,17 @@ def add_page(self, rows: int = 0) -> None:
157157
def note_retry(self, *, attempt: int, wait: float) -> None:
158158
"""Show that a sub-request is backing off before retry ``attempt``.
159159
160-
Cleared by the next :meth:`add_page` / :meth:`start_chunk` so the
161-
line returns to normal progress once the retry succeeds.
160+
Cleared by the next :meth:`add_page` / :meth:`start_chunk` (or by
161+
:meth:`close`) so the line returns to normal once the retry resolves.
162162
"""
163-
self.retry_note = f"retrying (attempt {attempt}, waiting {wait:.0f}s)"
163+
# Keep sub-second waits explicit (avoid misleading ``0s``) while
164+
# rendering whole-second waits without unnecessary ``.0`` noise.
165+
wait_1dp = round(wait, 1)
166+
if wait_1dp < 1 or not wait_1dp.is_integer():
167+
secs = f"{wait_1dp:.1f}s"
168+
else:
169+
secs = f"{wait_1dp:.0f}s"
170+
self.retry_note = f"retrying (attempt {attempt}, waiting {secs})"
164171
self._render()
165172

166173
def set_rate_remaining(
@@ -225,6 +232,13 @@ def close(self) -> None:
225232
"""
226233
if self._closed:
227234
return
235+
# A retry note set during the final backoff would otherwise freeze as
236+
# the persisted last line of a call that has since completed or given
237+
# up; clear it and redraw (while still un-closed, so ``_render`` runs)
238+
# so the final state isn't a stale "retrying".
239+
if self.enabled and self._rendered and self.retry_note is not None:
240+
self.retry_note = None
241+
self._render()
228242
self._closed = True
229243
if not (self.enabled and self._rendered):
230244
return

dataretrieval/waterdata/chunking.py

Lines changed: 132 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from contextvars import ContextVar
6363
from dataclasses import dataclass
6464
from datetime import timedelta
65-
from typing import Any, ClassVar
65+
from typing import Any, ClassVar, TypeVar
6666
from urllib.parse import quote_plus
6767

6868
import httpx
@@ -166,13 +166,14 @@ def _read_concurrency_env() -> int | None:
166166
return value
167167

168168

169-
# Retry-with-backoff for transient sub-request failures (429 / 5xx /
170-
# connect-read timeouts). The env var is read at call time so test
171-
# ``monkeypatch.setenv`` takes effect; the timing constants are
172-
# module-level so power users (and tests) can ``monkeypatch.setattr``
173-
# them. Defaults: 4 retries, 0.5s base doubling under full jitter up to
174-
# a 30s per-attempt ceiling, and honor a server ``Retry-After`` up to
175-
# 60s before escalating to a resumable interruption instead.
169+
# Retry-with-backoff defaults for transient sub-request failures (429 /
170+
# 5xx / connect-read timeouts). All four are resolved at call time by
171+
# ``RetryPolicy.from_env`` (the env var via ``monkeypatch.setenv``, the
172+
# timing constants via ``monkeypatch.setattr`` on this module), so both
173+
# are overridable in tests and by power users. Defaults: 4 retries, 0.5s
174+
# base doubling under full jitter up to a 30s per-attempt ceiling, and
175+
# honor a server ``Retry-After`` up to 60s before escalating to a
176+
# resumable interruption instead.
176177
_RETRIES_ENV = "API_USGS_RETRIES"
177178
_RETRIES_DEFAULT = 4
178179
_RETRY_BASE_BACKOFF = 0.5
@@ -237,10 +238,31 @@ class RetryPolicy:
237238
max_backoff: float = _RETRY_MAX_BACKOFF
238239
retry_after_cap: float = _RETRY_AFTER_CAP
239240

241+
def __post_init__(self) -> None:
242+
# Guard the value object's own invariants so a misconfiguration
243+
# fails loudly at construction rather than as a downstream
244+
# ``time.sleep`` ValueError (negative delay) or a silent
245+
# asyncio.sleep-treats-negative-as-zero divergence.
246+
if self.max_retries < 0:
247+
raise ValueError(f"max_retries must be >= 0 (got {self.max_retries}).")
248+
if self.base_backoff < 0 or self.max_backoff < 0 or self.retry_after_cap < 0:
249+
raise ValueError("retry backoff settings must be non-negative.")
250+
240251
@classmethod
241252
def from_env(cls) -> RetryPolicy:
242-
"""Build a policy, resolving ``max_retries`` from ``API_USGS_RETRIES``."""
243-
return cls(max_retries=_read_retries_env())
253+
"""Build a policy from the module-level defaults, resolved now.
254+
255+
``max_retries`` comes from ``API_USGS_RETRIES``; the timing knobs
256+
are read from the ``_RETRY_*`` module constants at call time (not
257+
the dataclass field defaults, which freeze at class definition) so
258+
``monkeypatch.setattr`` on those constants takes effect.
259+
"""
260+
return cls(
261+
max_retries=_read_retries_env(),
262+
base_backoff=_RETRY_BASE_BACKOFF,
263+
max_backoff=_RETRY_MAX_BACKOFF,
264+
retry_after_cap=_RETRY_AFTER_CAP,
265+
)
244266

245267
def should_retry(self, attempt: int, retry_after: float | None) -> bool:
246268
"""Whether a just-failed ``attempt`` (1-based) warrants another try.
@@ -276,42 +298,36 @@ def backoff(self, attempt: int, retry_after: float | None) -> float:
276298
"_chunked_client", default=None
277299
)
278300

279-
# Async sibling of ``_chunked_client``. Published by
280-
# ``_publish_async_client`` during ``_fan_out_async`` so async
281-
# paginated-loop helpers reuse one ``httpx.AsyncClient`` (and its
282-
# connection pool) across every concurrent sub-request of a single
283-
# chunked call.
301+
# Async sibling of ``_chunked_client``. Published (via :func:`_publish`)
302+
# during ``_fan_out_async`` so async paginated-loop helpers reuse one
303+
# ``httpx.AsyncClient`` (and its connection pool) across every concurrent
304+
# sub-request of a single chunked call.
284305
_chunked_async_client: ContextVar[httpx.AsyncClient | None] = ContextVar(
285306
"_chunked_async_client", default=None
286307
)
287308

288-
289-
@contextmanager
290-
def _publish_client(client: httpx.Client) -> Iterator[None]:
291-
"""
292-
Make ``client`` visible to :func:`get_active_client` for the
293-
duration of the ``with`` block via the ``_chunked_client``
294-
ContextVar. Wraps the set/reset token dance so callers don't have to.
295-
"""
296-
token = _chunked_client.set(client)
297-
try:
298-
yield
299-
finally:
300-
_chunked_client.reset(token)
309+
_ClientT = TypeVar("_ClientT")
301310

302311

303312
@contextmanager
304-
def _publish_async_client(client: httpx.AsyncClient) -> Iterator[None]:
313+
def _publish(var: ContextVar[_ClientT | None], client: _ClientT) -> Iterator[None]:
305314
"""
306-
Make ``client`` visible to :func:`get_active_async_client` for the
307-
duration of the ``with`` block. Async sibling of
308-
:func:`_publish_client`.
315+
Bind ``client`` to the ContextVar ``var`` for the duration of the
316+
``with`` block (wrapping the set/reset token dance), so paginated-loop
317+
helpers can borrow the chunker's shared client via
318+
:func:`get_active_client` / :func:`get_active_async_client`.
319+
320+
Generic over the client type so the sync (:class:`httpx.Client` via
321+
``_chunked_client``) and async (:class:`httpx.AsyncClient` via
322+
``_chunked_async_client``) paths share one implementation, while the
323+
``_ClientT`` type var still lets a type checker reject a var/client
324+
type mismatch.
309325
"""
310-
token = _chunked_async_client.set(client)
326+
token = var.set(client)
311327
try:
312328
yield
313329
finally:
314-
_chunked_async_client.reset(token)
330+
var.reset(token)
315331

316332

317333
def get_active_client() -> httpx.Client | None:
@@ -325,8 +341,8 @@ def get_active_client() -> httpx.Client | None:
325341
Returns
326342
-------
327343
httpx.Client or None
328-
The client published by :func:`_publish_client` if currently
329-
inside a :class:`ChunkedCall` ``resume`` block; ``None`` otherwise.
344+
The client published via :func:`_publish` if currently inside a
345+
:class:`ChunkedCall` ``resume`` block; ``None`` otherwise.
330346
"""
331347
return _chunked_client.get()
332348

@@ -1069,27 +1085,29 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]:
10691085
"""
10701086
Decide whether ``exc`` is a transient worth an automatic retry.
10711087
1072-
Narrower than :func:`_classify_chunk_error`: it retries rate limits
1073-
(429), service errors (5xx), and genuine transport transients
1074-
(:class:`httpx.TransportError` — ``ConnectError``, ``ReadTimeout``, …)
1075-
but NOT :class:`httpx.InvalidURL` (a too-long server cursor URL won't
1076-
fix on retry, though it stays *resumable*). Walks the ``__cause__``
1077-
chain because ``_walk_pages`` re-wraps mid-pagination failures as
1078-
``RuntimeError``.
1088+
Inspects only the *top-level* exception, by design — and so is
1089+
deliberately narrower than :func:`_classify_chunk_error`, which walks
1090+
the ``__cause__`` chain for resumability. ``_paginate`` raises an
1091+
initial-request transient (429 / 5xx / :class:`httpx.TransportError`
1092+
such as ``ConnectError`` / ``ReadTimeout``) *raw*, but re-wraps any
1093+
mid-pagination failure as a ``RuntimeError``. Retrying only the raw,
1094+
top-level transient means we re-issue a sub-request that made no
1095+
progress (cheap), while a failure after partial pagination escalates
1096+
to the resumable :class:`ChunkInterrupted` instead of being re-walked
1097+
from page 1 — which would re-spend the very quota that was exhausted.
1098+
``httpx.InvalidURL`` is excluded (a too-long cursor won't fix on
1099+
retry), and it only ever arises on a follow-up page anyway.
10791100
10801101
Returns
10811102
-------
10821103
tuple[bool, float or None]
10831104
``(retryable, retry_after)`` — the server ``Retry-After`` hint
10841105
(seconds) when the transient carried one, else ``None``.
10851106
"""
1086-
cur: BaseException | None = exc
1087-
while cur is not None:
1088-
if isinstance(cur, (RateLimited, ServiceUnavailable)):
1089-
return True, cur.retry_after
1090-
if isinstance(cur, httpx.TransportError):
1091-
return True, None
1092-
cur = cur.__cause__
1107+
if isinstance(exc, (RateLimited, ServiceUnavailable)):
1108+
return True, exc.retry_after
1109+
if isinstance(exc, httpx.TransportError):
1110+
return True, None
10931111
return False, None
10941112

10951113

@@ -1334,13 +1352,19 @@ def __init__(
13341352
# subsequent ``resume()`` only re-issues the missing indices.
13351353
# On the serial path this fills contiguously from 0.
13361354
self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {}
1355+
# Explicit completion order for response-header aggregation.
1356+
# Keeping this separate from ``_chunks`` avoids coupling that
1357+
# behavior to dict insertion semantics or future write patterns.
1358+
self._completion_order: list[int] = []
13371359

13381360
def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None:
13391361
"""Record a completed sub-request's ``(frame, response)`` pair
13401362
under its sub-args index. Used by both the serial loop in
13411363
:meth:`resume` and the parallel fan-out in
13421364
:func:`_fan_out_async` so the completion set stays
13431365
encapsulated."""
1366+
if index not in self._chunks:
1367+
self._completion_order.append(index)
13441368
self._chunks[index] = pair
13451369

13461370
def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None:
@@ -1369,6 +1393,27 @@ def completed_chunks(self) -> int:
13691393
def _ordered_chunks(self) -> list[tuple[pd.DataFrame, httpx.Response]]:
13701394
return [self._chunks[i] for i in sorted(self._chunks)]
13711395

1396+
def _responses_by_completion(self) -> list[httpx.Response]:
1397+
# The final element is the most-recently completed sub-request, whose
1398+
# headers carry the freshest ``x-ratelimit-remaining`` for aggregation.
1399+
return [self._chunks[i][1] for i in self._completion_order]
1400+
1401+
def combined(self) -> tuple[pd.DataFrame, httpx.Response]:
1402+
"""Combine every recorded sub-request into one ``(frame, response)``.
1403+
1404+
Frames concatenate in sub-args *index* order (deterministic,
1405+
independent of parallel completion order); the aggregated response
1406+
takes its headers from the most-recently-*completed* sub-request, so
1407+
a fan-out that finished chunks out of index order still surfaces the
1408+
latest rate-limit state the server reported rather than a stale one.
1409+
"""
1410+
return (
1411+
_combine_chunk_frames([frame for frame, _ in self._ordered_chunks()]),
1412+
_combine_chunk_responses(
1413+
self._responses_by_completion(), self.plan.canonical_url
1414+
),
1415+
)
1416+
13721417
@property
13731418
def partial_frame(self) -> pd.DataFrame:
13741419
"""
@@ -1405,7 +1450,7 @@ def partial_response(self) -> httpx.Response | None:
14051450
if not self._chunks:
14061451
return None
14071452
return _combine_chunk_responses(
1408-
[resp for _, resp in self._ordered_chunks()], self.plan.canonical_url
1453+
self._responses_by_completion(), self.plan.canonical_url
14091454
)
14101455

14111456
def resume(self) -> tuple[pd.DataFrame, httpx.Response]:
@@ -1443,23 +1488,18 @@ def resume(self) -> tuple[pd.DataFrame, httpx.Response]:
14431488
is on ``exc.call`` — wait for the underlying condition to
14441489
clear and call ``exc.call.resume()`` again.
14451490
"""
1446-
with httpx.Client(**HTTPX_DEFAULTS) as client, _publish_client(client):
1447-
reporter = _progress.current()
1448-
if reporter is not None:
1449-
reporter.set_chunks(self.plan.total)
1450-
for i, sub_args in enumerate(self.plan.iter_sub_args()):
1451-
if i in self._chunks:
1452-
continue
1491+
with httpx.Client(**HTTPX_DEFAULTS) as client:
1492+
with _publish(_chunked_client, client):
1493+
reporter = _progress.current()
14531494
if reporter is not None:
1454-
reporter.start_chunk(i + 1)
1455-
self._issue(i, sub_args)
1456-
ordered = self._ordered_chunks()
1457-
frames = [frame for frame, _ in ordered]
1458-
responses = [resp for _, resp in ordered]
1459-
return (
1460-
_combine_chunk_frames(frames),
1461-
_combine_chunk_responses(responses, self.plan.canonical_url),
1462-
)
1495+
reporter.set_chunks(self.plan.total)
1496+
for i, sub_args in enumerate(self.plan.iter_sub_args()):
1497+
if i in self._chunks:
1498+
continue
1499+
if reporter is not None:
1500+
reporter.start_chunk(i + 1)
1501+
self._issue(i, sub_args)
1502+
return self.combined()
14631503

14641504
def _issue(self, index: int, sub_args: dict[str, Any]) -> None:
14651505
"""
@@ -1556,13 +1596,17 @@ async def _fan_out_async(
15561596
limits = httpx.Limits(
15571597
max_connections=max_concurrent, max_keepalive_connections=max_concurrent
15581598
)
1559-
# ``sys.maxsize`` stands in for "unbounded": ``asyncio.Semaphore``
1560-
# only decrements a counter, never preallocates slots.
1561-
semaphore = asyncio.Semaphore(max_concurrent or sys.maxsize)
1599+
# ``None`` means "unbounded"; ``sys.maxsize`` stands in for it since
1600+
# ``asyncio.Semaphore`` only decrements a counter, never preallocates
1601+
# slots. Test ``is None`` explicitly so a stray ``0`` isn't silently
1602+
# promoted to unbounded by a falsy-``or``.
1603+
semaphore = asyncio.Semaphore(
1604+
sys.maxsize if max_concurrent is None else max_concurrent
1605+
)
15621606
call = ChunkedCall(plan, fetch_once, retry_policy)
15631607

15641608
async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client:
1565-
with _publish_async_client(client):
1609+
with _publish(_chunked_async_client, client):
15661610
reporter = _progress.current()
15671611
if reporter is not None:
15681612
reporter.set_chunks(plan.total)
@@ -1586,15 +1630,16 @@ async def track(
15861630
# Dispatch every sub-request concurrently. ``return_exceptions``
15871631
# keeps completed pairs after a sibling fails, so partial state
15881632
# stays recoverable via ``ChunkedCall.resume()``. Failure
1589-
# precedence:
1633+
# precedence, in order:
15901634
# 1. Cancellation / interrupt signals (CancelledError,
15911635
# KeyboardInterrupt, SystemExit — non-Exception) propagate
15921636
# unmodified; wrapping them as a transient would swallow the
15931637
# user's stop signal.
1594-
# 2. Recognized transients wrap as ChunkInterrupted so the user
1595-
# gets a resumable handle even when a non-transient failure
1596-
# landed earlier in submission order.
1597-
# 3. Otherwise re-raise the first failure, preserving its type.
1638+
# 2. A non-transient failure (a real bug — unrecognized by
1639+
# ``wrap_failure``) surfaces raw, so it isn't masked behind a
1640+
# resumable handle for a transient sibling that landed later.
1641+
# 3. Only when every failure is a recognized transient do we
1642+
# raise the first as a resumable ``ChunkInterrupted``.
15981643
results = await asyncio.gather(
15991644
*(track(i, args) for i, args in enumerate(sub_args_list)),
16001645
return_exceptions=True,
@@ -1603,17 +1648,18 @@ async def track(
16031648
for exc in failures:
16041649
if not isinstance(exc, Exception):
16051650
raise exc
1651+
first_transient: tuple[ChunkInterrupted, BaseException] | None = None
16061652
for exc in failures:
1607-
if (interrupted := call.wrap_failure(exc)) is not None:
1608-
raise interrupted from exc
1609-
if failures:
1610-
raise failures[0]
1611-
1612-
ordered = call._ordered_chunks()
1613-
return (
1614-
_combine_chunk_frames([df for df, _ in ordered]),
1615-
_combine_chunk_responses([resp for _, resp in ordered], plan.canonical_url),
1616-
)
1653+
interrupted = call.wrap_failure(exc)
1654+
if interrupted is None:
1655+
raise exc
1656+
if first_transient is None:
1657+
first_transient = (interrupted, exc)
1658+
if first_transient is not None:
1659+
interrupted, exc = first_transient
1660+
raise interrupted from exc
1661+
1662+
return call.combined()
16171663

16181664

16191665
def multi_value_chunked(

0 commit comments

Comments
 (0)