Skip to content

Commit 9d5fbd7

Browse files
thodson-usgsclaude
andcommitted
chore(typing): set up mypy and fix the type errors it surfaces
The package ships a ``py.typed`` marker (advertising itself as typed to downstream users) but nothing type-checked it. Add mypy and get a clean run. Setup: - [tool.mypy] in pyproject.toml: a lenient first-pass config (ignore_missing_imports, target python_version 3.9), scoped to the dataretrieval package. - mypy<2 added to the [test] extra (<2 so it can still target 3.9). - a type-check job in the CI workflow, parallel to the ruff lint job. Fixes (mypy went from 78 errors to 0 on the tracked package): - HTTPX_DEFAULTS annotated dict[str, Any] so **-splatting it into httpx.get / httpx.AsyncClient type-checks -- cleared ~55 errors across 7 call sites at once. - utils.py gains `from __future__ import annotations`: mypy (targeting 3.9) caught that the new `str | None` annotations there would be a runtime error on 3.9, because this module -- unlike the rest of the package -- lacked the future import. - BaseMetadata.comment annotated `str | None` (was inferred `None`, which rejected every subclass that assigns a comment string). - _format_api_dates: accept Sequence[str | None] (covariant) so a list[str] caller type-checks, and build the formatted list with an early return so the final join sees list[str]. - _as_str_list: delegate to _normalize_str_iterable then wrap, so the declared return type list[str] | None holds. - _next_req_url: declare next_host / cur_host as `str | None`. - ratings._search: build the query dict in a non-Optional local before aliasing it to the loop's `params` (which toggles to None per page). - nldi: drop the bool->str / Literal->str variable reuse; guard the basin branch so feature_source / feature_id are non-None before get_basin. - chunking: narrow the optional filter before _is_chunkable; fix a stale `# type: ignore` error code. The fixes are annotations and type-narrowing guards. The only runtime-visible change is that nldi.search() now raises a clear ValueError up front when a basin search is missing feature_source/feature_id, where the same condition previously raised deeper inside get_basin. 259 tests pass across the affected suites. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 665383f commit 9d5fbd7

7 files changed

Lines changed: 71 additions & 24 deletions

File tree

.github/workflows/python-package.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@ jobs:
2626
ruff check . --output-format=github
2727
ruff format --check .
2828
29+
type-check:
30+
runs-on: ubuntu-latest
31+
steps:
32+
- uses: actions/checkout@v6
33+
- name: Set up Python 3.13
34+
uses: actions/setup-python@v6
35+
with:
36+
python-version: "3.13"
37+
cache: "pip"
38+
- name: Install dependencies
39+
run: |
40+
python -m pip install --upgrade pip
41+
pip install .[test]
42+
- name: Type-check with mypy
43+
run: mypy
44+
2945
test:
3046
needs: lint
3147
runs-on: ${{ matrix.os }}

dataretrieval/nldi.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from json import JSONDecodeError
4-
from typing import Literal
4+
from typing import Literal, cast
55

66
from dataretrieval.utils import query
77

@@ -162,9 +162,12 @@ def get_basin(
162162
raise ValueError("feature_id is required")
163163

164164
url = f"{NLDI_API_BASE_URL}/{feature_source}/{feature_id}/basin"
165-
simplified = str(simplified).lower()
166-
split_catchment = str(split_catchment).lower()
167-
query_params = {"simplified": simplified, "splitCatchment": split_catchment}
165+
simplified_str = str(simplified).lower()
166+
split_catchment_str = str(split_catchment).lower()
167+
query_params = {
168+
"simplified": simplified_str,
169+
"splitCatchment": split_catchment_str,
170+
}
168171
err_msg = (
169172
f"Error getting basin for feature source '{feature_source}' and "
170173
f"feature_id '{feature_id}'"
@@ -408,7 +411,7 @@ def search(
408411
if (lat is None) != (long is None):
409412
raise ValueError("Both lat and long are required")
410413

411-
find = find.lower()
414+
find = cast(Literal["basin", "flowlines", "features"], find.lower())
412415
if find not in ("basin", "flowlines", "features"):
413416
raise ValueError(
414417
f"Invalid value for find: {find} - allowed values are:"
@@ -428,6 +431,10 @@ def search(
428431
return get_features(lat=lat, long=long, as_json=True)
429432

430433
if find == "basin":
434+
if feature_source is None or feature_id is None:
435+
raise ValueError(
436+
"feature_source and feature_id are required to find a basin"
437+
)
431438
return get_basin(
432439
feature_source=feature_source, feature_id=feature_id, as_json=True
433440
)

dataretrieval/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,22 @@
22
Useful utilities for data munging.
33
"""
44

5+
from __future__ import annotations
6+
57
import warnings
68
from collections.abc import Iterable
9+
from typing import Any
710

811
import httpx
912
import pandas as pd
1013

1114
import dataretrieval
1215
from dataretrieval.codes import tz
1316

14-
HTTPX_DEFAULTS = {
17+
# Typed as ``dict[str, Any]`` (not the inferred ``dict[str, object]``) so that
18+
# splatting it as ``**HTTPX_DEFAULTS`` into ``httpx.get`` / ``httpx.AsyncClient``
19+
# type-checks: the values are a heterogeneous bag of httpx keyword arguments.
20+
HTTPX_DEFAULTS: dict[str, Any] = {
1521
"follow_redirects": True,
1622
"timeout": httpx.Timeout(60.0, connect=10.0),
1723
}
@@ -190,6 +196,7 @@ def _attach_datetime_columns(df: pd.DataFrame) -> pd.DataFrame:
190196
# Concat in one shot — per-column assignment on a wide CSV-derived
191197
# frame triggers pandas' fragmentation PerformanceWarning.
192198
df = pd.concat([df, pd.DataFrame(new_columns, index=df.index)], axis=1)
199+
sort_key: str | None
193200
if "Activity_StartDateTime" in df.columns:
194201
sort_key = "Activity_StartDateTime"
195202
elif "ActivityStartDateTime" in df.columns:
@@ -234,7 +241,7 @@ def __init__(self, response) -> None:
234241
self.url = str(response.url)
235242
self.query_time = response.elapsed
236243
self.header = response.headers
237-
self.comment = None
244+
self.comment: str | None = None
238245

239246
# # not sure what statistic_info is
240247
# self.statistic_info = None

dataretrieval/waterdata/chunking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def _set_response_url(response: httpx.Response, url: str | httpx.URL) -> None:
681681
same ``.request``.
682682
"""
683683
try:
684-
response.url = url # type: ignore[misc]
684+
response.url = url # type: ignore[misc, assignment]
685685
except AttributeError:
686686
target = httpx.URL(str(url))
687687
try:
@@ -800,7 +800,7 @@ def _extract_axes(args: dict[str, Any]) -> list[_Axis]:
800800
axes.append(_Axis(arg_key=key, atoms=tuple(value), joiner=_LIST_SEP))
801801

802802
filter_expr = args.get("filter")
803-
if _is_chunkable(filter_expr, args.get("filter_lang")):
803+
if filter_expr is not None and _is_chunkable(filter_expr, args.get("filter_lang")):
804804
_check_numeric_filter_pitfall(filter_expr)
805805
clauses = _split_top_level_or(filter_expr)
806806
if len(clauses) >= 2:

dataretrieval/waterdata/ratings.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,18 @@ def _search(
246246
STAC ``next`` link is followed until exhausted so a result set larger than
247247
one page isn't silently truncated.
248248
"""
249-
params: dict[str, Any] | None = {"limit": min(limit, 10000)}
249+
query_params: dict[str, Any] = {"limit": min(limit, 10000)}
250250
if filter_str is not None:
251-
params["filter"] = filter_str
251+
query_params["filter"] = filter_str
252252
if time_str is not None:
253-
params["datetime"] = time_str
253+
query_params["datetime"] = time_str
254254
if bbox is not None:
255-
params["bbox"] = ",".join(map(str, bbox))
255+
query_params["bbox"] = ",".join(map(str, bbox))
256256

257257
url: str | None = f"{STAC_URL}/search"
258+
# ``params`` is sent only on the first request; each STAC ``next`` link
259+
# already carries the query, so it is reset to None inside the loop.
260+
params: dict[str, Any] | None = query_params
258261
features: list[dict[str, Any]] = []
259262
while url is not None:
260263
response = httpx.get(

dataretrieval/waterdata/utils.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Iterable,
1515
Iterator,
1616
Mapping,
17+
Sequence,
1718
)
1819
from contextlib import asynccontextmanager, contextmanager
1920
from contextvars import ContextVar
@@ -247,7 +248,7 @@ def _format_one(dt, *, date: bool, local_tz) -> str | None:
247248

248249

249250
def _format_api_dates(
250-
datetime_input: str | list[str | None] | None, date: bool = False
251+
datetime_input: str | Sequence[str | None] | None, date: bool = False
251252
) -> str | None:
252253
"""
253254
Formats date or datetime input(s) for use with an API.
@@ -330,11 +331,12 @@ def _format_api_dates(
330331
# element invalidates the range. Resolve the local tz only now — after the
331332
# all-NA / duration / interval guards above have had their chance to return.
332333
local_timezone = datetime.now().astimezone().tzinfo
333-
formatted = [
334-
_format_one(dt, date=date, local_tz=local_timezone) for dt in datetime_input
335-
]
336-
if any(f is None for f in formatted):
337-
return None
334+
formatted: list[str] = []
335+
for dt in datetime_input:
336+
one = _format_one(dt, date=date, local_tz=local_timezone)
337+
if one is None:
338+
return None
339+
formatted.append(one)
338340
return "/".join(formatted)
339341

340342

@@ -823,6 +825,8 @@ def _next_req_url(
823825
# body might supply. Guarded against mock-shaped ``resp.url``
824826
# attributes (tests sometimes set strings or ``MagicMock``)
825827
# by falling open when host extraction isn't reliable.
828+
next_host: str | None
829+
cur_host: str | None
826830
try:
827831
next_host = httpx.URL(href).host
828832
resp_url = (
@@ -1915,11 +1919,10 @@ def _as_str_list(
19151919
``",".join(...)`` doesn't iterate it character-by-character — and
19161920
materializes any other iterable via :func:`_normalize_str_iterable`.
19171921
"""
1918-
return (
1919-
[value]
1920-
if isinstance(value, str)
1921-
else _normalize_str_iterable(value, param_name)
1922-
)
1922+
normalized = _normalize_str_iterable(value, param_name)
1923+
if isinstance(normalized, str):
1924+
return [normalized]
1925+
return normalized
19231926

19241927

19251928
def _check_monitoring_location_id(

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ test = [
3939
"coverage",
4040
"pytest-httpx",
4141
"ruff",
42+
"mypy<2", # <2 so it can still target Python 3.9 (the project's floor)
4243
]
4344
doc = [
4445
"docutils<0.22",
@@ -102,3 +103,13 @@ skip-magic-trailing-comma = false
102103
line-ending = "auto"
103104
docstring-code-format = true
104105
docstring-code-line-length = 72
106+
107+
[tool.mypy]
108+
# First-pass type checking, kept lenient so it can be adopted incrementally on a
109+
# large, largely-unannotated scientific codebase: ``ignore_missing_imports``
110+
# treats untyped third-party libraries (geopandas, anyio, ...) as ``Any`` rather
111+
# than erroring, and unannotated function bodies are not checked by default.
112+
# Tightening (e.g. ``disallow_untyped_defs``) can follow once annotations are in.
113+
python_version = "3.9" # the project's minimum supported version
114+
files = ["dataretrieval"]
115+
ignore_missing_imports = true

0 commit comments

Comments
 (0)