Skip to content

Commit 47c82ec

Browse files
thodson-usgsclaude
andcommitted
refactor(waterdata): /simplify pass 2 — encapsulate session sharing, drop duplicate json parse
Follow-up cleanup on the chunker module driven by another code review: - _publish_session contextmanager wraps the _chunked_session ContextVar set/reset token dance. ChunkedCall.resume becomes a single `with requests.Session() as s, _publish_session(s):` line. - Push the ContextVar lookup INTO `_session()` itself (resolution order: caller-provided -> chunker's shared -> fresh temp). _paginate no longer reaches across modules into `chunking._chunked_session`. - Add `body` kwarg to `_get_resp_data` and `_next_req_url` so `_walk_pages` can `resp.json()` once and reuse the body across both helpers — eliminates a per-page redundant JSON parse on the OGC pagination path (~halves JSON-decode CPU per page). - TypeVar `_Cursor` on `_paginate` so the two callbacks (parse_response, follow_up) are linked through the type system rather than `Any → Any`. Type checkers can now catch a cursor-type mismatch at a single call site. - get_stats_data's follow_up no longer mutates the caller's args dict — uses `params={**args, "next_token": cursor}` instead of `args["next_token"] = cursor`. - Switch _walk_pages's lambdas to named inner functions to match get_stats_data's style. - Hoist the `RateLimited` import in test_chunked_session_isolated_per_resume to module level (was inside the fetch closure). - Drop one redundant multi-line "PreparedRequest.method is already upper-cased" comment in _walk_pages (the inline form in get_stats_data is enough context for the codebase). 80 chunker + utils unit tests pass; ruff clean. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 4e44906 commit 47c82ec

3 files changed

Lines changed: 91 additions & 54 deletions

File tree

dataretrieval/waterdata/chunking.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import math
4040
import os
4141
from collections.abc import Callable, Iterator
42+
from contextlib import contextmanager
4243
from contextvars import ContextVar
4344
from dataclasses import dataclass
4445
from typing import Any, ClassVar
@@ -96,7 +97,7 @@
9697
_QUOTA_HEADER = "x-ratelimit-remaining"
9798

9899
# Session shared across all sub-requests of a single chunked call.
99-
# Set by ``ChunkedCall.resume`` so paginated-loop helpers downstream
100+
# Published by ``_publish_session`` so paginated-loop helpers downstream
100101
# (``_walk_pages``) reuse the same connection pool across the entire
101102
# fan-out instead of opening a fresh ``requests.Session`` per
102103
# sub-request. ``None`` when not inside a chunked call — paginated
@@ -108,6 +109,21 @@
108109
"_chunked_session", default=None
109110
)
110111

112+
113+
@contextmanager
114+
def _publish_session(session: requests.Session) -> Iterator[None]:
115+
"""
116+
Make ``session`` visible to :func:`dataretrieval.waterdata.utils._session`
117+
for the duration of the ``with`` block via the ``_chunked_session``
118+
ContextVar. Wraps the set/reset token dance so callers don't have to.
119+
"""
120+
token = _chunked_session.set(session)
121+
try:
122+
yield
123+
finally:
124+
_chunked_session.reset(token)
125+
126+
111127
# Separators the two axis kinds use to join their atoms back into
112128
# URL text. List axes comma-join values
113129
# (``site=USGS-A,USGS-B``); the filter axis OR-joins clauses
@@ -1022,22 +1038,18 @@ def resume(self) -> tuple[pd.DataFrame, requests.Response]:
10221038
When the rate-limit window can't cover the remaining plan
10231039
(checked after the first sub-request).
10241040
"""
1025-
with requests.Session() as session:
1026-
token = _chunked_session.set(session)
1027-
try:
1028-
completed = len(self._chunks)
1029-
for i, sub_args in enumerate(self.plan.iter_sub_args()):
1030-
if i < completed:
1031-
continue
1032-
self._issue(sub_args)
1033-
frames = [frame for frame, _ in self._chunks]
1034-
responses = [resp for _, resp in self._chunks]
1035-
return (
1036-
_combine_chunk_frames(frames),
1037-
_combine_chunk_responses(responses, self.plan.canonical_url),
1038-
)
1039-
finally:
1040-
_chunked_session.reset(token)
1041+
with requests.Session() as session, _publish_session(session):
1042+
completed = len(self._chunks)
1043+
for i, sub_args in enumerate(self.plan.iter_sub_args()):
1044+
if i < completed:
1045+
continue
1046+
self._issue(sub_args)
1047+
frames = [frame for frame, _ in self._chunks]
1048+
responses = [resp for _, resp in self._chunks]
1049+
return (
1050+
_combine_chunk_frames(frames),
1051+
_combine_chunk_responses(responses, self.plan.canonical_url),
1052+
)
10411053

10421054
def _issue(self, sub_args: dict[str, Any]) -> None:
10431055
try:

dataretrieval/waterdata/utils.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Callable, Iterable, Iterator, Mapping
88
from contextlib import contextmanager
99
from datetime import datetime, timedelta
10-
from typing import Any, get_args
10+
from typing import Any, TypeVar, get_args
1111
from zoneinfo import ZoneInfo
1212

1313
import pandas as pd
@@ -678,7 +678,9 @@ def _construct_api_requests(
678678
return request.prepare()
679679

680680

681-
def _next_req_url(resp: requests.Response) -> str | None:
681+
def _next_req_url(
682+
resp: requests.Response, *, body: dict[str, Any] | None = None
683+
) -> str | None:
682684
"""
683685
Extracts the URL for the next page of results from an HTTP response from a
684686
water data endpoint.
@@ -687,6 +689,10 @@ def _next_req_url(resp: requests.Response) -> str | None:
687689
----------
688690
resp : requests.Response
689691
The HTTP response object containing JSON data and headers.
692+
body : dict, optional
693+
Pre-parsed JSON body for ``resp``. When provided, skips the
694+
``resp.json()`` call — useful when the caller has already
695+
decoded the body for its own use (avoids a second parse pass).
690696
691697
Returns
692698
-------
@@ -702,7 +708,8 @@ def _next_req_url(resp: requests.Response) -> str | None:
702708
"rel" and "href" keys.
703709
- Checks for the "next" relation in the "links" to determine the next URL.
704710
"""
705-
body = resp.json()
711+
if body is None:
712+
body = resp.json()
706713
if not body.get("numberReturned"):
707714
return None
708715
header_info = resp.headers
@@ -719,7 +726,12 @@ def _next_req_url(resp: requests.Response) -> str | None:
719726
return None
720727

721728

722-
def _get_resp_data(resp: requests.Response, geopd: bool) -> pd.DataFrame:
729+
def _get_resp_data(
730+
resp: requests.Response,
731+
geopd: bool,
732+
*,
733+
body: dict[str, Any] | None = None,
734+
) -> pd.DataFrame:
723735
"""
724736
Extracts and normalizes data from an HTTP response containing GeoJSON features.
725737
@@ -731,6 +743,10 @@ def _get_resp_data(resp: requests.Response, geopd: bool) -> pd.DataFrame:
731743
geopd : bool
732744
Indicates whether geopandas is installed and should be used to
733745
handle geometries.
746+
body : dict, optional
747+
Pre-parsed JSON body for ``resp``. When provided, skips the
748+
``resp.json()`` call — useful when the caller has already
749+
decoded the body for its own use (avoids a second parse pass).
734750
735751
Returns
736752
-------
@@ -739,8 +755,8 @@ def _get_resp_data(resp: requests.Response, geopd: bool) -> pd.DataFrame:
739755
containing the feature properties and each row's service-specific id.
740756
Returns an empty pandas DataFrame if no features are returned.
741757
"""
742-
# Check if it's an empty response
743-
body = resp.json()
758+
if body is None:
759+
body = resp.json()
744760
if not body.get("numberReturned"):
745761
return pd.DataFrame()
746762

@@ -771,28 +787,36 @@ def _get_resp_data(resp: requests.Response, geopd: bool) -> pd.DataFrame:
771787
@contextmanager
772788
def _session(client: requests.Session | None) -> Iterator[requests.Session]:
773789
"""
774-
Yield a usable session, opening a temporary one when needed.
790+
Yield a usable session, picking the best available source.
791+
792+
Resolution order:
775793
776-
Lets paginated-loop callers borrow a caller-provided session
777-
(without closing it) or fall back to a short-lived one with a
778-
single ``with`` statement, instead of repeating the
779-
``close_client = client is None`` pattern.
794+
1. ``client`` if the caller supplied one (borrowed; not closed
795+
here — the caller owns its lifecycle).
796+
2. The chunker's shared session if we're inside a ``ChunkedCall``
797+
fan-out (published via :func:`chunking._publish_session`).
798+
Borrowed; ``ChunkedCall.resume`` closes it on exit.
799+
3. A fresh short-lived ``requests.Session`` opened here and closed
800+
on context exit.
780801
781802
Parameters
782803
----------
783804
client : requests.Session or None
784-
A caller-owned session to borrow, or ``None`` to open a
785-
temporary one.
805+
A caller-owned session to borrow, or ``None`` to defer to the
806+
chunker's shared session or a temporary one.
786807
787808
Yields
788809
------
789810
requests.Session
790-
``client`` itself when provided; otherwise a freshly opened
791-
session that is closed on context exit.
811+
The chosen session.
792812
"""
793813
if client is not None:
794814
yield client
795815
return
816+
shared = chunking._chunked_session.get()
817+
if shared is not None:
818+
yield shared
819+
return
796820
with requests.Session() as new:
797821
yield new
798822

@@ -833,12 +857,15 @@ def _finalize_paginated_response(
833857
initial.elapsed = total_elapsed
834858

835859

860+
_Cursor = TypeVar("_Cursor")
861+
862+
836863
def _paginate(
837864
initial_req: requests.PreparedRequest,
838865
*,
839866
geopd: bool,
840-
parse_response: Callable[[requests.Response], tuple[pd.DataFrame, Any]],
841-
follow_up: Callable[[Any, requests.Session], requests.Response],
867+
parse_response: Callable[[requests.Response], tuple[pd.DataFrame, _Cursor | None]],
868+
follow_up: Callable[[_Cursor, requests.Session], requests.Response],
842869
client: requests.Session | None = None,
843870
) -> tuple[pd.DataFrame, requests.Response]:
844871
"""
@@ -898,12 +925,6 @@ def _paginate(
898925
"into pandas DataFrames."
899926
)
900927

901-
# Inside a chunker fan-out, reuse the shared session so every
902-
# sub-request rides the same connection pool. The fallback path
903-
# (``client=None`` and no chunker context) opens a temp session.
904-
if client is None:
905-
client = chunking._chunked_session.get()
906-
907928
with _session(client) as sess:
908929
resp = sess.send(initial_req)
909930
_raise_for_non_200(resp)
@@ -973,22 +994,25 @@ def _walk_pages(
973994
requests.exceptions.RequestException
974995
See :func:`_paginate`.
975996
"""
976-
# ``PreparedRequest.method`` is already upper-cased by
977-
# ``requests`` during preparation, so no need to normalize again.
978-
method = req.method
997+
method = req.method # ``PreparedRequest.method`` is already upper-cased.
979998
headers = dict(req.headers)
980999
content = req.body if method == "POST" else None
9811000

1001+
def parse_response(resp: requests.Response) -> tuple[pd.DataFrame, str | None]:
1002+
body = resp.json()
1003+
return (
1004+
_get_resp_data(resp, geopd=geopd, body=body),
1005+
_next_req_url(resp, body=body),
1006+
)
1007+
1008+
def follow_up(cursor: str, sess: requests.Session) -> requests.Response:
1009+
return sess.request(method, cursor, headers=headers, data=content)
1010+
9821011
return _paginate(
9831012
req,
9841013
geopd=geopd,
985-
parse_response=lambda resp: (
986-
_get_resp_data(resp, geopd=geopd),
987-
_next_req_url(resp),
988-
),
989-
follow_up=lambda cursor, sess: sess.request(
990-
method, cursor, headers=headers, data=content
991-
),
1014+
parse_response=parse_response,
1015+
follow_up=follow_up,
9921016
client=client,
9931017
)
9941018

@@ -1409,8 +1433,11 @@ def parse_response(resp: requests.Response) -> tuple[pd.DataFrame, str | None]:
14091433
return _handle_stats_nesting(body, geopd=GEOPANDAS), body.get("next")
14101434

14111435
def follow_up(cursor: str, sess: requests.Session) -> requests.Response:
1412-
args["next_token"] = cursor
1413-
return sess.request(method, url=url, params=args, headers=headers)
1436+
# Build a fresh params dict per page so the caller's ``args``
1437+
# is never mutated (the closure used to do ``args["next_token"] = ...``).
1438+
return sess.request(
1439+
method, url=url, params={**args, "next_token": cursor}, headers=headers
1440+
)
14141441

14151442
df, response = _paginate(
14161443
req,

tests/waterdata_chunking_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,6 @@ def fetch(args):
369369
i = state["i"]
370370
state["i"] += 1
371371
if i == 1 and state["blow_up"]:
372-
from dataretrieval.waterdata.utils import RateLimited
373-
374372
raise RateLimited("429: Too many requests.")
375373
return (
376374
pd.DataFrame({"sites": list(args["sites"])}),

0 commit comments

Comments
 (0)