Skip to content

Commit fa8387e

Browse files
thodson-usgsclaude
andcommitted
Widen _check_monitoring_location_id to accept iterables of strings
The previous `(str, list)` isinstance check rejected legitimate inputs: tuple, pandas.Series, pandas.Index, numpy.ndarray, generators. Pre-fix these would round-trip through requests and either work (tuple) or silently break the URL (numpy/pandas). Post-fix the function: - Accepts any non-string iterable whose elements are strings - Materializes the iterable to a list so downstream comma-join / POST-CQL2 logic in _construct_api_requests keeps working uniformly - Returns the (possibly-normalized) value, so callers reassign: monitoring_location_id = _check_monitoring_location_id(monitoring_location_id) - Rejects Mapping (e.g. dict) explicitly — iterating a dict yields keys, which is a footgun Live-verified against api.waterdata.usgs.gov: passing tuple, pd.Series, pd.Index, and np.ndarray of "USGS-01646500" all return 3 rows for the 2024-06-01/2024-06-03 window. Passing pd.Series([1646500]) raises TypeError("monitoring_location_id elements must be strings, not int...") before any network call. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c6b262e commit fa8387e

3 files changed

Lines changed: 93 additions & 33 deletions

File tree

dataretrieval/waterdata/api.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def get_daily(
214214
... approval_status = "Approved",
215215
... time = "2024-01-01/.."
216216
"""
217-
_check_monitoring_location_id(monitoring_location_id)
217+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
218218
service = "daily"
219219
output_id = "daily_id"
220220

@@ -403,7 +403,7 @@ def get_continuous(
403403
... filter_lang="cql-text",
404404
... )
405405
"""
406-
_check_monitoring_location_id(monitoring_location_id)
406+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
407407
service = "continuous"
408408
output_id = "continuous_id"
409409

@@ -702,7 +702,7 @@ def get_monitoring_locations(
702702
... properties=["monitoring_location_id", "state_name", "country_name"],
703703
... )
704704
"""
705-
_check_monitoring_location_id(monitoring_location_id)
705+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
706706
service = "monitoring-locations"
707707
output_id = "monitoring_location_id"
708708

@@ -926,7 +926,7 @@ def get_time_series_metadata(
926926
... begin="1990-01-01/..",
927927
... )
928928
"""
929-
_check_monitoring_location_id(monitoring_location_id)
929+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
930930
service = "time-series-metadata"
931931
output_id = "time_series_id"
932932

@@ -1106,7 +1106,7 @@ def get_latest_continuous(
11061106
... monitoring_location_id=["USGS-05114000", "USGS-09423350"]
11071107
... )
11081108
"""
1109-
_check_monitoring_location_id(monitoring_location_id)
1109+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
11101110
service = "latest-continuous"
11111111
output_id = "latest_continuous_id"
11121112

@@ -1288,7 +1288,7 @@ def get_latest_daily(
12881288
... monitoring_location_id=["USGS-05114000", "USGS-09423350"]
12891289
... )
12901290
"""
1291-
_check_monitoring_location_id(monitoring_location_id)
1291+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
12921292
service = "latest-daily"
12931293
output_id = "latest_daily_id"
12941294

@@ -1469,7 +1469,7 @@ def get_field_measurements(
14691469
... time = "P20Y"
14701470
... )
14711471
"""
1472-
_check_monitoring_location_id(monitoring_location_id)
1472+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
14731473
service = "field-measurements"
14741474
output_id = "field_measurement_id"
14751475

@@ -1925,7 +1925,7 @@ def get_stats_por(
19251925
... )
19261926
"""
19271927
# Build argument dictionary, omitting None values
1928-
_check_monitoring_location_id(monitoring_location_id)
1928+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
19291929
params = _get_args(locals(), exclude={"expand_percentiles"})
19301930

19311931
return get_stats_data(
@@ -2055,7 +2055,7 @@ def get_stats_date_range(
20552055
... )
20562056
"""
20572057
# Build argument dictionary, omitting None values
2058-
_check_monitoring_location_id(monitoring_location_id)
2058+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
20592059
params = _get_args(locals(), exclude={"expand_percentiles"})
20602060

20612061
return get_stats_data(
@@ -2228,7 +2228,7 @@ def get_channel(
22282228
... monitoring_location_id="USGS-02238500",
22292229
... )
22302230
"""
2231-
_check_monitoring_location_id(monitoring_location_id)
2231+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
22322232
service = "channel-measurements"
22332233
output_id = "channel_measurements_id"
22342234

dataretrieval/waterdata/utils.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import re
7+
from collections.abc import Iterable, Mapping
78
from datetime import datetime
89
from typing import Any, get_args
910

@@ -1155,44 +1156,53 @@ def _check_profiles(
11551156
_MONITORING_LOCATION_ID_RE = re.compile(r"^.+-.+$")
11561157

11571158

1158-
def _check_monitoring_location_id(
1159-
monitoring_location_id: str | list[str] | None,
1160-
) -> None:
1161-
"""Validate the format of a monitoring_location_id value.
1159+
def _check_monitoring_location_id(monitoring_location_id):
1160+
"""Validate and normalize a ``monitoring_location_id`` value.
11621161
11631162
Parameters
11641163
----------
1165-
monitoring_location_id : str, list of str, or None
1166-
One or more monitoring location identifiers.
1164+
monitoring_location_id : None, str, or iterable of str
1165+
Accepts ``None``, a single string, or any non-string iterable of
1166+
strings (``list``, ``tuple``, ``pandas.Series``, ``pandas.Index``,
1167+
``numpy.ndarray``, ...). Iterables are materialized to a ``list``
1168+
so downstream code that branches on ``isinstance(v, list)`` keeps
1169+
working.
1170+
1171+
Returns
1172+
-------
1173+
None, str, or list of str
1174+
``None`` and ``str`` inputs are returned unchanged; non-string
1175+
iterables are returned as a ``list``.
11671176
11681177
Raises
11691178
------
11701179
TypeError
1171-
If any identifier is not a string (e.g. an integer was passed).
1180+
If the input is not ``None``, a string, or an iterable, or if any
1181+
iterable element is not a string.
11721182
ValueError
1173-
If any string identifier does not follow the required
1174-
``'AGENCY-ID'`` format (e.g. ``'USGS-01646500'``).
1183+
If any identifier doesn't contain a hyphen separator
1184+
(per the OGC API spec: AGENCY-ID format, e.g. ``USGS-01646500``).
11751185
"""
11761186
if monitoring_location_id is None:
1177-
return
1187+
return None
11781188

1179-
if not isinstance(monitoring_location_id, (str, list)):
1189+
if isinstance(monitoring_location_id, str):
1190+
ids = [monitoring_location_id]
1191+
elif isinstance(monitoring_location_id, Iterable) and not isinstance(
1192+
monitoring_location_id, Mapping
1193+
):
1194+
ids = list(monitoring_location_id)
1195+
else:
11801196
raise TypeError(
1181-
f"monitoring_location_id must be a string or list of strings, "
1197+
f"monitoring_location_id must be a string or iterable of strings, "
11821198
f"not {type(monitoring_location_id).__name__}. "
11831199
f"Expected format: 'AGENCY-ID', e.g., 'USGS-{monitoring_location_id}'."
11841200
)
11851201

1186-
ids = (
1187-
[monitoring_location_id]
1188-
if isinstance(monitoring_location_id, str)
1189-
else monitoring_location_id
1190-
)
1191-
11921202
for id_ in ids:
11931203
if not isinstance(id_, str):
11941204
raise TypeError(
1195-
f"monitoring_location_id must be a string or list of strings, "
1205+
f"monitoring_location_id elements must be strings, "
11961206
f"not {type(id_).__name__}. "
11971207
f"Expected format: 'AGENCY-ID', e.g., 'USGS-{id_}'."
11981208
)
@@ -1202,6 +1212,8 @@ def _check_monitoring_location_id(
12021212
f"Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'."
12031213
)
12041214

1215+
return monitoring_location_id if isinstance(monitoring_location_id, str) else ids
1216+
12051217

12061218
def _get_args(
12071219
local_vars: dict[str, Any], exclude: set[str] | None = None

tests/waterdata_test.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import sys
33

4+
import pandas as pd
45
import pytest
56
from pandas import DataFrame
67

@@ -389,16 +390,17 @@ class TestCheckMonitoringLocationId:
389390
"""
390391

391392
def test_valid_string(self):
392-
"""A correctly formatted string passes without error."""
393-
_check_monitoring_location_id("USGS-01646500")
393+
"""A correctly formatted string passes and is returned unchanged."""
394+
assert _check_monitoring_location_id("USGS-01646500") == "USGS-01646500"
394395

395396
def test_valid_list(self):
396397
"""A list of correctly formatted strings passes without error."""
397-
_check_monitoring_location_id(["USGS-01646500", "USGS-02238500"])
398+
ids = ["USGS-01646500", "USGS-02238500"]
399+
assert _check_monitoring_location_id(ids) == ids
398400

399401
def test_none_passes(self):
400402
"""None is allowed (optional parameter)."""
401-
_check_monitoring_location_id(None)
403+
assert _check_monitoring_location_id(None) is None
402404

403405
def test_integer_raises_type_error(self):
404406
"""An integer ID raises TypeError with a helpful message."""
@@ -425,6 +427,52 @@ def test_get_daily_integer_id_raises(self):
425427
with pytest.raises(TypeError):
426428
get_daily(monitoring_location_id=5129115, parameter_code="00060")
427429

430+
def test_tuple_normalizes_to_list(self):
431+
"""A tuple of valid strings is accepted and normalized to list."""
432+
result = _check_monitoring_location_id(("USGS-01646500", "USGS-02238500"))
433+
assert result == ["USGS-01646500", "USGS-02238500"]
434+
assert isinstance(result, list)
435+
436+
def test_pandas_series_normalizes_to_list(self):
437+
"""A pandas.Series of valid strings is accepted and normalized to list."""
438+
s = pd.Series(["USGS-01646500", "USGS-02238500"])
439+
result = _check_monitoring_location_id(s)
440+
assert result == ["USGS-01646500", "USGS-02238500"]
441+
assert isinstance(result, list)
442+
443+
def test_pandas_index_normalizes_to_list(self):
444+
"""A pandas.Index of valid strings is accepted and normalized to list."""
445+
idx = pd.Index(["USGS-01646500", "USGS-02238500"])
446+
result = _check_monitoring_location_id(idx)
447+
assert result == ["USGS-01646500", "USGS-02238500"]
448+
assert isinstance(result, list)
449+
450+
def test_numpy_array_normalizes_to_list(self):
451+
"""A numpy.ndarray of valid strings is accepted and normalized to list."""
452+
import numpy as np
453+
454+
arr = np.array(["USGS-01646500", "USGS-02238500"])
455+
result = _check_monitoring_location_id(arr)
456+
assert result == ["USGS-01646500", "USGS-02238500"]
457+
assert isinstance(result, list)
458+
459+
def test_numpy_int_array_raises_type_error(self):
460+
"""An iterable whose elements aren't strings (numpy int array) raises."""
461+
import numpy as np
462+
463+
with pytest.raises(TypeError, match="elements must be strings"):
464+
_check_monitoring_location_id(np.array([1, 2, 3]))
465+
466+
def test_pandas_series_of_ints_raises_type_error(self):
467+
"""An iterable whose elements aren't strings (Series of ints) raises."""
468+
with pytest.raises(TypeError, match="elements must be strings"):
469+
_check_monitoring_location_id(pd.Series([1, 2, 3]))
470+
471+
def test_dict_raises_type_error(self):
472+
"""Mappings are rejected — iterating a dict yields keys, which is a footgun."""
473+
with pytest.raises(TypeError, match="not dict"):
474+
_check_monitoring_location_id({"USGS-01646500": "site"})
475+
428476
def test_get_daily_malformed_id_raises(self):
429477
"""get_daily raises ValueError for a malformed string ID."""
430478
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)