diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 6626aba9..af60bbc0 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -10,17 +10,18 @@ single-step plan — ``ChunkedCall`` has one code path either way. Concurrency: ``multi_value_chunked`` fans every pending sub-request out -under one ``asyncio.gather`` sharing a single ``httpx.AsyncClient``; -concurrency is bounded purely by the client's connection pool -(``httpx.Limits(max_connections=N, max_keepalive_connections=N)``), so -the pool throttles. ``API_USGS_CONCURRENT`` resolves -``N``: an integer N > 1 caps connections at N; ``1`` pins a single -connection (one request at a time); the literal ``unbounded`` removes -the cap (``N=None``). The default (16) is the server-friendly sweet -spot; higher values can trip USGS burst-protection 5xx in practice. The -fan-out runs in a short-lived worker thread (an ``anyio`` blocking -portal), so it works whether or not the caller is already inside an -event loop (Jupyter / IPython / async apps). +under one ``asyncio.gather`` sharing a single ``httpx.AsyncClient``. An +``asyncio.Semaphore`` — not the client's connection pool, which is +merely sized to match — caps the sub-requests in flight at ``N``; see +:meth:`ChunkedCall._run` for why the gate must be the semaphore rather +than the pool. ``API_USGS_CONCURRENT`` resolves ``N``: an integer N > 1 +allows N sub-requests in flight; ``1`` pins sequential dispatch (one +request at a time); the literal ``unbounded`` lifts the cap. The +default (16) is the server-friendly sweet spot; higher values can trip +USGS burst-protection 5xx in practice. The fan-out runs in a +short-lived worker thread (an ``anyio`` blocking portal), so it works +whether or not the caller is already inside an event loop (Jupyter / +IPython / async apps). Retries: each sub-request is retried on a transient failure (429, 5xx, connect/read timeout) with exponential backoff + full jitter, @@ -130,10 +131,10 @@ def _read_concurrency_env() -> int | None: Returns ------- int or None - ``1`` for a single connection; an integer >1 for bounded - concurrency; ``None`` to disable the per-call cap entirely - (``unbounded`` keyword). Unset → default of - ``_CONCURRENCY_DEFAULT``. + ``1`` for sequential dispatch (one sub-request at a time); an + integer >1 for bounded concurrency; ``None`` to disable the + per-call cap entirely (``unbounded`` keyword). Unset → default + of ``_CONCURRENCY_DEFAULT``. """ raw = os.environ.get(_CONCURRENCY_ENV) if raw is None: @@ -1307,9 +1308,9 @@ class ChunkedCall: :class:`httpx.AsyncClient`, applies the failure-precedence rules, and combines; :meth:`resume` drives it through an ``anyio`` blocking portal so it works whether or not the caller is already inside an - event loop. Concurrency is bounded purely by the client's connection - pool, so a single connection (``API_USGS_CONCURRENT=1``) is just a - degenerate gather. + event loop. Concurrency is bounded by a per-run ``asyncio.Semaphore`` + (see :meth:`_run`), so sequential dispatch + (``API_USGS_CONCURRENT=1``) is just a degenerate gather. A ``ChunkedCall`` is created internally when a :class:`ChunkPlan` executes; callers reach it via :attr:`ChunkInterrupted.call` on @@ -1551,20 +1552,33 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: ``.call``; ``exc.call.resume()`` then re-issues only the unfinished indices through this same runner. - Concurrency is bounded by the client's connection pool — - ``httpx.Limits(max_connections=N, max_keepalive_connections=N)`` - where ``N = max_concurrent`` (``None`` for unbounded). The gather - dispatches *every* pending sub-request and the pool throttles, so - ``N=1`` is just a single-connection gather (one request at a time) - and ``total <= 1`` is just a one-element gather. + The gather dispatches *every* pending sub-request at once, but an + ``asyncio.Semaphore`` caps how many fetch concurrently at + ``N = max_concurrent`` — ``None`` lifts the cap, ``N=1`` runs them + one at a time. The connection pool is sized to the same ``N`` + (``httpx.Limits(max_connections=N, max_keepalive_connections=N)``) + so the in-flight fetches reuse keepalive connections. + + The semaphore, not the pool, is deliberately the throttle. If the + pool throttled instead, the excess sub-requests would queue + *inside* httpx waiting for a connection, and that wait counts + against the pool-acquire timeout (60 s, from ``HTTPX_DEFAULTS``). + A batch of slow pages that keeps every connection busy past that + window would then trip ``httpx.PoolTimeout`` on the queued tail — + a purely client-side failure that burns the retry budget and + surfaces as a bogus resumable ``ServiceInterrupted``. Parking + sub-requests on the semaphore keeps them out of the pool until a + slot frees, so the pool timeout only fires for a genuinely stuck + connection. + The shared client is published on :data:`_chunked_client` so the paginated-loop helpers reuse its connection pool. Parameters ---------- max_concurrent : int or None - Maximum simultaneous connections (the pool cap). ``None`` - disables the cap. + Maximum sub-requests in flight (the semaphore value, and the + connection-pool size). ``None`` lifts the cap entirely. Returns ------- @@ -1583,13 +1597,19 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: holding the sparse completed sub-requests; ``.call.resume()`` re-issues the unfinished ones. """ - # ``httpx.Limits()`` defaults to ``max_connections=100`` — at higher - # concurrency the pool would silently bottleneck the fan-out behind - # that cap. Set it to the resolved concurrency so the pool *is* the - # throttle (``None`` for truly unbounded). + # The semaphore is the throttle; the pool is merely sized to match + # it. Left at httpx's default client limits (``max_connections=100``, + # keepalive 20) the pool would bottleneck a wider cap or churn + # connections by keeping too few alive. See the method docstring for + # why the gate can't be the pool itself. ``unbounded`` + # (``max_concurrent=None``) is a degenerate cap at the plan total — a + # semaphore that can never block — so gated is the only code path. limits = httpx.Limits( max_connections=max_concurrent, max_keepalive_connections=max_concurrent ) + semaphore = asyncio.Semaphore( + self.plan.total if max_concurrent is None else max_concurrent + ) async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: with _publish(client): @@ -1597,11 +1617,25 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: if reporter is not None: reporter.set_chunks(self.plan.total) + async def fetch_gated( + args: dict[str, Any], + ) -> tuple[pd.DataFrame, httpx.Response]: + """One fetch attempt under the concurrency gate. + + The slot is held for the attempt's full duration — + every page of a paginated sub-request — but acquired + per *attempt* (this is what ``_retry`` re-invokes), so + a sub-request sleeping off a retry backoff isn't + holding a slot while it isn't touching the server. + """ + async with semaphore: + return await self.fetch(args) + async def track( index: int, args: dict[str, Any] ) -> tuple[pd.DataFrame, httpx.Response]: """One sub-request (with retry) + result-store + progress tick.""" - result = await _retry(lambda: self.fetch(args), self.retry_policy) + result = await _retry(lambda: fetch_gated(args), self.retry_policy) self._chunks[index] = result if reporter is not None: # Chunks finish out of order under gather, so tick the @@ -1610,7 +1644,7 @@ async def track( return result # Dispatch every pending sub-request concurrently; the - # connection pool (``limits``) is the only throttle. + # semaphore (via ``fetch_gated``) is the only throttle. # ``return_exceptions`` keeps completed pairs after a sibling # fails, so partial state stays recoverable via :meth:`resume`. # Failure precedence, in order: @@ -1706,8 +1740,8 @@ def wrapper( limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit plan = ChunkPlan(args, build_request, limit) retry_policy = RetryPolicy.from_env() - # The connection-pool cap is resolved inside ``resume()`` from - # ``API_USGS_CONCURRENT``; ``1`` is a single-connection gather, + # The concurrency cap is resolved inside ``resume()`` from + # ``API_USGS_CONCURRENT``; ``1`` is a sequential gather, # ``total <= 1`` a one-element gather — no special branch. return plan.execute(fetch, retry_policy, finalize) diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 7a2c0d5c..8b74706a 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -1551,7 +1551,7 @@ async def _fetch_once( and iterates the cartesian product. With no chunkable inputs the decorator passes args through unchanged. The decorator gathers every sub-request over one shared :class:`httpx.AsyncClient` (concurrency - bounded by the connection pool, sized from ``API_USGS_CONCURRENT``) + bounded by a semaphore, sized from ``API_USGS_CONCURRENT``) and returns a *synchronous* wrapper, so ``get_ogc_data`` keeps calling ``_fetch_once(args, finalize=...)`` synchronously. The return shape is ``(frame, response)``. diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 0857da4a..37e9b999 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -18,7 +18,10 @@ import asyncio import concurrent.futures import datetime +import http.server import sys +import threading +import time import warnings from unittest import mock from urllib.parse import quote_plus @@ -32,6 +35,7 @@ pytest.skip("Skip entire module on Python < 3.10", allow_module_level=True) from dataretrieval.exceptions import DataRetrievalError +from dataretrieval.utils import HTTPX_DEFAULTS from dataretrieval.waterdata import chunking as _chunking from dataretrieval.waterdata import utils as _utils from dataretrieval.waterdata.chunking import ( @@ -56,6 +60,7 @@ _retry, _retryable, _safe_request_bytes, + get_active_client, multi_value_chunked, ) from dataretrieval.waterdata.utils import _DATE_RANGE_PARAMS, _construct_api_requests @@ -1314,13 +1319,15 @@ def test_iter_sub_args_passthrough_yields_a_copy(): # --- async fan-out path ---------------------------------------------------- # # Every sub-request is gathered over one ``httpx.AsyncClient`` and -# concurrency is bounded purely by that client's connection pool, sized -# from ``API_USGS_CONCURRENT``. The conftest's ``_pin_chunker_env`` -# autouse pins ``API_USGS_CONCURRENT=1`` (a single connection) for the -# whole suite; each test below raises it so the gather can dispatch -# sub-requests under a wider pool. The decorated async fetcher is the -# SAME one used on both first-run and resume. No real ``httpx.AsyncClient`` -# round-trip occurs (the fakes return mock data), even though +# concurrency is bounded by an ``asyncio.Semaphore`` sized from +# ``API_USGS_CONCURRENT`` (the client's connection pool is sized to +# match, but the semaphore is the throttle — see ``ChunkedCall._run``). +# The conftest's ``_pin_chunker_env`` autouse pins +# ``API_USGS_CONCURRENT=1`` (sequential dispatch) for the whole suite; +# each test below raises it so the gather can dispatch sub-requests +# under a wider cap. The decorated async fetcher is the SAME one used on +# both first-run and resume. No real ``httpx.AsyncClient`` round-trip +# occurs (the fakes return mock data), even though # :meth:`ChunkedCall._run` opens one for pool management. @@ -1489,6 +1496,121 @@ async def fetch(args): assert len(df) == len(calls) +# Eight 20-char sites against ``url_limit=240`` (base 200): any two atoms +# joined overflow the 40-byte budget, so the planner lands on eight +# singleton sub-requests — enough fan-out to observe the concurrency gate. +_EIGHT_SINGLETON_SITES = [f"S{i}" * 10 for i in range(8)] + + +def _concurrency_probe(in_flight): + """An async fetch that records the high-water mark of simultaneous + invocations in ``in_flight`` (keys ``now``/``max``). The ``sleep(0)`` + yields to the loop while "in flight", so overlap is observable.""" + + async def fetch_async(args): + in_flight["now"] += 1 + in_flight["max"] = max(in_flight["max"], in_flight["now"]) + await asyncio.sleep(0) + in_flight["now"] -= 1 + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + return fetch_async + + +@pytest.mark.parametrize( + ("cap", "expected_high_water"), + [ + pytest.param(2, 2, id="capped"), + pytest.param("unbounded", len(_EIGHT_SINGLETON_SITES), id="unbounded"), + ], +) +def test_fan_out_in_flight_high_water_mark_is_the_cap( + monkeypatch, cap, expected_high_water +): + """The fetch-level high-water mark of simultaneous sub-requests IS the + ``API_USGS_CONCURRENT`` cap — genuine parallelism up to it, never past + it — and ``unbounded`` degenerates to every sub-request at once. + + Regression: the cap used to be enforced only by the shared client's + connection-pool size, so sub-requests beyond it queued on connection + *acquisition*, subject to the client's pool-acquire timeout (see + ``ChunkedCall._run``). The semaphore parks excess sub-requests before + they touch the pool. + """ + in_flight = {"now": 0, "max": 0} + fetch = _async_chunked_fetch( + monkeypatch, _concurrency_probe(in_flight), max_concurrent=cap + ) + + df, _ = fetch({"sites": list(_EIGHT_SINGLETON_SITES)}) + + assert len(df) == len(_EIGHT_SINGLETON_SITES) # all sub-requests completed + assert in_flight["max"] == expected_high_water + + +def test_fan_out_outlives_pool_timeout_on_real_transport(monkeypatch): + """End-to-end regression for the pool-timeout starvation bug: the + fan-out must survive every pooled connection staying busy past the + client's pool-acquire timeout (the stall mechanism is documented on + ``ChunkedCall._run``; at production scale think a batch of large, + slowly-streaming pages). + + Sub-requests here send real HTTP to a slow localhost server through + the chunker's shared client — fakes can't catch this, since + ``MockTransport`` bypasses the connection pool. With the pool as the + only throttle, 2 connections busy for 0.35 s each and the 0.2 s pool + timeout pinned below, the 2 queued sub-requests sat out the full + timeout with no completion to reset their clocks → + ``httpx.PoolTimeout`` → (retries exhausted, ``API_USGS_RETRIES=0``) + a spurious resumable ``ServiceInterrupted``. Gated by the semaphore, + queued sub-requests never touch the pool and the call completes. + """ + + class _SlowHandler(http.server.BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" # keepalive, so pooled connections reuse + + def do_GET(self): + time.sleep(0.35) # hold the connection busy past the pool timeout + body = b'{"ok": true}' + self.send_response(200) + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def log_message(self, *args): # keep pytest output clean + pass + + server = http.server.ThreadingHTTPServer(("127.0.0.1", 0), _SlowHandler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + # Everything past the thread start is in the try, so any setup failure + # (monkeypatch, decorator construction) still tears the server down. + try: + url = f"http://127.0.0.1:{server.server_address[1]}/" + + # Scale the production pool timeout (see ``HTTPX_DEFAULTS``) down to + # 0.2 s so the pre-semaphore failure mode reproduces in test time. + monkeypatch.setitem( + HTTPX_DEFAULTS, "timeout", httpx.Timeout(5.0, connect=1.0, pool=0.2) + ) + + async def fetch_async(args): + client = get_active_client() + assert client is not None, "sub-request must use the shared client" + resp = await client.get(url) + assert resp.status_code == 200 + return pd.DataFrame({"id": [_atom_id(args)]}), resp + + sites = _EIGHT_SINGLETON_SITES[:4] # 2 in flight + 2 queued, 2 waves + fetch = _async_chunked_fetch(monkeypatch, fetch_async, max_concurrent=2) + df, _ = fetch({"sites": sites}) + assert len(df) == len(sites) + finally: + server.shutdown() + server.server_close() + thread.join(timeout=5) + + def test_async_fan_out_runs_inside_running_event_loop(monkeypatch): """The parallel fan-out works even when the caller is already inside a running event loop (Jupyter / async apps): the anyio blocking portal