Skip to content

Commit 9345743

Browse files
thodson-usgsclaude
andcommitted
feat(wateruse): async fan-out over a shared client; surface final rate-limit
Replace the ThreadPoolExecutor fan-out with an asyncio implementation: one shared `httpx.AsyncClient` paginates each location, `asyncio.gather` (bounded by a semaphore at `MAX_CONCURRENT_REQUESTS`) fans the locations out, and input order is preserved for a deterministic concat. The single client keeps connections alive across pages and locations (the old per-call `httpx.get` opened a fresh connection every page). The event loop runs in a worker thread, so it is safe even when called inside an already-running loop (Jupyter) — a bare `asyncio.run` would raise there. `md.header` now surfaces the *final* rate-limit headers — the response with the lowest `x-ratelimit-remaining` (the quota left after the whole fan-out) — plus cumulative elapsed, instead of the first request's values. (The OGC engine already aggregates this way via `_aggregate_paginated_response` / `_combine_chunk_responses`, so only wateruse needed the fix.) Reuse: the genuinely-shared, low-coupling primitives only (`_default_headers`, `_raise_for_status(detail_from=...)` keeping NWDC's `{detail}` errors, `_network_error`, `BaseMetadata`, `HTTPX_DEFAULTS`). Deliberately NOT the OGC `_paginate` — it hardcodes `_raise_for_non_200` (the `{code, description}` envelope, wrong for NWDC's `{detail}`) and is entangled with the engine's context vars; the CSV/Link pager is ~20 lines locally. The sync→async bridge is stdlib (`asyncio.run` in a worker thread), not anyio, which isn't a declared dependency. Verified live: single/paginated/fan-out results unchanged and order-stable, the final (lowest) rate-limit header surfaces, `{detail}` errors preserved, and calls succeed inside a running event loop. Offline tests cover the rate-limit aggregation; 31 pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01Sjb14HkwuCydKSKMsaXsgd
1 parent 63a13ba commit 9345743

2 files changed

Lines changed: 137 additions & 55 deletions

File tree

dataretrieval/wateruse.py

Lines changed: 102 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
from __future__ import annotations
4141

42+
import asyncio
43+
import copy
4244
import io
4345
import logging
4446
from collections.abc import Iterable
@@ -54,7 +56,7 @@
5456
HTTPX_DEFAULTS,
5557
BaseMetadata,
5658
_default_headers,
57-
_get,
59+
_network_error,
5860
_raise_for_status,
5961
to_str,
6062
)
@@ -216,27 +218,17 @@ def get_wateruse(
216218
base_params = {k: v for k, v in base_params.items() if v is not None}
217219

218220
# The NWDC queries one location per request, so fan a multi-value selector
219-
# out into a request per location and concatenate the results.
221+
# out into a request per location (concurrently — see ``_fan_out``) and
222+
# concatenate the results.
220223
locations = _resolve_locations(state, county, huc)
221-
222-
def _fetch(location: str) -> tuple[pd.DataFrame, httpx.Response]:
223-
return _fetch_all_pages(
224-
{**base_params, "location": location}, ssl_check=ssl_check
225-
)
226-
227-
if len(locations) == 1:
228-
# Common case: no pool, and no extra concat copy of the whole result.
229-
frame, response = _fetch(locations[0])
230-
return frame, BaseMetadata(response)
231-
232-
# Fan out concurrently (bounded), preserving input order. The locations are
233-
# independent single requests, so a thread pool over the synchronous fetch
234-
# needs no shared state or backoff; ``pool.map`` re-raises the first failure.
235-
workers = min(len(locations), max(1, MAX_CONCURRENT_REQUESTS))
236-
with ThreadPoolExecutor(max_workers=workers) as pool:
237-
results = list(pool.map(_fetch, locations))
238-
df = pd.concat([frame for frame, _ in results], ignore_index=True)
239-
return df, BaseMetadata(results[0][1])
224+
# Drive the async fan-out from a worker thread so it is safe even when
225+
# called inside an already-running event loop (e.g. a Jupyter notebook),
226+
# where a bare ``asyncio.run`` would raise.
227+
with ThreadPoolExecutor(max_workers=1) as pool:
228+
df, response = pool.submit(
229+
lambda: asyncio.run(_fan_out(locations, base_params, ssl_check))
230+
).result()
231+
return df, BaseMetadata(response)
240232

241233

242234
# Valid HUC code lengths (digits) → the hydrologic-unit level they query.
@@ -316,57 +308,112 @@ def _validate_huc(value: object) -> str:
316308
return code
317309

318310

319-
def _fetch_all_pages(
320-
params: dict[str, Any], *, ssl_check: bool
311+
async def _fan_out(
312+
locations: list[str], base_params: dict[str, Any], ssl_check: bool
321313
) -> tuple[pd.DataFrame, httpx.Response]:
322-
"""Fetch every page of a water-use query and concatenate the CSV bodies.
314+
"""Fetch every location concurrently over one shared async client.
315+
316+
Each location is an independent paginated request; concurrency is bounded by
317+
a semaphore at :data:`MAX_CONCURRENT_REQUESTS`, and ``asyncio.gather``
318+
preserves input order so the concatenation is deterministic. The single
319+
shared :class:`httpx.AsyncClient` keeps connections alive across pages and
320+
locations.
321+
"""
322+
headers = _default_headers()
323+
semaphore = asyncio.Semaphore(max(1, MAX_CONCURRENT_REQUESTS))
324+
325+
async with httpx.AsyncClient(verify=ssl_check, **HTTPX_DEFAULTS) as client:
326+
327+
async def _one(location: str) -> tuple[pd.DataFrame, list[httpx.Response]]:
328+
async with semaphore:
329+
return await _fetch_location(client, location, base_params, headers)
330+
331+
results = await asyncio.gather(*(_one(loc) for loc in locations))
332+
333+
frames = [frame for frame, _ in results]
334+
responses = [resp for _, page_responses in results for resp in page_responses]
335+
df = frames[0] if len(frames) == 1 else pd.concat(frames, ignore_index=True)
336+
return df, _aggregate_responses(responses)
337+
338+
339+
async def _fetch_location(
340+
client: httpx.AsyncClient,
341+
location: str,
342+
base_params: dict[str, Any],
343+
headers: dict[str, str],
344+
) -> tuple[pd.DataFrame, list[httpx.Response]]:
345+
"""Fetch and concatenate every page for one location over ``client``.
323346
324347
The NWDC paginates large areas with an RFC 8288 ``Link: <...>; rel="next"``
325348
header (the cursor is a ``skip`` offset). The first request carries the
326-
query params; each subsequent page is a fully-formed URL we request bare.
327-
Returns the combined frame and the first page's response (for metadata).
349+
query params; each subsequent page is a fully-formed URL requested bare. The
350+
``seen`` set guards against a non-advancing or cyclic cursor (a server bug
351+
that would otherwise loop forever, accumulating frames until OOM).
328352
"""
329-
headers = _default_headers()
330-
frame, first_response = _fetch_page(WATERUSE_URL, params, headers, ssl_check)
331-
frames = [frame]
332-
# Guard against a non-advancing or cyclic ``next`` cursor (a server bug
333-
# would otherwise spin this loop forever, accumulating frames until OOM):
334-
# stop if a page points back to a URL we have already fetched.
353+
frames: list[pd.DataFrame] = []
354+
responses: list[httpx.Response] = []
335355
seen: set[str] = set()
336-
next_url = _next_page_url(first_response)
337-
while next_url is not None and next_url not in seen:
338-
seen.add(next_url)
339-
frame, response = _fetch_page(next_url, None, headers, ssl_check)
340-
frames.append(frame)
341-
next_url = _next_page_url(response)
342-
# Avoid re-copying the (often whole) single-page result, matching the
343-
# per-location concat in get_wateruse.
356+
url: str | None = WATERUSE_URL
357+
params: dict[str, Any] | None = {**base_params, "location": location}
358+
while url is not None and url not in seen:
359+
seen.add(url)
360+
try:
361+
response = await client.get(url, params=params, headers=headers)
362+
except httpx.TransportError as exc:
363+
raise _network_error(url, exc) from exc
364+
_raise_for_status(response, detail_from=_nwdc_error_detail)
365+
logger.debug("Requested water-use page: %s", response.url)
366+
responses.append(response)
367+
frames.append(_read_csv_page(response))
368+
url, params = _next_page_url(response), None
369+
344370
df = frames[0] if len(frames) == 1 else pd.concat(frames, ignore_index=True)
345-
return df, first_response
371+
return df, responses
346372

347373

348-
def _fetch_page(
349-
url: str,
350-
params: dict[str, Any] | None,
351-
headers: dict[str, str],
352-
ssl_check: bool,
353-
) -> tuple[pd.DataFrame, httpx.Response]:
354-
"""Fetch one water-use page and parse its CSV body into a DataFrame."""
355-
response = _get(
356-
url, params=params, headers=headers, verify=ssl_check, **HTTPX_DEFAULTS
357-
)
358-
_raise_for_status(response, detail_from=_nwdc_error_detail)
359-
logger.debug("Requested water-use page: %s", response.url)
374+
def _read_csv_page(response: httpx.Response) -> pd.DataFrame:
375+
"""Parse one CSV page; ``huc12_id`` stays a string to keep leading zeros."""
360376
try:
361-
frame = pd.read_csv(io.BytesIO(response.content), dtype={_HUC12_COLUMN: str})
377+
return pd.read_csv(io.BytesIO(response.content), dtype={_HUC12_COLUMN: str})
362378
except pd.errors.EmptyDataError as exc:
363379
# NWDC normally signals "no data" with a 400 (handled above) or rows of
364380
# zeros, never an empty body — but keep the typed-error contract if it
365381
# ever returns one rather than leaking a bare pandas exception.
366382
raise DataRetrievalError(
367383
f"NWDC returned an empty response body (URL: {response.url})."
368384
) from exc
369-
return frame, response
385+
386+
387+
def _aggregate_responses(responses: list[httpx.Response]) -> httpx.Response:
388+
"""Fold the per-page, per-location responses into one for metadata.
389+
390+
Keeps the first request's URL (the query identity) but surfaces the *final*
391+
rate-limit headers — those of the response that saw the lowest
392+
``x-ratelimit-remaining``, i.e. the quota left after the whole fan-out — and
393+
the cumulative elapsed time. A single response is returned unchanged.
394+
"""
395+
first = responses[0]
396+
if len(responses) == 1:
397+
return first
398+
final = copy.copy(first)
399+
final.headers = httpx.Headers(_most_depleted(responses).headers)
400+
final.elapsed = sum((r.elapsed for r in responses[1:]), start=first.elapsed)
401+
return final
402+
403+
404+
def _most_depleted(responses: list[httpx.Response]) -> httpx.Response:
405+
"""The response reporting the lowest ``x-ratelimit-remaining`` (the latest
406+
server-side view of the quota), or the last response if none report it."""
407+
best: httpx.Response | None = None
408+
best_remaining: int | None = None
409+
for response in responses:
410+
try:
411+
remaining = int(response.headers["x-ratelimit-remaining"])
412+
except (KeyError, ValueError):
413+
continue
414+
if best_remaining is None or remaining < best_remaining:
415+
best, best_remaining = response, remaining
416+
return best if best is not None else responses[-1]
370417

371418

372419
def _next_page_url(response: httpx.Response) -> str | None:

tests/wateruse_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,41 @@ def test_fan_out_is_serial_when_concurrency_is_one(httpx_mock, monkeypatch):
265265
assert len(httpx_mock.get_requests()) == 2
266266

267267

268+
def test_fan_out_surfaces_final_rate_limit_header(httpx_mock):
269+
"""``md.header`` reports the lowest (latest) remaining quota across the fan-out,
270+
not the first request's value."""
271+
httpx_mock.add_response(
272+
method="GET",
273+
url=re.compile(r".*location=stateCd%3ARI.*"),
274+
text=_CSV_P1,
275+
headers={"x-ratelimit-remaining": "900"},
276+
)
277+
httpx_mock.add_response(
278+
method="GET",
279+
url=re.compile(r".*location=stateCd%3AWI.*"),
280+
text=_CSV_P2,
281+
headers={"x-ratelimit-remaining": "850"},
282+
)
283+
284+
_, md = get_wateruse(model="wu-public-supply-wd", state=["RI", "WI"])
285+
286+
assert md.header["x-ratelimit-remaining"] == "850"
287+
288+
289+
def test_most_depleted_picks_lowest_remaining():
290+
responses = [
291+
httpx.Response(200, headers={"x-ratelimit-remaining": "900"}),
292+
httpx.Response(200, headers={"x-ratelimit-remaining": "850"}),
293+
httpx.Response(200, headers={"x-ratelimit-remaining": "875"}),
294+
]
295+
assert wateruse._most_depleted(responses) is responses[1]
296+
297+
298+
def test_most_depleted_falls_back_to_last_when_header_absent():
299+
responses = [httpx.Response(200), httpx.Response(200)]
300+
assert wateruse._most_depleted(responses) is responses[1]
301+
302+
268303
# --- _resolve_locations unit tests (no HTTP) -------------------------------
269304

270305

0 commit comments

Comments
 (0)