Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 68 additions & 34 deletions dataretrieval/waterdata/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
single-step plan — ``ChunkedCall`` has one code path either way.

Concurrency: ``multi_value_chunked`` fans every pending sub-request out
under one ``asyncio.gather`` sharing a single ``httpx.AsyncClient``;
concurrency is bounded purely by the client's connection pool
(``httpx.Limits(max_connections=N, max_keepalive_connections=N)``), so
the pool throttles. ``API_USGS_CONCURRENT`` resolves
``N``: an integer N > 1 caps connections at N; ``1`` pins a single
connection (one request at a time); the literal ``unbounded`` removes
the cap (``N=None``). The default (16) is the server-friendly sweet
spot; higher values can trip USGS burst-protection 5xx in practice. The
fan-out runs in a short-lived worker thread (an ``anyio`` blocking
portal), so it works whether or not the caller is already inside an
event loop (Jupyter / IPython / async apps).
under one ``asyncio.gather`` sharing a single ``httpx.AsyncClient``. An
``asyncio.Semaphore`` — not the client's connection pool, which is
merely sized to match — caps the sub-requests in flight at ``N``; see
:meth:`ChunkedCall._run` for why the gate must be the semaphore rather
than the pool. ``API_USGS_CONCURRENT`` resolves ``N``: an integer N > 1
allows N sub-requests in flight; ``1`` pins sequential dispatch (one
request at a time); the literal ``unbounded`` lifts the cap. The
default (16) is the server-friendly sweet spot; higher values can trip
USGS burst-protection 5xx in practice. The fan-out runs in a
short-lived worker thread (an ``anyio`` blocking portal), so it works
whether or not the caller is already inside an event loop (Jupyter /
IPython / async apps).

Retries: each sub-request is retried on a transient failure (429,
5xx, connect/read timeout) with exponential backoff + full jitter,
Expand Down Expand Up @@ -130,10 +131,10 @@ def _read_concurrency_env() -> int | None:
Returns
-------
int or None
``1`` for a single connection; an integer >1 for bounded
concurrency; ``None`` to disable the per-call cap entirely
(``unbounded`` keyword). Unset → default of
``_CONCURRENCY_DEFAULT``.
``1`` for sequential dispatch (one sub-request at a time); an
integer >1 for bounded concurrency; ``None`` to disable the
per-call cap entirely (``unbounded`` keyword). Unset → default
of ``_CONCURRENCY_DEFAULT``.
"""
raw = os.environ.get(_CONCURRENCY_ENV)
if raw is None:
Expand Down Expand Up @@ -1307,9 +1308,9 @@ class ChunkedCall:
:class:`httpx.AsyncClient`, applies the failure-precedence rules, and
combines; :meth:`resume` drives it through an ``anyio`` blocking
portal so it works whether or not the caller is already inside an
event loop. Concurrency is bounded purely by the client's connection
pool, so a single connection (``API_USGS_CONCURRENT=1``) is just a
degenerate gather.
event loop. Concurrency is bounded by a per-run ``asyncio.Semaphore``
(see :meth:`_run`), so sequential dispatch
(``API_USGS_CONCURRENT=1``) is just a degenerate gather.

A ``ChunkedCall`` is created internally when a :class:`ChunkPlan`
executes; callers reach it via :attr:`ChunkInterrupted.call` on
Expand Down Expand Up @@ -1551,20 +1552,33 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]:
``.call``; ``exc.call.resume()`` then re-issues only the unfinished
indices through this same runner.

Concurrency is bounded by the client's connection pool —
``httpx.Limits(max_connections=N, max_keepalive_connections=N)``
where ``N = max_concurrent`` (``None`` for unbounded). The gather
dispatches *every* pending sub-request and the pool throttles, so
``N=1`` is just a single-connection gather (one request at a time)
and ``total <= 1`` is just a one-element gather.
The gather dispatches *every* pending sub-request at once, but an
``asyncio.Semaphore`` caps how many fetch concurrently at
``N = max_concurrent`` — ``None`` lifts the cap, ``N=1`` runs them
one at a time. The connection pool is sized to the same ``N``
(``httpx.Limits(max_connections=N, max_keepalive_connections=N)``)
so the in-flight fetches reuse keepalive connections.

The semaphore, not the pool, is deliberately the throttle. If the
pool throttled instead, the excess sub-requests would queue
*inside* httpx waiting for a connection, and that wait counts
against the pool-acquire timeout (60 s, from ``HTTPX_DEFAULTS``).
A batch of slow pages that keeps every connection busy past that
window would then trip ``httpx.PoolTimeout`` on the queued tail —
a purely client-side failure that burns the retry budget and
surfaces as a bogus resumable ``ServiceInterrupted``. Parking
sub-requests on the semaphore keeps them out of the pool until a
slot frees, so the pool timeout only fires for a genuinely stuck
connection.

The shared client is published on :data:`_chunked_client` so
the paginated-loop helpers reuse its connection pool.

Parameters
----------
max_concurrent : int or None
Maximum simultaneous connections (the pool cap). ``None``
disables the cap.
Maximum sub-requests in flight (the semaphore value, and the
connection-pool size). ``None`` lifts the cap entirely.

Returns
-------
Expand All @@ -1583,25 +1597,45 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]:
holding the sparse completed sub-requests; ``.call.resume()``
re-issues the unfinished ones.
"""
# ``httpx.Limits()`` defaults to ``max_connections=100`` — at higher
# concurrency the pool would silently bottleneck the fan-out behind
# that cap. Set it to the resolved concurrency so the pool *is* the
# throttle (``None`` for truly unbounded).
# The semaphore is the throttle; the pool is merely sized to match
# it. Left at httpx's default client limits (``max_connections=100``,
# keepalive 20) the pool would bottleneck a wider cap or churn
# connections by keeping too few alive. See the method docstring for
# why the gate can't be the pool itself. ``unbounded``
# (``max_concurrent=None``) is a degenerate cap at the plan total — a
# semaphore that can never block — so gated is the only code path.
limits = httpx.Limits(
max_connections=max_concurrent, max_keepalive_connections=max_concurrent
)
semaphore = asyncio.Semaphore(
self.plan.total if max_concurrent is None else max_concurrent
)

async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client:
with _publish(client):
reporter = _progress.current()
if reporter is not None:
reporter.set_chunks(self.plan.total)

async def fetch_gated(
args: dict[str, Any],
) -> tuple[pd.DataFrame, httpx.Response]:
"""One fetch attempt under the concurrency gate.

The slot is held for the attempt's full duration —
every page of a paginated sub-request — but acquired
per *attempt* (this is what ``_retry`` re-invokes), so
a sub-request sleeping off a retry backoff isn't
holding a slot while it isn't touching the server.
"""
async with semaphore:
return await self.fetch(args)

async def track(
index: int, args: dict[str, Any]
) -> tuple[pd.DataFrame, httpx.Response]:
"""One sub-request (with retry) + result-store + progress tick."""
result = await _retry(lambda: self.fetch(args), self.retry_policy)
result = await _retry(lambda: fetch_gated(args), self.retry_policy)
self._chunks[index] = result
if reporter is not None:
# Chunks finish out of order under gather, so tick the
Expand All @@ -1610,7 +1644,7 @@ async def track(
return result

# Dispatch every pending sub-request concurrently; the
# connection pool (``limits``) is the only throttle.
# semaphore (via ``fetch_gated``) is the only throttle.
# ``return_exceptions`` keeps completed pairs after a sibling
# fails, so partial state stays recoverable via :meth:`resume`.
# Failure precedence, in order:
Expand Down Expand Up @@ -1706,8 +1740,8 @@ def wrapper(
limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit
plan = ChunkPlan(args, build_request, limit)
retry_policy = RetryPolicy.from_env()
# The connection-pool cap is resolved inside ``resume()`` from
# ``API_USGS_CONCURRENT``; ``1`` is a single-connection gather,
# The concurrency cap is resolved inside ``resume()`` from
# ``API_USGS_CONCURRENT``; ``1`` is a sequential gather,
# ``total <= 1`` a one-element gather — no special branch.
return plan.execute(fetch, retry_policy, finalize)

Expand Down
2 changes: 1 addition & 1 deletion dataretrieval/waterdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,7 @@ async def _fetch_once(
and iterates the cartesian product. With no chunkable inputs the
decorator passes args through unchanged. The decorator gathers every
sub-request over one shared :class:`httpx.AsyncClient` (concurrency
bounded by the connection pool, sized from ``API_USGS_CONCURRENT``)
bounded by a semaphore, sized from ``API_USGS_CONCURRENT``)
and returns a *synchronous* wrapper, so ``get_ogc_data`` keeps calling
``_fetch_once(args, finalize=...)`` synchronously. The return shape is
``(frame, response)``.
Expand Down
136 changes: 129 additions & 7 deletions tests/waterdata_chunking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import asyncio
import concurrent.futures
import datetime
import http.server
import sys
import threading
import time
import warnings
from unittest import mock
from urllib.parse import quote_plus
Expand All @@ -32,6 +35,7 @@
pytest.skip("Skip entire module on Python < 3.10", allow_module_level=True)

from dataretrieval.exceptions import DataRetrievalError
from dataretrieval.utils import HTTPX_DEFAULTS
from dataretrieval.waterdata import chunking as _chunking
from dataretrieval.waterdata import utils as _utils
from dataretrieval.waterdata.chunking import (
Expand All @@ -56,6 +60,7 @@
_retry,
_retryable,
_safe_request_bytes,
get_active_client,
multi_value_chunked,
)
from dataretrieval.waterdata.utils import _DATE_RANGE_PARAMS, _construct_api_requests
Expand Down Expand Up @@ -1314,13 +1319,15 @@ def test_iter_sub_args_passthrough_yields_a_copy():
# --- async fan-out path ----------------------------------------------------
#
# Every sub-request is gathered over one ``httpx.AsyncClient`` and
# concurrency is bounded purely by that client's connection pool, sized
# from ``API_USGS_CONCURRENT``. The conftest's ``_pin_chunker_env``
# autouse pins ``API_USGS_CONCURRENT=1`` (a single connection) for the
# whole suite; each test below raises it so the gather can dispatch
# sub-requests under a wider pool. The decorated async fetcher is the
# SAME one used on both first-run and resume. No real ``httpx.AsyncClient``
# round-trip occurs (the fakes return mock data), even though
# concurrency is bounded by an ``asyncio.Semaphore`` sized from
# ``API_USGS_CONCURRENT`` (the client's connection pool is sized to
# match, but the semaphore is the throttle — see ``ChunkedCall._run``).
# The conftest's ``_pin_chunker_env`` autouse pins
# ``API_USGS_CONCURRENT=1`` (sequential dispatch) for the whole suite;
# each test below raises it so the gather can dispatch sub-requests
# under a wider cap. The decorated async fetcher is the SAME one used on
# both first-run and resume. No real ``httpx.AsyncClient`` round-trip
# occurs (the fakes return mock data), even though
# :meth:`ChunkedCall._run` opens one for pool management.


Expand Down Expand Up @@ -1489,6 +1496,121 @@ async def fetch(args):
assert len(df) == len(calls)


# Eight 20-char sites against ``url_limit=240`` (base 200): any two atoms
# joined overflow the 40-byte budget, so the planner lands on eight
# singleton sub-requests — enough fan-out to observe the concurrency gate.
_EIGHT_SINGLETON_SITES = [f"S{i}" * 10 for i in range(8)]


def _concurrency_probe(in_flight):
"""An async fetch that records the high-water mark of simultaneous
invocations in ``in_flight`` (keys ``now``/``max``). The ``sleep(0)``
yields to the loop while "in flight", so overlap is observable."""

async def fetch_async(args):
in_flight["now"] += 1
in_flight["max"] = max(in_flight["max"], in_flight["now"])
await asyncio.sleep(0)
in_flight["now"] -= 1
return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response()

return fetch_async


@pytest.mark.parametrize(
("cap", "expected_high_water"),
[
pytest.param(2, 2, id="capped"),
pytest.param("unbounded", len(_EIGHT_SINGLETON_SITES), id="unbounded"),
],
)
def test_fan_out_in_flight_high_water_mark_is_the_cap(
monkeypatch, cap, expected_high_water
):
"""The fetch-level high-water mark of simultaneous sub-requests IS the
``API_USGS_CONCURRENT`` cap — genuine parallelism up to it, never past
it — and ``unbounded`` degenerates to every sub-request at once.

Regression: the cap used to be enforced only by the shared client's
connection-pool size, so sub-requests beyond it queued on connection
*acquisition*, subject to the client's pool-acquire timeout (see
``ChunkedCall._run``). The semaphore parks excess sub-requests before
they touch the pool.
"""
in_flight = {"now": 0, "max": 0}
fetch = _async_chunked_fetch(
monkeypatch, _concurrency_probe(in_flight), max_concurrent=cap
)

df, _ = fetch({"sites": list(_EIGHT_SINGLETON_SITES)})

assert len(df) == len(_EIGHT_SINGLETON_SITES) # all sub-requests completed
assert in_flight["max"] == expected_high_water


def test_fan_out_outlives_pool_timeout_on_real_transport(monkeypatch):
"""End-to-end regression for the pool-timeout starvation bug: the
fan-out must survive every pooled connection staying busy past the
client's pool-acquire timeout (the stall mechanism is documented on
``ChunkedCall._run``; at production scale think a batch of large,
slowly-streaming pages).

Sub-requests here send real HTTP to a slow localhost server through
the chunker's shared client — fakes can't catch this, since
``MockTransport`` bypasses the connection pool. With the pool as the
only throttle, 2 connections busy for 0.35 s each and the 0.2 s pool
timeout pinned below, the 2 queued sub-requests sat out the full
timeout with no completion to reset their clocks →
``httpx.PoolTimeout`` → (retries exhausted, ``API_USGS_RETRIES=0``)
a spurious resumable ``ServiceInterrupted``. Gated by the semaphore,
queued sub-requests never touch the pool and the call completes.
"""

class _SlowHandler(http.server.BaseHTTPRequestHandler):
protocol_version = "HTTP/1.1" # keepalive, so pooled connections reuse

def do_GET(self):
time.sleep(0.35) # hold the connection busy past the pool timeout
body = b'{"ok": true}'
self.send_response(200)
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)

def log_message(self, *args): # keep pytest output clean
pass

server = http.server.ThreadingHTTPServer(("127.0.0.1", 0), _SlowHandler)
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
# Everything past the thread start is in the try, so any setup failure
# (monkeypatch, decorator construction) still tears the server down.
try:
url = f"http://127.0.0.1:{server.server_address[1]}/"

# Scale the production pool timeout (see ``HTTPX_DEFAULTS``) down to
# 0.2 s so the pre-semaphore failure mode reproduces in test time.
monkeypatch.setitem(
HTTPX_DEFAULTS, "timeout", httpx.Timeout(5.0, connect=1.0, pool=0.2)
)

async def fetch_async(args):
client = get_active_client()
assert client is not None, "sub-request must use the shared client"
resp = await client.get(url)
assert resp.status_code == 200
return pd.DataFrame({"id": [_atom_id(args)]}), resp

sites = _EIGHT_SINGLETON_SITES[:4] # 2 in flight + 2 queued, 2 waves
fetch = _async_chunked_fetch(monkeypatch, fetch_async, max_concurrent=2)
df, _ = fetch({"sites": sites})
assert len(df) == len(sites)
finally:
server.shutdown()
server.server_close()
thread.join(timeout=5)


def test_async_fan_out_runs_inside_running_event_loop(monkeypatch):
"""The parallel fan-out works even when the caller is already inside a
running event loop (Jupyter / async apps): the anyio blocking portal
Expand Down
Loading