Skip to content

Commit f2a2503

Browse files
nodohsclaude
andcommitted
fix(nwis): forward get_record state, fix major-filter validation, reject bad format
Several fixes in the deprecated `nwis` module, surfaced by the package review. - `get_record(state=…)` was accepted and documented but never forwarded. It now reaches the request as the NWIS `stateCd` major filter; previously it was silently dropped, producing a confusing "Bad Request" when used without `sites`. - The "must specify a major filter" validation was defeated: callers injected a filter key unconditionally (e.g. `kwargs["sites"] = kwargs.pop("sites", sites)` set the key to None when absent), so the membership-based guards in `query_waterservices`/`query_waterdata` always passed and a filterless request went out as a confusing "Bad Request" instead of the intended TypeError. The guards now require a filter that is present *and* non-None, rejecting an unset filter at the chokepoint — no per-getter pre-filtering needed. (`utils.query` already strips None-valued params, so this is purely a validation change.) - `get_dv`/`get_iv`/`get_discharge_peaks`/`get_stats` each parse a fixed response body but passed `format=` explicitly alongside `**kwargs`, so e.g. `get_dv(sites=…, format="rdb")` raised "multiple values for 'format'". A caller-supplied non-native `format` is now rejected with a clear ValueError via the shared `_reject_unexpected_format` helper (json for dv/iv, rdb for peaks/stats). - `get_ratings` is per-site; called without one it issued a request returning an unhelpful error page. It now fails fast with a clear TypeError (also covering `get_record(service="ratings")`). Adds regression tests for each. Full nwis suite passes; ruff and mypy --strict clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent ecf2833 commit f2a2503

2 files changed

Lines changed: 154 additions & 4 deletions

File tree

dataretrieval/nwis.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,21 @@ def _parse_json_or_raise(response: httpx.Response) -> pd.DataFrame:
126126
raise
127127

128128

129+
def _reject_unexpected_format(
130+
kwargs: dict[str, Any], func_name: str, expected: str
131+
) -> None:
132+
"""Drop ``format`` from ``kwargs`` and reject any value other than ``expected``.
133+
134+
These getters always request a fixed ``format`` and parse that specific body,
135+
so a caller-supplied ``format`` would either collide with the explicit
136+
``format=`` argument or be silently overridden; reject it explicitly instead.
137+
"""
138+
if kwargs.pop("format", expected) != expected:
139+
raise ValueError(
140+
f"{func_name} returns {expected.upper()} and does not accept a `format`."
141+
)
142+
143+
129144
def format_response(
130145
df: pd.DataFrame, service: str | None = None, **kwargs: Any
131146
) -> pd.DataFrame:
@@ -277,6 +292,7 @@ def get_discharge_peaks(
277292
kwargs["end_date"] = kwargs.pop("end_date", end)
278293
kwargs["multi_index"] = multi_index
279294

295+
_reject_unexpected_format(kwargs, "get_discharge_peaks", "rdb")
280296
response = query_waterdata("peaks", format="rdb", ssl_check=ssl_check, **kwargs)
281297

282298
# Parse raw (read_rdb), not _read_rdb — the latter already runs
@@ -355,6 +371,8 @@ def get_stats(
355371
"""
356372
_check_sites_value_types(sites)
357373

374+
# get_stats parses an RDB body; reject a caller-supplied non-RDB `format`.
375+
_reject_unexpected_format(kwargs, "get_stats", "rdb")
358376
response = query_waterservices(
359377
service="stat", sites=sites, ssl_check=ssl_check, **kwargs
360378
)
@@ -392,11 +410,13 @@ def query_waterdata(
392410
"se_latitude_va",
393411
]
394412

395-
if not any(key in kwargs for key in major_params + bbox_params):
413+
# Present *and* non-None — see query_waterservices for why membership alone
414+
# would let an unset filter slip through.
415+
if not any(kwargs.get(key) is not None for key in major_params + bbox_params):
396416
raise TypeError("Query must specify a major filter: site_no, stateCd, bBox")
397417

398-
elif any(key in kwargs for key in bbox_params) and not all(
399-
key in kwargs for key in bbox_params
418+
elif any(kwargs.get(key) is not None for key in bbox_params) and not all(
419+
kwargs.get(key) is not None for key in bbox_params
400420
):
401421
raise TypeError("One or more lat/long coordinates missing or invalid.")
402422

@@ -453,8 +473,15 @@ def query_waterservices(
453473
The response object from the API request to the web service
454474
455475
"""
476+
# A major filter must be present *and* non-None. Membership alone is not
477+
# enough: callers may inject an unset filter (e.g. sites=None), and
478+
# utils.query() later strips None-valued params, so a plain `"sites" in
479+
# kwargs` would wave a filterless request through to a confusing "Bad
480+
# Request". Checking the value keeps that decision in one place instead of
481+
# forcing every getter to pre-filter its own kwargs.
456482
if not any(
457-
key in kwargs for key in ["sites", "stateCd", "bBox", "huc", "countyCd"]
483+
kwargs.get(key) is not None
484+
for key in ["sites", "stateCd", "bBox", "huc", "countyCd"]
458485
):
459486
raise TypeError(
460487
"Query must specify a major filter: sites, stateCd, bBox, huc, or countyCd"
@@ -537,6 +564,7 @@ def get_dv(
537564
kwargs["sites"] = kwargs.pop("sites", sites)
538565
kwargs["multi_index"] = multi_index
539566

567+
_reject_unexpected_format(kwargs, "get_dv", "json")
540568
response = query_waterservices("dv", format="json", ssl_check=ssl_check, **kwargs)
541569
df = _parse_json_or_raise(response)
542570

@@ -724,6 +752,7 @@ def get_iv(
724752
kwargs["sites"] = kwargs.pop("sites", sites)
725753
kwargs["multi_index"] = multi_index
726754

755+
_reject_unexpected_format(kwargs, "get_iv", "json")
727756
response = query_waterservices(
728757
service="iv", format="json", ssl_check=ssl_check, **kwargs
729758
)
@@ -788,6 +817,11 @@ def get_ratings(
788817
789818
"""
790819
site = kwargs.pop("site_no", site)
820+
# The ratings endpoint is per-site; without one it would issue a request
821+
# that returns an unhelpful error page. Fail fast with a clear message
822+
# (also covers get_record(service="ratings") called without a site).
823+
if site is None:
824+
raise TypeError("get_ratings requires a `site` (USGS site number).")
791825

792826
payload = {}
793827
url = WATERDATA_BASE_URL + "nwisweb/get_ratings/"
@@ -956,6 +990,12 @@ def get_record(
956990
if service not in WATERSERVICES_SERVICES + WATERDATA_SERVICES:
957991
raise TypeError(f"Unrecognized service: {service}")
958992

993+
# Forward the documented `state` filter as the NWIS `stateCd` major filter;
994+
# it was previously accepted but silently ignored, which produced a
995+
# confusing "Bad Request" when used without `sites`.
996+
if state is not None:
997+
kwargs.setdefault("stateCd", state)
998+
959999
if service == "iv":
9601000
df, _ = get_iv(
9611001
sites=sites,

tests/nwis_test.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
NWIS_Metadata,
1414
_read_rdb,
1515
get_discharge_measurements,
16+
get_discharge_peaks,
17+
get_dv,
1618
get_gwlevels,
1719
get_info,
1820
get_iv,
1921
get_pmcodes,
2022
get_qwdata,
2123
get_record,
24+
get_stats,
2225
get_water_use,
2326
preformat_peaks_response,
2427
what_sites,
@@ -120,6 +123,113 @@ def test_preformat_peaks_response():
120123
# Removed defunct gwlevels tests.
121124

122125

126+
def test_get_dv_requires_major_filter():
127+
"""Regression: get_dv() with no major filter must raise the documented
128+
TypeError. The getters injected ``kwargs["sites"] = kwargs.pop("sites",
129+
sites)``, which always set the key (None when absent) and so defeated
130+
query_waterservices' filter check, yielding a confusing Bad Request."""
131+
with warnings.catch_warnings():
132+
warnings.simplefilter("ignore")
133+
with pytest.raises(TypeError, match="major filter"):
134+
get_dv()
135+
136+
137+
def test_get_dv_rejects_non_json_format():
138+
"""Regression: get_dv passed format= explicitly alongside **kwargs, so
139+
get_dv(..., format="rdb") raised 'multiple values for format'. It now
140+
rejects a non-json format with a clear ValueError (it parses JSON)."""
141+
with warnings.catch_warnings():
142+
warnings.simplefilter("ignore")
143+
with pytest.raises(ValueError, match="JSON"):
144+
get_dv(sites="01646500", format="rdb")
145+
146+
147+
def test_get_record_forwards_state_as_statecd(monkeypatch):
148+
"""Regression: get_record's documented `state` arg was accepted but never
149+
forwarded. It now reaches the request as the NWIS `stateCd` major filter."""
150+
import dataretrieval.nwis as nwis_mod
151+
152+
captured: dict = {}
153+
154+
def fake_query_waterservices(service, format=None, ssl_check=True, **kw):
155+
captured.update(kw)
156+
raise RuntimeError("stop before network")
157+
158+
monkeypatch.setattr(nwis_mod, "query_waterservices", fake_query_waterservices)
159+
with warnings.catch_warnings():
160+
warnings.simplefilter("ignore")
161+
with pytest.raises(RuntimeError, match="stop"):
162+
get_record(state="OH", service="dv")
163+
assert captured.get("stateCd") == "OH"
164+
165+
166+
def test_get_stats_requires_major_filter():
167+
"""Regression: get_stats passed ``sites=sites`` explicitly, so the key was
168+
always present (None when absent) and defeated query_waterservices' filter
169+
check -- get_stats() reached the network and returned a confusing Bad
170+
Request instead of the documented TypeError."""
171+
with warnings.catch_warnings():
172+
warnings.simplefilter("ignore")
173+
with pytest.raises(TypeError, match="major filter"):
174+
get_stats()
175+
176+
177+
@pytest.mark.parametrize("service", ["stat", "site"])
178+
def test_get_record_requires_major_filter(service):
179+
"""Regression: get_record(service="stat"/"site") forwarded sites=None into
180+
get_stats/get_info, defeating the major-filter check. With no filter it must
181+
raise the documented TypeError rather than reach the network."""
182+
with warnings.catch_warnings():
183+
warnings.simplefilter("ignore")
184+
with pytest.raises(TypeError, match="major filter"):
185+
get_record(service=service)
186+
187+
188+
def test_get_discharge_peaks_rejects_format():
189+
"""Regression: get_discharge_peaks passed format="rdb" explicitly alongside
190+
**kwargs, so any caller-supplied format raised 'multiple values for format'.
191+
A non-native format is now rejected with a clear ValueError (it parses RDB)."""
192+
with warnings.catch_warnings():
193+
warnings.simplefilter("ignore")
194+
with pytest.raises(ValueError, match="RDB"):
195+
get_discharge_peaks(sites="01491000", format="json")
196+
197+
198+
def test_get_discharge_peaks_accepts_native_format(monkeypatch):
199+
"""The reported collision was that even the *native* format="rdb" raised
200+
'multiple values for format'. Popping it first resolves the collision, so
201+
format="rdb" now reaches the request rather than crashing."""
202+
import dataretrieval.nwis as nwis_mod
203+
204+
def fake_query_waterdata(service, format=None, ssl_check=True, **kw):
205+
raise RuntimeError("stop before network")
206+
207+
monkeypatch.setattr(nwis_mod, "query_waterdata", fake_query_waterdata)
208+
with warnings.catch_warnings():
209+
warnings.simplefilter("ignore")
210+
with pytest.raises(RuntimeError, match="stop"):
211+
get_discharge_peaks(sites="01491000", format="rdb")
212+
213+
214+
def test_get_stats_rejects_non_rdb_format():
215+
"""get_stats parses RDB; a caller-supplied non-RDB format would be requested
216+
and then mis-parsed. It is now rejected with a clear ValueError."""
217+
with warnings.catch_warnings():
218+
warnings.simplefilter("ignore")
219+
with pytest.raises(ValueError, match="RDB"):
220+
get_stats(sites="01646500", format="json")
221+
222+
223+
def test_get_ratings_requires_site():
224+
"""The ratings endpoint is per-site; get_ratings() / get_record(
225+
service="ratings") with no site previously issued a request that returned an
226+
unhelpful error page. It now fails fast with a clear TypeError."""
227+
with warnings.catch_warnings():
228+
warnings.simplefilter("ignore")
229+
with pytest.raises(TypeError, match="requires a `site`"):
230+
get_record(service="ratings")
231+
232+
123233
class TestDeprecationWarnings:
124234
"""Verify per-function DeprecationWarning fires with the right replacement.
125235

0 commit comments

Comments
 (0)