Skip to content

Commit b0f2289

Browse files
committed
Validate monitoring_location_id in get_ratings and widen annotations in nearest
Copilot found that `get_ratings` accepts `monitoring_location_id` and documents the same AGENCY-ID contract, but builds its STAC filter directly without routing through `_get_args` — so the centralized validation never ran. Call `_check_monitoring_location_id` at the top of `get_ratings` and widen the annotation/docstring to `Iterable[str]` for consistency. `get_nearest_continuous` inherits validation via its forwarded call to `get_continuous`, so its behavior is already correct — but its annotation and docstring still advertised `list[str]`. Widen both for parity.
1 parent 0cf981e commit b0f2289

2 files changed

Lines changed: 15 additions & 7 deletions

File tree

dataretrieval/waterdata/nearest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
from collections.abc import Iterable
89
from typing import Literal, get_args
910

1011
import pandas as pd
@@ -18,8 +19,8 @@
1819

1920
def get_nearest_continuous(
2021
targets,
21-
monitoring_location_id: str | list[str] | None = None,
22-
parameter_code: str | list[str] | None = None,
22+
monitoring_location_id: str | Iterable[str] | None = None,
23+
parameter_code: str | Iterable[str] | None = None,
2324
*,
2425
window: str | pd.Timedelta = "PT7M30S",
2526
on_tie: OnTie = "first",
@@ -44,9 +45,9 @@ def get_nearest_continuous(
4445
Target timestamps. Naive datetimes are treated as UTC. Accepts a
4546
list, ``pandas.Series``, ``pandas.DatetimeIndex``, ``numpy.ndarray``,
4647
or anything ``pandas.to_datetime`` consumes.
47-
monitoring_location_id : string or list of strings, optional
48+
monitoring_location_id : string or iterable of strings, optional
4849
Forwarded to ``get_continuous``.
49-
parameter_code : string or list of strings, optional
50+
parameter_code : string or iterable of strings, optional
5051
Forwarded to ``get_continuous``.
5152
window : string or ``pandas.Timedelta``, default ``"PT7M30S"``
5253
Half-window around each target, as an ISO 8601 duration

dataretrieval/waterdata/ratings.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222

2323
from dataretrieval.rdb import extract_rdb_comment, read_rdb
2424

25-
from .utils import _DURATION_RE, BASE_URL, _default_headers, _format_api_dates
25+
from .utils import (
26+
_DURATION_RE,
27+
BASE_URL,
28+
_check_monitoring_location_id,
29+
_default_headers,
30+
_format_api_dates,
31+
)
2632

2733
logger = logging.getLogger(__name__)
2834

@@ -33,7 +39,7 @@
3339

3440

3541
def get_ratings(
36-
monitoring_location_id: str | list[str] | None = None,
42+
monitoring_location_id: str | Iterable[str] | None = None,
3743
file_type: RATING_FILE_TYPE | list[RATING_FILE_TYPE] = "exsa",
3844
file_path: str | None = None,
3945
time: str | list[str] | None = None,
@@ -62,7 +68,7 @@ def get_ratings(
6268
6369
Parameters
6470
----------
65-
monitoring_location_id : string or list of strings, optional
71+
monitoring_location_id : string or iterable of strings, optional
6672
One or more identifiers in ``AGENCY-ID`` form (e.g.
6773
``"USGS-01104475"``). If omitted, the spatial / temporal filters
6874
determine the result set.
@@ -142,6 +148,7 @@ def get_ratings(
142148
... )
143149
144150
"""
151+
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
145152
file_types = _as_list(file_type)
146153
invalid = [ft for ft in file_types if ft not in _VALID_FILE_TYPES]
147154
if invalid:

0 commit comments

Comments
 (0)