Skip to content

Commit e740a64

Browse files
thodson-usgsclaude
andcommitted
fix(waterdata): finalize resumed chunked results; add reference-table row cap
Post-review fixes layered on the async parallel chunker: - Funnel OGC post-processing through the chunker via an injected `finalize` hook so `ChunkInterrupted.call.resume()` returns the same type-coerced `(df, BaseMetadata)` as an un-interrupted call instead of the raw `(frame, httpx.Response)`. `partial_frame`/`partial_response` stay raw, so building the exception never triggers finalize's side effects (a schema network GET on an empty frame would otherwise fire inside the ctor). - Add `max_rows` to `get_reference_table`/`get_ogc_data` to preview large reference tables without downloading every page; enforced as the exact total in `_finalize_ogc` (after dedup) and validated as a positive integer (accepts numpy ints via `numbers.Integral`). - Co-locate the parallel fan-out into `ChunkedCall.resume_async`, sharing a `_pending()` generator with the serial `resume()` so the two execution paths can't drift. - Harden `ProgressReporter.note_retry` for Python 3.9-3.11 (int `wait` and `int.is_integer()`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1d0acb7 commit e740a64

8 files changed

Lines changed: 689 additions & 294 deletions

File tree

dataretrieval/waterdata/_progress.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def note_retry(self, *, attempt: int, wait: float) -> None:
162162
"""
163163
# Keep sub-second waits explicit (avoid misleading ``0s``) while
164164
# rendering whole-second waits without unnecessary ``.0`` noise.
165-
wait_1dp = round(wait, 1)
165+
# ``float()`` to support Python 3.9-3.11: ``round(int, 1)`` returns an
166+
# int and ``int.is_integer()`` (used below) only exists on 3.12+.
167+
wait_1dp = round(float(wait), 1)
166168
if wait_1dp < 1 or not wait_1dp.is_integer():
167169
secs = f"{wait_1dp:.1f}s"
168170
else:

dataretrieval/waterdata/api.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2022,6 +2022,7 @@ def get_reference_table(
20222022
collection: str,
20232023
limit: int | None = None,
20242024
query: dict | None = None,
2025+
max_rows: int | None = None,
20252026
) -> tuple[pd.DataFrame, BaseMetadata]:
20262027
"""Get metadata reference tables for the USGS Water Data API.
20272028
@@ -2046,6 +2047,12 @@ def get_reference_table(
20462047
query: dictionary, optional
20472048
The optional args parameter can be used to pass a dictionary of
20482049
query parameters to the collection API call.
2050+
max_rows : int, optional
2051+
Cap the total number of rows returned, stopping pagination early
2052+
instead of downloading the whole table. Useful for cheaply
2053+
previewing large tables (e.g. ``hydrologic-unit-codes`` has ~125k
2054+
rows). Unlike ``limit`` (the per-page size), this bounds the total
2055+
result. The default (None) downloads every page.
20492056
20502057
Returns
20512058
-------
@@ -2092,7 +2099,9 @@ def get_reference_table(
20922099
query_args = dict(query) if query else {}
20932100
if limit is not None:
20942101
query_args["limit"] = limit
2095-
return get_ogc_data(args=query_args, output_id=output_id, service=collection)
2102+
return get_ogc_data(
2103+
args=query_args, output_id=output_id, service=collection, max_rows=max_rows
2104+
)
20962105

20972106

20982107
def get_codes(code_service: CODE_SERVICES) -> pd.DataFrame:

dataretrieval/waterdata/chunking.py

Lines changed: 289 additions & 229 deletions
Large diffs are not rendered by default.

dataretrieval/waterdata/utils.py

Lines changed: 123 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import json
66
import logging
7+
import numbers
78
import os
89
import re
910
from collections.abc import (
@@ -15,6 +16,7 @@
1516
Mapping,
1617
)
1718
from contextlib import asynccontextmanager, contextmanager
19+
from contextvars import ContextVar
1820
from datetime import datetime, timedelta
1921
from typing import Any, TypeVar, get_args
2022
from zoneinfo import ZoneInfo
@@ -905,6 +907,26 @@ def _aggregate_paginated_response(
905907

906908
_Cursor = TypeVar("_Cursor")
907909

910+
# Optional cap on the total rows a single paginated call accumulates before it
911+
# stops following ``next`` links. ``None`` (the default the data getters use)
912+
# means "no cap — fetch the whole series". Set via :func:`_row_cap` so the deep
913+
# ``_paginate`` loop can honor it without threading the value through the
914+
# generic chunker; this mirrors the ``_progress`` ambient-reporter pattern.
915+
_row_cap_var: ContextVar[int | None] = ContextVar("waterdata_row_cap", default=None)
916+
917+
918+
@contextmanager
919+
def _row_cap(max_rows: int | None) -> Iterator[None]:
920+
"""Cap the rows any :func:`_paginate` / :func:`_paginate_async` under this
921+
context will accumulate (``None`` = uncapped). Used by
922+
:func:`get_reference_table` to preview large tables without downloading
923+
every page."""
924+
token = _row_cap_var.set(max_rows)
925+
try:
926+
yield
927+
finally:
928+
_row_cap_var.reset(token)
929+
908930

909931
def _paginate(
910932
initial_req: httpx.Request,
@@ -988,18 +1010,24 @@ def _paginate(
9881010
logger.warning("Initial response parse failed.")
9891011
raise RuntimeError(_paginated_failure_message(0, e)) from e
9901012
dfs = [df]
1013+
# Stop following ``next`` links once the optional row cap is reached
1014+
# (see :func:`_row_cap`); ``None`` means uncapped. The concatenation is
1015+
# sliced to the cap below so a final over-budget page can't exceed it.
1016+
cap = _row_cap_var.get()
1017+
nrows = len(df)
9911018
if reporter is not None:
9921019
reporter.set_rate_remaining(
9931020
resp.headers.get(_QUOTA_HEADER),
9941021
limit=resp.headers.get("x-ratelimit-limit"),
9951022
)
9961023
reporter.add_page(rows=len(df))
997-
while cursor is not None:
1024+
while cursor is not None and (cap is None or nrows < cap):
9981025
try:
9991026
resp = follow_up(cursor, client)
10001027
_raise_for_non_200(resp)
10011028
df, cursor = parse_response(resp)
10021029
dfs.append(df)
1030+
nrows += len(df)
10031031
total_elapsed += _safe_elapsed(resp)
10041032
if reporter is not None:
10051033
reporter.set_rate_remaining(
@@ -1021,7 +1049,10 @@ def _paginate(
10211049
final_response = _aggregate_paginated_response(
10221050
initial_response, resp, total_elapsed
10231051
)
1024-
return pd.concat(dfs, ignore_index=True), final_response
1052+
result = pd.concat(dfs, ignore_index=True)
1053+
if cap is not None:
1054+
result = result.head(cap)
1055+
return result, final_response
10251056

10261057

10271058
async def _paginate_async(
@@ -1037,7 +1068,7 @@ async def _paginate_async(
10371068
10381069
Runs the same per-page loop but issues HTTP asynchronously so
10391070
multiple sub-requests of a chunked call can run concurrently from
1040-
:func:`_fan_out_async`.
1071+
:meth:`~dataretrieval.waterdata.chunking.ChunkedCall.resume_async`.
10411072
"""
10421073
logger.debug("Requesting: %s", initial_req.url)
10431074
reporter = _progress.current()
@@ -1058,18 +1089,24 @@ async def _paginate_async(
10581089
logger.warning("Initial response parse failed.")
10591090
raise RuntimeError(_paginated_failure_message(0, e)) from e
10601091
dfs = [df]
1092+
# Stop following ``next`` links once the optional row cap is reached
1093+
# (see :func:`_row_cap`); ``None`` means uncapped. Mirrors the sync
1094+
# :func:`_paginate`; the concatenation is sliced to the cap below.
1095+
cap = _row_cap_var.get()
1096+
nrows = len(df)
10611097
if reporter is not None:
10621098
reporter.set_rate_remaining(
10631099
resp.headers.get(_QUOTA_HEADER),
10641100
limit=resp.headers.get("x-ratelimit-limit"),
10651101
)
10661102
reporter.add_page(rows=len(df))
1067-
while cursor is not None:
1103+
while cursor is not None and (cap is None or nrows < cap):
10681104
try:
10691105
resp = await follow_up(cursor, sess)
10701106
_raise_for_non_200(resp)
10711107
df, cursor = parse_response(resp)
10721108
dfs.append(df)
1109+
nrows += len(df)
10731110
total_elapsed += _safe_elapsed(resp)
10741111
if reporter is not None:
10751112
reporter.set_rate_remaining(
@@ -1091,7 +1128,10 @@ async def _paginate_async(
10911128
final_response = _aggregate_paginated_response(
10921129
initial_response, resp, total_elapsed
10931130
)
1094-
return pd.concat(dfs, ignore_index=True), final_response
1131+
result = pd.concat(dfs, ignore_index=True)
1132+
if cap is not None:
1133+
result = result.head(cap)
1134+
return result, final_response
10951135

10961136

10971137
def _ogc_parse_response(
@@ -1356,8 +1396,50 @@ def _sort_rows(df: pd.DataFrame) -> pd.DataFrame:
13561396
return df
13571397

13581398

1399+
def _finalize_ogc(
1400+
frame: pd.DataFrame,
1401+
response: httpx.Response,
1402+
*,
1403+
properties: list[str] | None,
1404+
output_id: str,
1405+
convert_type: bool,
1406+
service: str,
1407+
max_rows: int | None = None,
1408+
) -> tuple[pd.DataFrame, BaseMetadata]:
1409+
"""Shape a combined OGC result into the user-facing ``(df, md)``.
1410+
1411+
The single home for the OGC getters' result shaping: empties
1412+
normalized, types coerced (when ``convert_type``), the wire ``id``
1413+
renamed and columns ordered, rows sorted, optionally truncated to
1414+
``max_rows``, and the response wrapped as
1415+
:class:`~dataretrieval.utils.BaseMetadata`.
1416+
1417+
Injected into the chunker as its ``finalize`` hook (see
1418+
:data:`~dataretrieval.waterdata.chunking._Finalize`) so the
1419+
un-interrupted return *and* a resumed ``ChunkInterrupted.call.resume()``
1420+
produce the same shape — closing the gap where resume used to hand back
1421+
the chunker's raw frame and bare ``httpx.Response``.
1422+
1423+
``max_rows`` is applied here (after dedup/sort, on the *combined* frame)
1424+
rather than only per-sub-request, so a chunked call's total is bounded
1425+
to exactly ``max_rows`` and a resumed call honors the cap too — the
1426+
per-``_paginate`` ``_row_cap`` is only an early-stop download bound.
1427+
"""
1428+
frame = _deal_with_empty(frame, properties, service)
1429+
if convert_type:
1430+
frame = _type_cols(frame)
1431+
frame = _arrange_cols(frame, properties, output_id)
1432+
frame = _sort_rows(frame)
1433+
if max_rows is not None:
1434+
frame = frame.head(max_rows)
1435+
return frame, BaseMetadata(response)
1436+
1437+
13591438
def get_ogc_data(
1360-
args: dict[str, Any], output_id: str, service: str
1439+
args: dict[str, Any],
1440+
output_id: str,
1441+
service: str,
1442+
max_rows: int | None = None,
13611443
) -> tuple[pd.DataFrame, BaseMetadata]:
13621444
"""
13631445
Retrieves OGC (Open Geospatial Consortium) data from a specified
@@ -1376,6 +1458,11 @@ def get_ogc_data(
13761458
service : str
13771459
The OGC API collection name (e.g., ``"daily"``,
13781460
``"monitoring-locations"``, ``"continuous"``).
1461+
max_rows : int, optional
1462+
Stop paginating once this many rows have been collected and
1463+
truncate the result to exactly ``max_rows``. ``None`` (default)
1464+
fetches the full result. Intended for cheap previews of large,
1465+
un-chunked tables (e.g. :func:`get_reference_table`).
13791466
13801467
Returns
13811468
-------
@@ -1390,6 +1477,19 @@ def get_ogc_data(
13901477
- Handles optional arguments such as `convert_type`.
13911478
- Applies column cleanup and reordering based on service and properties.
13921479
"""
1480+
# Enforce a genuine positive integer: a float (even ``10.0``) or ``bool``
1481+
# would pass a bare ``< 1`` check and then crash deep in
1482+
# ``pd.DataFrame.head`` with an opaque ``TypeError`` after HTTP I/O has
1483+
# already fired. ``numbers.Integral`` (not ``int``) so numpy integers —
1484+
# e.g. ``max_rows`` derived from a numpy/pandas computation — are accepted;
1485+
# ``bool`` is an ``Integral`` subtype, so exclude it explicitly.
1486+
if max_rows is not None and (
1487+
not isinstance(max_rows, numbers.Integral)
1488+
or isinstance(max_rows, bool)
1489+
or max_rows < 1
1490+
):
1491+
raise ValueError(f"max_rows must be a positive integer (got {max_rows!r}).")
1492+
13931493
args = args.copy()
13941494
args["service"] = service
13951495
args = _switch_arg_id(args, id_name=output_id, service=service)
@@ -1402,15 +1502,23 @@ def get_ogc_data(
14021502
convert_type = args.pop("convert_type", False)
14031503
args = {k: v for k, v in args.items() if v is not None}
14041504

1405-
with _progress.progress_context(service=service):
1406-
return_list, response = _fetch_once(args)
1407-
return_list = _deal_with_empty(return_list, properties, service)
1408-
if convert_type:
1409-
return_list = _type_cols(return_list)
1410-
return_list = _arrange_cols(return_list, properties, output_id)
1411-
return_list = _sort_rows(return_list)
1412-
1413-
return return_list, BaseMetadata(response)
1505+
# Post-processing is injected into the chunker rather than applied here,
1506+
# so it runs on *every* exit: the normal return AND a later
1507+
# ``exc.call.resume()`` after a ChunkInterrupted (which never re-enters
1508+
# this function). ``_finalize_ogc`` is the single source of result shape;
1509+
# it also applies ``max_rows`` to the *combined* frame so the cap is the
1510+
# exact total even when the plan chunks or the call is resumed, while
1511+
# ``_row_cap`` below only early-stops each sub-request's pagination.
1512+
finalize = functools.partial(
1513+
_finalize_ogc,
1514+
properties=properties,
1515+
output_id=output_id,
1516+
convert_type=convert_type,
1517+
service=service,
1518+
max_rows=max_rows,
1519+
)
1520+
with _progress.progress_context(service=service), _row_cap(max_rows):
1521+
return _fetch_once(args, finalize=finalize)
14141522

14151523

14161524
async def _fetch_once_async(

0 commit comments

Comments
 (0)