Skip to content

Commit c755f6b

Browse files
authored
Validate monitoring_location_id format in waterdata functions (#229)
Validate and normalize monitoring_location_id (and other multi-value filters) at the boundary Closes #188. Catches a class of bug that previously produced silent zero-result responses or confusing JSONDecodeErrors when callers passed malformed `monitoring_location_id` values to the WaterData OGC API getters. Behavior changes (every public waterdata function accepting `monitoring_location_id`): - `monitoring_location_id` is now validated client-side. Non-string, non-iterable inputs (`int`, `dict`, …) raise `TypeError`; strings that don't match the AGENCY-ID hyphen-separated form (e.g. `"USGS-01646500"`) raise `ValueError`; non-string elements inside an iterable raise `TypeError`. A clear "Expected 'AGENCY-ID' format, e.g. 'USGS-01646500'" hint is appended in every error. - Other multi-value string filters (`parameter_code`, `statistic_id`, `agency_code`, etc.) now accept any non-string iterable of strings — `list`, `tuple`, `pandas.Series`, `pandas.Index`, `numpy.ndarray`, and generators are all materialized to `list[str]` before the URL is built. A bare single string is still accepted unchanged. - `properties` additionally accepts a single string and wraps it into a one-element list, since `",".join(...)` would otherwise iterate it as characters. - `list[int]` filters on `get_peaks` (`water_year`, `year`, `month`, `day`, `peak_since`) and `list[float]` on `get_combined_metadata` (`thresholds`) pass through untouched. - `_format_api_dates` now rejects `Mapping` inputs (which previously silently materialized as the keys list) and `None` is short-circuited up front. Public API: - Type annotations widened to `str | Iterable[str] | None` (or `str | list[str] | None` where only a list is meaningful) across all affected functions. Numpydoc parameter descriptions updated from "list of strings" to "iterable of strings" to match. - Coverage verified on all 15 public functions that accept `monitoring_location_id`: 12 in `waterdata/api.py` (via centralized `_get_args`), `get_ratings` in `waterdata/ratings.py` (direct call), and `get_nearest_continuous` in `waterdata/nearest.py` (transitively via `get_continuous`). Internals: - `_normalize_str_iterable` — one O(N) walk validates element types and materializes to `list`. Generic, used by every string-filter param. - `_check_monitoring_location_id` — composes `_normalize_str_iterable` with per-element `_check_id_format` (regex `[^-\s]+-[^-\s]+`, fullmatch). - `_get_args` — single dispatch point that runs the right normalizer per param name. New `get_*` functions inherit validation automatically. - `_DATE_RANGE_PARAMS` — shared constant covering `time`, `datetime`, `last_modified`, `begin`, `begin_utc`, `end`, `end_utc`; used by both `_construct_api_requests`'s date formatting and `_get_args`'s bypass. Tests: - Live-verified against the USGS OGC + STAC APIs with ~70 stress cases covering every iterable shape, every public function with mloc, and every rejection path. - 100+ unit tests including regressions for: int-list parameter rejection, Series-of-IDs materialization, Mapping rejection on date inputs, AGENCY-ID format edge cases (trailing/leading hyphen, embedded comma, whitespace, multi-hyphen).
1 parent da49749 commit c755f6b

6 files changed

Lines changed: 815 additions & 456 deletions

File tree

dataretrieval/waterdata/api.py

Lines changed: 428 additions & 432 deletions
Large diffs are not rendered by default.

dataretrieval/waterdata/nearest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
from collections.abc import Iterable
89
from typing import Literal, get_args
910

1011
import pandas as pd
@@ -18,8 +19,8 @@
1819

1920
def get_nearest_continuous(
2021
targets,
21-
monitoring_location_id: str | list[str] | None = None,
22-
parameter_code: str | list[str] | None = None,
22+
monitoring_location_id: str | Iterable[str] | None = None,
23+
parameter_code: str | Iterable[str] | None = None,
2324
*,
2425
window: str | pd.Timedelta = "PT7M30S",
2526
on_tie: OnTie = "first",
@@ -44,9 +45,9 @@ def get_nearest_continuous(
4445
Target timestamps. Naive datetimes are treated as UTC. Accepts a
4546
list, ``pandas.Series``, ``pandas.DatetimeIndex``, ``numpy.ndarray``,
4647
or anything ``pandas.to_datetime`` consumes.
47-
monitoring_location_id : string or list of strings, optional
48+
monitoring_location_id : string or iterable of strings, optional
4849
Forwarded to ``get_continuous``.
49-
parameter_code : string or list of strings, optional
50+
parameter_code : string or iterable of strings, optional
5051
Forwarded to ``get_continuous``.
5152
window : string or ``pandas.Timedelta``, default ``"PT7M30S"``
5253
Half-window around each target, as an ISO 8601 duration

dataretrieval/waterdata/ratings.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222

2323
from dataretrieval.rdb import extract_rdb_comment, read_rdb
2424

25-
from .utils import _DURATION_RE, BASE_URL, _default_headers, _format_api_dates
25+
from .utils import (
26+
_DURATION_RE,
27+
BASE_URL,
28+
_check_monitoring_location_id,
29+
_default_headers,
30+
_format_api_dates,
31+
)
2632

2733
logger = logging.getLogger(__name__)
2834

@@ -33,7 +39,7 @@
3339

3440

3541
def get_ratings(
36-
monitoring_location_id: str | list[str] | None = None,
42+
monitoring_location_id: str | Iterable[str] | None = None,
3743
file_type: RATING_FILE_TYPE | list[RATING_FILE_TYPE] = "exsa",
3844
file_path: str | None = None,
3945
time: str | list[str] | None = None,
@@ -62,7 +68,7 @@ def get_ratings(
6268
6369
Parameters
6470
----------
65-
monitoring_location_id : string or list of strings, optional
71+
monitoring_location_id : string or iterable of strings, optional
6672
One or more identifiers in ``AGENCY-ID`` form (e.g.
6773
``"USGS-01104475"``). If omitted, the spatial / temporal filters
6874
determine the result set.
@@ -142,6 +148,7 @@ def get_ratings(
142148
... )
143149
144150
"""
151+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
145152
file_types = _as_list(file_type)
146153
invalid = [ft for ft in file_types if ft not in _VALID_FILE_TYPES]
147154
if invalid:

dataretrieval/waterdata/utils.py

Lines changed: 167 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import re
7+
from collections.abc import Iterable, Mapping
78
from datetime import datetime
89
from typing import Any, get_args
910
from zoneinfo import ZoneInfo
@@ -143,6 +144,15 @@ def _switch_properties_id(properties: list[str] | None, id_name: str, service: s
143144
# admits time-only forms like ``PT36H``.
144145
_DURATION_RE = re.compile(r"^[Pp]T?\d")
145146

147+
# OGC API parameters that carry a date/datetime value (single string,
148+
# two-element range, or interval/duration string) rather than a multi-value
149+
# string list. Used by ``_construct_api_requests`` to keep them out of the
150+
# POST/CQL2 multi-value path and to route them through ``_format_api_dates``,
151+
# and by ``_NO_NORMALIZE_PARAMS`` to bypass string-iterable normalization.
152+
_DATE_RANGE_PARAMS = frozenset(
153+
{"datetime", "last_modified", "begin", "begin_utc", "end", "end_utc", "time"}
154+
)
155+
146156

147157
def _parse_datetime(value: str) -> datetime | None:
148158
"""Parse a single datetime string against the supported formats.
@@ -223,12 +233,24 @@ def _format_api_dates(
223233
converted from that offset to UTC; naive inputs are interpreted in the
224234
local time zone for backwards compatibility.
225235
"""
236+
if datetime_input is None:
237+
return None
226238
# Get timezone
227239
local_timezone = datetime.now().astimezone().tzinfo
228240

229241
# Convert single string to list for uniform processing
230242
if isinstance(datetime_input, str):
231243
datetime_input = [datetime_input]
244+
elif isinstance(datetime_input, Mapping):
245+
# `list(mapping)` returns keys, which silently accepts the wrong shape.
246+
raise TypeError(
247+
f"date input must be a string or sequence of strings, "
248+
f"not {type(datetime_input).__name__}."
249+
)
250+
elif not isinstance(datetime_input, (list, tuple)):
251+
# Materialize any other iterable (pandas.Series, numpy.ndarray,
252+
# generator, ...) so the len()/subscript operations below work.
253+
datetime_input = list(datetime_input)
232254

233255
# Check for null or all NA and return None
234256
if all(pd.isna(dt) or dt == "" or dt is None for dt in datetime_input):
@@ -429,14 +451,11 @@ def _construct_api_requests(
429451
"""
430452
service_url = f"{OGC_API_URL}/collections/{service}/items"
431453

432-
# Single parameters can only have one value
433-
single_params = {"datetime", "last_modified", "begin", "end", "time"}
434-
435454
# Identify which parameters should be included in the POST content body
436455
post_params = {
437456
k: v
438457
for k, v in kwargs.items()
439-
if k not in single_params and isinstance(v, (list, tuple)) and len(v) > 1
458+
if k not in _DATE_RANGE_PARAMS and isinstance(v, (list, tuple)) and len(v) > 1
440459
}
441460

442461
# Everything else goes into the params dictionary for the URL
@@ -452,15 +471,13 @@ def _construct_api_requests(
452471
POST = bool(post_params)
453472

454473
# Convert dates to ISO08601 format
455-
time_periods = {"last_modified", "datetime", "time", "begin", "end"}
456-
for i in time_periods:
474+
for i in _DATE_RANGE_PARAMS:
457475
if i in params:
458476
dates = service == "daily" and i != "last_modified"
459477
params[i] = _format_api_dates(params[i], date=dates)
460478

461-
# String together bbox elements from a list to a comma-separated string,
462-
# and string together properties if provided
463-
if bbox:
479+
# `len()` instead of truthiness: a numpy ndarray would raise on `if bbox:`.
480+
if bbox is not None and len(bbox) > 0:
464481
params["bbox"] = ",".join(map(str, bbox))
465482
if properties:
466483
params["properties"] = ",".join(properties)
@@ -1168,6 +1185,129 @@ def _check_profiles(
11681185
)
11691186

11701187

1188+
_MONITORING_LOCATION_ID_RE = re.compile(r"[^-\s]+-[^-\s]+")
1189+
1190+
1191+
# Iterable-shaped params that ``_get_args`` must NOT push through
1192+
# ``_normalize_str_iterable`` (scalar non-string knobs are caught by runtime
1193+
# type, so only iterables with special handling need to be named here):
1194+
# - date-range params may contain ``pd.NaT``/None or interval strings
1195+
# - ``bbox``/``boundingBox`` are ``list[float]``, sometimes ``numpy.ndarray``
1196+
# - ``get_peaks``'s int-valued filters (``water_year`` etc.) are ``list[int]``
1197+
# - ``get_combined_metadata``'s ``thresholds`` is ``list[float]``
1198+
_NO_NORMALIZE_PARAMS = _DATE_RANGE_PARAMS | {
1199+
"bbox",
1200+
"boundingBox",
1201+
"water_year",
1202+
"year",
1203+
"month",
1204+
"day",
1205+
"peak_since",
1206+
"thresholds",
1207+
}
1208+
1209+
1210+
def _normalize_str_iterable(
1211+
value: str | Iterable[str] | None,
1212+
param_name: str = "value",
1213+
) -> str | list[str] | None:
1214+
"""Validate that ``value`` is None, a string, or an iterable of strings.
1215+
1216+
Non-string iterables (``list``, ``tuple``, ``pandas.Series``,
1217+
``pandas.Index``, ``numpy.ndarray``, generators) are materialized to a
1218+
``list`` so downstream code that branches on ``isinstance(v, (list,
1219+
tuple))`` keeps working. ``Mapping`` types are rejected because
1220+
iterating a mapping yields keys, not values.
1221+
1222+
Parameters
1223+
----------
1224+
value : None, str, or iterable of str
1225+
param_name : str, optional
1226+
Used in error messages. Defaults to ``"value"``.
1227+
1228+
Returns
1229+
-------
1230+
None, str, or list of str
1231+
1232+
Raises
1233+
------
1234+
TypeError
1235+
If the input isn't ``None``, ``str``, or a non-``Mapping``
1236+
iterable; or if any iterable element isn't a string.
1237+
"""
1238+
if value is None:
1239+
return None
1240+
if isinstance(value, str):
1241+
return value
1242+
if isinstance(value, Mapping) or not isinstance(value, Iterable):
1243+
raise TypeError(
1244+
f"{param_name} must be a string or iterable of strings, "
1245+
f"not {type(value).__name__} (got {value!r})."
1246+
)
1247+
values: list[str] = []
1248+
for v in value:
1249+
if not isinstance(v, str):
1250+
raise TypeError(
1251+
f"{param_name} elements must be strings, "
1252+
f"not {type(v).__name__} (got {v!r})."
1253+
)
1254+
values.append(v)
1255+
return values
1256+
1257+
1258+
def _check_monitoring_location_id(
1259+
monitoring_location_id: str | Iterable[str] | None,
1260+
) -> str | list[str] | None:
1261+
"""Validate and normalize a ``monitoring_location_id`` value.
1262+
1263+
Combines :func:`_normalize_str_iterable` with the AGENCY-ID format
1264+
check that is unique to ``monitoring_location_id`` (the OGC spec
1265+
requires a hyphen separator, e.g. ``USGS-01646500``).
1266+
1267+
Parameters
1268+
----------
1269+
monitoring_location_id : None, str, or iterable of str
1270+
See :func:`_normalize_str_iterable`. Each string is additionally
1271+
required to match the AGENCY-ID hyphen-separated format.
1272+
1273+
Returns
1274+
-------
1275+
None, str, or list of str
1276+
1277+
Raises
1278+
------
1279+
TypeError
1280+
If the input isn't ``None``, ``str``, or a non-``Mapping``
1281+
iterable; or if any iterable element isn't a string.
1282+
ValueError
1283+
If any identifier doesn't contain a hyphen separator
1284+
(per the OGC API spec: AGENCY-ID format, e.g. ``USGS-01646500``).
1285+
"""
1286+
try:
1287+
value = _normalize_str_iterable(
1288+
monitoring_location_id, "monitoring_location_id"
1289+
)
1290+
except TypeError as exc:
1291+
# Re-raise with the AGENCY-ID hint the generic helper doesn't carry.
1292+
raise TypeError(
1293+
f"{exc} Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'."
1294+
) from None
1295+
if value is None:
1296+
return None
1297+
for item in (value,) if isinstance(value, str) else value:
1298+
_check_id_format(item)
1299+
return value
1300+
1301+
1302+
def _check_id_format(value: str) -> None:
1303+
"""Raise ``ValueError`` if ``value`` is not in ``AGENCY-ID`` format."""
1304+
if not _MONITORING_LOCATION_ID_RE.fullmatch(value):
1305+
raise ValueError(
1306+
f"Invalid monitoring_location_id: {value!r}. "
1307+
f"Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'."
1308+
)
1309+
1310+
11711311
def _get_args(
11721312
local_vars: dict[str, Any], exclude: set[str] | None = None
11731313
) -> dict[str, Any]:
@@ -1194,6 +1334,21 @@ def _get_args(
11941334
if exclude:
11951335
to_exclude.update(exclude)
11961336

1197-
return {
1198-
k: v for k, v in local_vars.items() if k not in to_exclude and v is not None
1199-
}
1337+
args: dict[str, Any] = {}
1338+
for k, v in local_vars.items():
1339+
if k in to_exclude or v is None:
1340+
continue
1341+
if k == "monitoring_location_id":
1342+
args[k] = _check_monitoring_location_id(v)
1343+
elif k == "properties":
1344+
# `",".join(properties)` would iterate a bare string as characters.
1345+
args[k] = [v] if isinstance(v, str) else _normalize_str_iterable(v, k)
1346+
elif (
1347+
k in _NO_NORMALIZE_PARAMS
1348+
or isinstance(v, str)
1349+
or not isinstance(v, Iterable)
1350+
):
1351+
args[k] = v
1352+
else:
1353+
args[k] = _normalize_str_iterable(v, k)
1354+
return args

0 commit comments

Comments
 (0)