Skip to content

Commit 4e44906

Browse files
thodson-usgsclaude
andcommitted
refactor(waterdata): /simplify pass — shared Session, _paginate helper, dedup state
Three improvements driven by a Python-idioms code review: - ChunkedCall.resume opens one requests.Session for the whole chunked call and publishes it on a `_chunked_session` ContextVar; _walk_pages picks it up when called with `client=None` so every sub-request of a fan-out rides the same connection pool. ContextVar (not a kwarg on _FetchOnce) keeps existing test fixtures' `def fetch(args)` signature intact. Two tests added: shared-Session across sub-requests, and isolation between resume() calls. - _walk_pages and get_stats_data now defer to a shared `_paginate` helper that owns the session-or-borrow management, the initial-then-cursor loop, the wrap-as-RuntimeError contract, and the finalize-headers/elapsed step. Each call site supplies only the page-parsing and follow-up-request strategies. ~80 lines of duplicated loop logic collapse into two ~15-line wrappers. - ChunkedCall switched from parallel `_frames` + `_responses` lists to a single `_chunks: list[tuple[DataFrame, Response]]`. Eliminates the invariant that the two lists move in lockstep — len(_chunks) is now the only cursor. Other polish from the same review: - copy.copy in _combine_chunk_responses so partial_response is idempotent across repeated property access mid-resume - _Axis gets `slots=True` (Py3.10+) - _plan_axes halving loop uses slice assignment instead of tuple-concat - Drop redundant `.upper()` on PreparedRequest.method (already normalized) - Drop dead `"datetime" # unused` config entry - Tighten `Callable[..., Any]` → `Callable[..., requests.PreparedRequest]` for build_request - `cur.retry_after` instead of `getattr(cur, "retry_after", None)` — the _RetryableTransportError base guarantees the attribute - Drift test asserting _DATE_RANGE_PARAMS ⊆ _NEVER_CHUNK so a future date-range param added to utils.py can't silently bypass the chunker's never-chunk list - Trim WHAT-narrating comments in ChunkPlan.__init__ and iter_sub_args - Fix stale `get_ogc_data` example ("wfs"/"wms" → "daily"/…) and add missing geopd param doc to _handle_stats_nesting 80 chunker + utils unit tests pass; ruff clean. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent afe073b commit 4e44906

3 files changed

Lines changed: 314 additions & 164 deletions

File tree

dataretrieval/waterdata/chunking.py

Lines changed: 87 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@
3333

3434
from __future__ import annotations
3535

36+
import copy
3637
import functools
3738
import itertools
3839
import math
3940
import os
4041
from collections.abc import Callable, Iterator
42+
from contextvars import ContextVar
4143
from dataclasses import dataclass
4244
from typing import Any, ClassVar
4345
from urllib.parse import quote_plus
@@ -93,6 +95,19 @@
9395
# Response header USGS uses to advertise remaining hourly quota.
9496
_QUOTA_HEADER = "x-ratelimit-remaining"
9597

98+
# Session shared across all sub-requests of a single chunked call.
99+
# Set by ``ChunkedCall.resume`` so paginated-loop helpers downstream
100+
# (``_walk_pages``) reuse the same connection pool across the entire
101+
# fan-out instead of opening a fresh ``requests.Session`` per
102+
# sub-request. ``None`` when not inside a chunked call — paginated
103+
# helpers fall back to their own short-lived session in that case.
104+
# Plumbed via ``ContextVar`` rather than a kwarg on ``_FetchOnce`` so
105+
# user-defined fetch functions (and test fixtures) keep the simple
106+
# ``fetch(args)`` signature.
107+
_chunked_session: ContextVar[requests.Session | None] = ContextVar(
108+
"_chunked_session", default=None
109+
)
110+
96111
# Separators the two axis kinds use to join their atoms back into
97112
# URL text. List axes comma-join values
98113
# (``site=USGS-A,USGS-B``); the filter axis OR-joins clauses
@@ -336,7 +351,7 @@ def _request_bytes(req: requests.PreparedRequest) -> int:
336351
return len(req.url) + body_len
337352

338353

339-
@dataclass(frozen=True)
354+
@dataclass(frozen=True, slots=True)
340355
class _Axis:
341356
"""
342357
A single chunkable axis of one user-level request — a list of
@@ -481,7 +496,7 @@ def _worst_case_args(
481496
def _plan_axes(
482497
axes: list[_Axis],
483498
args: dict[str, Any],
484-
build_request: Callable[..., Any],
499+
build_request: Callable[..., requests.PreparedRequest],
485500
url_limit: int,
486501
) -> dict[str, list[list[str]]]:
487502
"""
@@ -498,7 +513,7 @@ def _plan_axes(
498513
The chunkable axes to partition.
499514
args : dict[str, Any]
500515
The user-level args (used to build the worst-case request).
501-
build_request : Callable[..., Any]
516+
build_request : Callable[..., requests.PreparedRequest]
502517
Factory that turns a kwargs dict into a sized prepared
503518
request, e.g. ``_construct_api_requests``.
504519
url_limit : int
@@ -545,11 +560,7 @@ def _plan_axes(
545560
axis_chunks = chunks[biggest_axis.arg_key]
546561
chunk = axis_chunks[biggest_idx]
547562
mid = len(chunk) // 2
548-
chunks[biggest_axis.arg_key] = (
549-
axis_chunks[:biggest_idx]
550-
+ [chunk[:mid], chunk[mid:]]
551-
+ axis_chunks[biggest_idx + 1 :]
552-
)
563+
axis_chunks[biggest_idx : biggest_idx + 1] = [chunk[:mid], chunk[mid:]]
553564

554565

555566
class ChunkPlan:
@@ -572,7 +583,7 @@ class ChunkPlan:
572583
----------
573584
args : dict[str, Any]
574585
The user-level request kwargs.
575-
build_request : Callable[..., Any]
586+
build_request : Callable[..., requests.PreparedRequest]
576587
Factory that turns a kwargs dict into a sized prepared
577588
request, e.g. ``_construct_api_requests``.
578589
url_limit : int
@@ -608,28 +619,23 @@ class ChunkPlan:
608619
def __init__(
609620
self,
610621
args: dict[str, Any],
611-
build_request: Callable[..., Any],
622+
build_request: Callable[..., requests.PreparedRequest],
612623
url_limit: int,
613624
) -> None:
614625
self.args = args
615-
# Passthrough defaults; promoted at the bottom only if chunking
616-
# is actually needed.
617626
self.axes: list[_Axis] = []
618627
self.chunks: dict[str, list[list[str]]] = {}
619628
self.canonical_url: str | None = None
620629

621630
axes = _extract_axes(args)
622-
623-
# Trivial passthrough: no chunkable axes. Skip the
624-
# ``build_request`` call entirely — the common Water Data call
625-
# shape doesn't pay for an unused request prep.
631+
# No chunkable axes → skip ``build_request`` entirely; the
632+
# common Water Data call shape shouldn't pay for an unused
633+
# request prep on the passthrough hot path.
626634
if not axes:
627635
return
628636

629637
initial_request = build_request(**args)
630638
self.canonical_url = initial_request.url
631-
632-
# Already-fits passthrough: chunking is possible but unnecessary.
633639
if _request_bytes(initial_request) <= url_limit:
634640
return
635641

@@ -663,8 +669,6 @@ def iter_sub_args(self) -> Iterator[dict[str, Any]]:
663669
A copy of ``self.args`` with each axis's current chunk
664670
substituted under its ``arg_key``.
665671
"""
666-
# Trivial-passthrough fast path: skip the cartesian product
667-
# machinery and yield ``self.args`` directly.
668672
if not self.axes:
669673
yield self.args
670674
return
@@ -791,9 +795,9 @@ def _classify_chunk_error(
791795
cur: BaseException | None = exc
792796
while cur is not None:
793797
if isinstance(cur, RateLimited):
794-
return QuotaExhausted, getattr(cur, "retry_after", None)
798+
return QuotaExhausted, cur.retry_after
795799
if isinstance(cur, ServiceUnavailable):
796-
return ServiceInterrupted, getattr(cur, "retry_after", None)
800+
return ServiceInterrupted, cur.retry_after
797801
cur = cur.__cause__
798802
return None
799803

@@ -839,11 +843,12 @@ def _combine_chunk_responses(
839843
"""
840844
Fold per-sub-request responses into a single aggregated response.
841845
842-
The first response is mutated in place: ``.headers`` becomes the
843-
last response's (so ``x-ratelimit-remaining`` reflects current
844-
state), ``.elapsed`` accumulates total wall-clock, and ``.url`` is
845-
set to the canonical original-query URL so ``BaseMetadata``
846-
reflects the user's full request rather than the first chunk.
846+
Returns a shallow copy of ``responses[0]`` with ``.headers`` set to
847+
the last response's (so ``x-ratelimit-remaining`` reflects current
848+
state), ``.elapsed`` set to total wall-clock across every response,
849+
and ``.url`` set to the canonical original-query URL so
850+
``BaseMetadata`` reflects the user's full request rather than the
851+
first chunk.
847852
848853
Parameters
849854
----------
@@ -858,12 +863,20 @@ def _combine_chunk_responses(
858863
Returns
859864
-------
860865
requests.Response
861-
The first response, mutated as described above.
866+
A copy of the first response with aggregated state. The input
867+
responses are not mutated, so this function is idempotent —
868+
safe to call repeatedly via :attr:`ChunkedCall.partial_response`
869+
during error inspection or resume retries.
862870
"""
863-
head = responses[0]
871+
# copy.copy preserves the requests.Response shape but breaks the
872+
# alias to responses[0] so repeated calls accumulate fresh totals
873+
# rather than re-adding tail elapsed onto a previously-mutated head.
874+
head = copy.copy(responses[0])
864875
if len(responses) > 1:
865876
head.headers = responses[-1].headers
866-
head.elapsed = sum((r.elapsed for r in responses[1:]), start=head.elapsed)
877+
head.elapsed = sum(
878+
(r.elapsed for r in responses[1:]), start=responses[0].elapsed
879+
)
867880
if canonical_url is not None:
868881
head.url = canonical_url
869882
return head
@@ -919,12 +932,15 @@ class ChunkedCall:
919932
def __init__(self, plan: ChunkPlan, fetch_once: _FetchOnce) -> None:
920933
self.plan = plan
921934
self.fetch_once = fetch_once
922-
self._frames: list[pd.DataFrame] = []
923-
self._responses: list[requests.Response] = []
935+
# One entry per completed sub-request, in execution order.
936+
# A single list keeps the (frame, response) pair atomic so the
937+
# ``len(_chunks)`` cursor can't ever drift between two parallel
938+
# lists.
939+
self._chunks: list[tuple[pd.DataFrame, requests.Response]] = []
924940

925941
@property
926942
def completed_chunks(self) -> int:
927-
return len(self._responses)
943+
return len(self._chunks)
928944

929945
@property
930946
def total_chunks(self) -> int:
@@ -945,9 +961,9 @@ def partial_frame(self) -> pd.DataFrame:
945961
Combined frame of completed sub-requests, or an empty
946962
``DataFrame`` when nothing has completed.
947963
"""
948-
if not self._frames:
964+
if not self._chunks:
949965
return pd.DataFrame()
950-
return _combine_chunk_frames(self._frames)
966+
return _combine_chunk_frames([frame for frame, _ in self._chunks])
951967

952968
@property
953969
def partial_response(self) -> requests.Response | None:
@@ -963,14 +979,24 @@ def partial_response(self) -> requests.Response | None:
963979
Aggregated response when at least one sub-request has
964980
completed, ``None`` otherwise.
965981
"""
966-
if not self._responses:
982+
if not self._chunks:
967983
return None
968-
return _combine_chunk_responses(self._responses, self.plan.canonical_url)
984+
return _combine_chunk_responses(
985+
[resp for _, resp in self._chunks], self.plan.canonical_url
986+
)
969987

970988
def resume(self) -> tuple[pd.DataFrame, requests.Response]:
971989
"""
972990
Drive the chunked call to completion.
973991
992+
Opens one ``requests.Session`` for the run and publishes it on
993+
the ``_chunked_session`` ``ContextVar`` so paginated-loop
994+
helpers downstream (``_walk_pages``) reuse the same connection
995+
pool across every sub-request instead of handshaking fresh on
996+
each. The session is closed when ``resume`` returns or raises;
997+
a follow-up ``resume`` call (after a ``ChunkInterrupted``)
998+
opens a new one.
999+
9741000
Idempotent: starts from chunk 0 on the first call, then from
9751001
the cursor (``self.completed_chunks``) on every subsequent
9761002
call. Re-issues only sub-requests that haven't already
@@ -996,39 +1022,46 @@ def resume(self) -> tuple[pd.DataFrame, requests.Response]:
9961022
When the rate-limit window can't cover the remaining plan
9971023
(checked after the first sub-request).
9981024
"""
999-
completed = len(self._responses)
1000-
for i, sub_args in enumerate(self.plan.iter_sub_args()):
1001-
if i < completed:
1002-
continue
1003-
self._issue(sub_args)
1004-
return (
1005-
_combine_chunk_frames(self._frames),
1006-
_combine_chunk_responses(self._responses, self.plan.canonical_url),
1007-
)
1025+
with requests.Session() as session:
1026+
token = _chunked_session.set(session)
1027+
try:
1028+
completed = len(self._chunks)
1029+
for i, sub_args in enumerate(self.plan.iter_sub_args()):
1030+
if i < completed:
1031+
continue
1032+
self._issue(sub_args)
1033+
frames = [frame for frame, _ in self._chunks]
1034+
responses = [resp for _, resp in self._chunks]
1035+
return (
1036+
_combine_chunk_frames(frames),
1037+
_combine_chunk_responses(responses, self.plan.canonical_url),
1038+
)
1039+
finally:
1040+
_chunked_session.reset(token)
10081041

10091042
def _issue(self, sub_args: dict[str, Any]) -> None:
10101043
try:
1011-
frame, response = self.fetch_once(sub_args)
1044+
chunk = self.fetch_once(sub_args)
10121045
except RuntimeError as exc:
10131046
classification = _classify_chunk_error(exc)
10141047
if classification is None:
10151048
raise
10161049
interrupted_class, retry_after = classification
10171050
raise interrupted_class(
1018-
completed_chunks=len(self._responses),
1051+
completed_chunks=len(self._chunks),
10191052
total_chunks=self.plan.total,
10201053
call=self,
10211054
retry_after=retry_after,
10221055
) from exc
1023-
self._frames.append(frame)
1024-
self._responses.append(response)
1025-
if len(self._responses) == 1 and self.plan.total > 1:
1056+
self._chunks.append(chunk)
1057+
if len(self._chunks) == 1 and self.plan.total > 1:
10261058
self._check_quota_after_first()
10271059

10281060
def _check_quota_after_first(self) -> None:
10291061
if _quota_check_disabled():
10301062
return
1031-
remaining = _read_remaining(self._responses[0])
1063+
_, first_response = self._chunks[0]
1064+
remaining = _read_remaining(first_response)
10321065
if remaining is None or remaining >= self.plan.total - 1:
10331066
return
10341067
raise RequestExceedsQuota(
@@ -1040,7 +1073,7 @@ def _check_quota_after_first(self) -> None:
10401073

10411074
def multi_value_chunked(
10421075
*,
1043-
build_request: Callable[..., Any],
1076+
build_request: Callable[..., requests.PreparedRequest],
10441077
url_limit: int | None = None,
10451078
) -> Callable[[_FetchOnce], _FetchOnce]:
10461079
"""
@@ -1054,7 +1087,7 @@ def multi_value_chunked(
10541087
10551088
Parameters
10561089
----------
1057-
build_request : Callable
1090+
build_request : Callable[..., requests.PreparedRequest]
10581091
Factory that turns a kwargs dict into a sized prepared
10591092
request, e.g. ``_construct_api_requests``. Called during
10601093
planning to measure each candidate plan.

0 commit comments

Comments
 (0)