Skip to content

Commit dca49a7

Browse files
authored
chore(typing): set up mypy and tighten the package to mypy --strict (#314)
The package ships a ``py.typed`` marker (advertising itself as typed to downstream users) but nothing type-checked it. Add mypy and get a clean run. Setup: - [tool.mypy] in pyproject.toml: a lenient first-pass config (ignore_missing_imports, target python_version 3.9), scoped to the dataretrieval package. - mypy<2 added to the [test] extra (<2 so it can still target 3.9). - a type-check job in the CI workflow, parallel to the ruff lint job. Fixes (mypy went from 78 errors to 0 on the tracked package): - HTTPX_DEFAULTS annotated dict[str, Any] so **-splatting it into httpx.get / httpx.AsyncClient type-checks -- cleared ~55 errors across 7 call sites at once. - utils.py gains `from __future__ import annotations`: mypy (targeting 3.9) caught that the new `str | None` annotations there would be a runtime error on 3.9, because this module -- unlike the rest of the package -- lacked the future import. - BaseMetadata.comment annotated `str | None` (was inferred `None`, which rejected every subclass that assigns a comment string). - _format_api_dates: accept Sequence[str | None] (covariant) so a list[str] caller type-checks, and build the formatted list with an early return so the final join sees list[str]. - _as_str_list: delegate to _normalize_str_iterable then wrap, so the declared return type list[str] | None holds. - _next_req_url: declare next_host / cur_host as `str | None`. - ratings._search: build the query dict in a non-Optional local before aliasing it to the loop's `params` (which toggles to None per page). - nldi: drop the bool->str / Literal->str variable reuse; guard the basin branch so feature_source / feature_id are non-None before get_basin. - chunking: narrow the optional filter before _is_chunkable; fix a stale `# type: ignore` error code. The fixes are annotations and type-narrowing guards. The only runtime-visible change is that nldi.search() now raises a clear ValueError up front when a basin search is missing feature_source/feature_id, where the same condition previously raised deeper inside get_basin. 259 tests pass across the affected suites.
1 parent c107fc9 commit dca49a7

14 files changed

Lines changed: 284 additions & 152 deletions

File tree

.github/workflows/python-package.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@ jobs:
2626
ruff check . --output-format=github
2727
ruff format --check .
2828
29+
type-check:
30+
runs-on: ubuntu-latest
31+
steps:
32+
- uses: actions/checkout@v6
33+
- name: Set up Python 3.13
34+
uses: actions/setup-python@v6
35+
with:
36+
python-version: "3.13"
37+
cache: "pip"
38+
- name: Install dependencies
39+
run: |
40+
python -m pip install --upgrade pip
41+
pip install .[type-check]
42+
- name: Type-check with mypy
43+
run: mypy
44+
2945
test:
3046
needs: lint
3147
runs-on: ${{ matrix.os }}

dataretrieval/nadp.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
3030
"""
3131

32+
from __future__ import annotations
33+
3234
import io
3335
import re
3436
import warnings
@@ -45,7 +47,7 @@
4547
)
4648

4749

48-
def _warn_deprecated():
50+
def _warn_deprecated() -> None:
4951
warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=3)
5052

5153

@@ -74,19 +76,19 @@ def _warn_deprecated():
7476
class NADP_ZipFile(zipfile.ZipFile):
7577
"""Extend zipfile.ZipFile for working on data from NADP"""
7678

77-
def tif_name(self):
79+
def tif_name(self) -> str:
7880
"""Get the name of the tif file in the zip file."""
7981
filenames = self.namelist()
8082
r = re.compile(".*tif$")
8183
tif_list = list(filter(r.match, filenames))
8284
return tif_list[0]
8385

84-
def tif(self):
86+
def tif(self) -> bytes:
8587
"""Read the tif file in the zip file."""
8688
return self.read(self.tif_name())
8789

8890

89-
def get_annual_MDN_map(measurement_type, year, path):
91+
def get_annual_MDN_map(measurement_type: str, year: str, path: str) -> str:
9092
"""Download a MDN map from NDAP.
9193
9294
This function looks for a zip file containing gridded information at:
@@ -135,7 +137,12 @@ def get_annual_MDN_map(measurement_type, year, path):
135137
return str(path)
136138

137139

138-
def get_annual_NTN_map(measurement_type, measurement=None, year=None, path="."):
140+
def get_annual_NTN_map(
141+
measurement_type: str,
142+
measurement: str | None = None,
143+
year: str | None = None,
144+
path: str = ".",
145+
) -> str:
139146
"""Download a NTN map from NDAP.
140147
141148
This function looks for a zip file containing gridded information at:
@@ -193,7 +200,7 @@ def get_annual_NTN_map(measurement_type, measurement=None, year=None, path="."):
193200
return str(path)
194201

195202

196-
def get_zip(url, filename):
203+
def get_zip(url: str, filename: str) -> NADP_ZipFile:
197204
"""Gets a ZipFile at url and returns it
198205
199206
Parameters

dataretrieval/nldi.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from json import JSONDecodeError
4-
from typing import Literal
4+
from typing import Any, Literal, cast
55

66
from dataretrieval.utils import query
77

@@ -16,13 +16,17 @@
1616
_VALID_NAVIGATION_MODES = ("UM", "DM", "UT", "DD")
1717

1818

19-
def _query_nldi(url, query_params, error_message):
19+
def _query_nldi(
20+
url: str,
21+
query_params: dict[str, str],
22+
error_message: str,
23+
) -> dict[str, Any] | list[Any]:
2024
# A helper function to query the NLDI API
2125
response = query(url, payload=query_params)
2226
if response.status_code != 200:
2327
raise ValueError(f"{error_message}. Error reason: {response.reason_phrase}")
2428

25-
response_data = {}
29+
response_data: dict[str, Any] | list[Any] = {}
2630
try:
2731
response_data = response.json()
2832
except JSONDecodeError:
@@ -32,7 +36,7 @@ def _query_nldi(url, query_params, error_message):
3236
return response_data
3337

3438

35-
def _features_to_gdf(feature_collection: dict) -> gpd.GeoDataFrame:
39+
def _features_to_gdf(feature_collection: dict[str, Any]) -> gpd.GeoDataFrame:
3640
"""Build a GeoDataFrame from an NLDI FeatureCollection, tolerating empties.
3741
3842
NLDI can legitimately return no features (e.g. a feature with nothing
@@ -56,7 +60,7 @@ def get_flowlines(
5660
stop_comid: int | None = None,
5761
trim_start: bool = False,
5862
as_json: bool = False,
59-
) -> gpd.GeoDataFrame | dict:
63+
) -> gpd.GeoDataFrame | dict[str, Any]:
6064
"""Gets the flowlines for the specified navigation either by comid or feature
6165
source in WGS84 lat/long coordinates as GeoDataFrame containing a polyline geometry.
6266
@@ -116,7 +120,7 @@ def get_flowlines(
116120
else:
117121
err_msg = f"Error getting flowlines for comid '{comid}'"
118122

119-
feature_collection = _query_nldi(url, query_params, err_msg)
123+
feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg))
120124
if as_json:
121125
return feature_collection
122126
gdf = _features_to_gdf(feature_collection)
@@ -129,7 +133,7 @@ def get_basin(
129133
simplified: bool = True,
130134
split_catchment: bool = False,
131135
as_json: bool = False,
132-
) -> gpd.GeoDataFrame | dict:
136+
) -> gpd.GeoDataFrame | dict[str, Any]:
133137
"""Gets the aggregated basin for the specified feature in WGS84 lat/lon
134138
as GeoDataFrame or as JSON conatining a polygon geometry.
135139
@@ -162,14 +166,17 @@ def get_basin(
162166
raise ValueError("feature_id is required")
163167

164168
url = f"{NLDI_API_BASE_URL}/{feature_source}/{feature_id}/basin"
165-
simplified = str(simplified).lower()
166-
split_catchment = str(split_catchment).lower()
167-
query_params = {"simplified": simplified, "splitCatchment": split_catchment}
169+
simplified_str = str(simplified).lower()
170+
split_catchment_str = str(split_catchment).lower()
171+
query_params = {
172+
"simplified": simplified_str,
173+
"splitCatchment": split_catchment_str,
174+
}
168175
err_msg = (
169176
f"Error getting basin for feature source '{feature_source}' and "
170177
f"feature_id '{feature_id}'"
171178
)
172-
feature_collection = _query_nldi(url, query_params, err_msg)
179+
feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg))
173180
if as_json:
174181
return feature_collection
175182
gdf = _features_to_gdf(feature_collection)
@@ -187,7 +194,7 @@ def get_features(
187194
long: float | None = None,
188195
stop_comid: int | None = None,
189196
as_json: bool = False,
190-
) -> gpd.GeoDataFrame | dict:
197+
) -> gpd.GeoDataFrame | dict[str, Any]:
191198
"""Gets all features found along the specified navigation either by
192199
comid or feature source as points in WGS84 lat/long coordinates - a GeoDataFrame
193200
containing a point geometry.
@@ -285,7 +292,7 @@ def get_features(
285292
query_params = {}
286293
err_msg = _features_err_msg(feature_source, feature_id, comid, data_source)
287294

288-
feature_collection = _query_nldi(url, query_params, err_msg)
295+
feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg))
289296
if as_json:
290297
return feature_collection
291298
gdf = _features_to_gdf(feature_collection)
@@ -321,7 +328,7 @@ def get_features_by_data_source(data_source: str) -> gpd.GeoDataFrame:
321328
_validate_data_source(data_source)
322329
url = f"{NLDI_API_BASE_URL}/{data_source}"
323330
err_msg = f"Error getting features for data source '{data_source}'"
324-
feature_collection = _query_nldi(url, {}, err_msg)
331+
feature_collection = cast("dict[str, Any]", _query_nldi(url, {}, err_msg))
325332
gdf = _features_to_gdf(feature_collection)
326333
return gdf
327334

@@ -336,7 +343,7 @@ def search(
336343
lat: float | None = None,
337344
long: float | None = None,
338345
distance: int = 50,
339-
) -> dict:
346+
) -> dict[str, Any]:
340347
"""Searches for the specified feature in NLDI and returns the results
341348
as a dictionary.
342349
@@ -408,7 +415,7 @@ def search(
408415
if (lat is None) != (long is None):
409416
raise ValueError("Both lat and long are required")
410417

411-
find = find.lower()
418+
find = cast(Literal["basin", "flowlines", "features"], find.lower())
412419
if find not in ("basin", "flowlines", "features"):
413420
raise ValueError(
414421
f"Invalid value for find: {find} - allowed values are:"
@@ -428,6 +435,10 @@ def search(
428435
return get_features(lat=lat, long=long, as_json=True)
429436

430437
if find == "basin":
438+
if feature_source is None or feature_id is None:
439+
raise ValueError(
440+
"feature_source and feature_id are required to find a basin"
441+
)
431442
return get_basin(
432443
feature_source=feature_source, feature_id=feature_id, as_json=True
433444
)
@@ -458,7 +469,7 @@ def search(
458469
)
459470

460471

461-
def _validate_data_source(data_source: str):
472+
def _validate_data_source(data_source: str) -> None:
462473
# A helper function to validate user specified data source/feature source
463474

464475
global _AVAILABLE_DATA_SOURCES
@@ -487,7 +498,12 @@ def _validate_data_source(data_source: str):
487498
raise ValueError(err_msg)
488499

489500

490-
def _features_err_msg(feature_source, feature_id, comid, data_source) -> str:
501+
def _features_err_msg(
502+
feature_source: str | None,
503+
feature_id: str | None,
504+
comid: int | None,
505+
data_source: str | None,
506+
) -> str:
491507
if feature_source is not None:
492508
return (
493509
f"Error getting features for feature source '{feature_source}'"
@@ -512,7 +528,7 @@ def _validate_navigation_mode(navigation_mode: str | None) -> str:
512528

513529
def _validate_feature_source_comid(
514530
feature_source: str | None, feature_id: str | None, comid: int | None
515-
):
531+
) -> None:
516532
if feature_source is not None and feature_id is None:
517533
raise ValueError("feature_id is required if feature_source is provided")
518534
if feature_id is not None and feature_source is None:

0 commit comments

Comments
 (0)