Skip to content

Commit 961be13

Browse files
thodson-usgsclaude
andcommitted
fix(waterdata.xarray): resolve remaining review findings
- samples now surface station longitude/latitude (mapped from Location_Longitude/Location_Latitude; _point_coords reads explicit lon/lat columns in addition to an OGC geometry column) - metadata cache: a single large pull is no longer subject to within-batch FIFO eviction (the call's result is built from the freshly-parsed entries), and sites with no metadata are no longer negatively cached, so they retry - dense variable naming is deterministic and unambiguous: a bare name (e.g. discharge) is used only when unique; same-named series are all disambiguated by cell method / statistic / parameter code - dense multi-unit label is deterministic (sorted) instead of row-order dependent - row_size is int64 (was int32) to avoid overflow / cumsum truncation - select_series rejects descriptor coords as keys (lon/lat float-equality footgun) and can match a null instance key 64 offline tests pass; live samples lon/lat + dense naming verified. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent e47ad68 commit 961be13

2 files changed

Lines changed: 215 additions & 31 deletions

File tree

dataretrieval/waterdata/xarray.py

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,27 @@ def _lonlat(geom):
232232

233233

234234
def _point_coords(df, site):
235-
"""lon/lat dicts keyed by site from point geometry, or None."""
235+
"""lon/lat dicts keyed by site, or None.
236+
237+
Reads either a ``geometry`` column (the time-series getters' OGC response) or
238+
explicit ``longitude`` / ``latitude`` columns (the Samples profile, mapped via
239+
:data:`_SAMPLES_RENAME`) -- so every service surfaces station coordinates.
240+
"""
241+
if {"longitude", "latitude"}.issubset(df.columns):
242+
geo = df.dropna(subset=["longitude", "latitude"]).drop_duplicates(site)
243+
if geo.empty:
244+
return None
245+
lon, lat = {}, {}
246+
for site_id, x, y in zip(
247+
geo[site].to_numpy(),
248+
geo["longitude"].to_numpy(),
249+
geo["latitude"].to_numpy(),
250+
):
251+
try:
252+
lon[site_id], lat[site_id] = float(x), float(y)
253+
except (TypeError, ValueError):
254+
continue
255+
return (lon, lat) if lon else None
236256
if "geometry" not in df.columns:
237257
return None
238258
geo = df.dropna(subset=["geometry"]).drop_duplicates(site)
@@ -408,8 +428,9 @@ def lookup(self, site_ids):
408428
"""
409429
sites = sorted({str(s) for s in site_ids if _pd.notna(s)})
410430
# Racy read of the keys is fine: a concurrent miss just re-fetches (the
411-
# fetch is idempotent); only the writes in _ingest take the lock.
431+
# fetch is idempotent); only the writes in _store take the lock.
412432
todo = [s for s in sites if s not in self._entries]
433+
fresh: dict[str, dict] = {}
413434
if todo:
414435
try:
415436
meta, _ = self._getter(monitoring_location_id=todo)
@@ -420,12 +441,17 @@ def lookup(self, site_ids):
420441
stacklevel=2,
421442
)
422443
else:
423-
self._ingest(meta, todo)
444+
fresh = self._parse(meta, todo)
445+
self._store(fresh)
424446
param_meta: dict[str, dict] = {}
425447
site_meta: dict[str, dict] = {}
426448
with self._lock:
427449
for s in sites:
428-
entry = self._entries.get(s, {})
450+
# Prefer this call's freshly-parsed entry over the cache: the
451+
# bounded cache may have already evicted just-fetched sites when a
452+
# single pull's ``todo`` exceeds maxsize, but the current call
453+
# must still see every site it fetched.
454+
entry = fresh.get(s) or self._entries.get(s, {})
429455
param_meta.update(entry.get("params", {}))
430456
if entry.get("site"):
431457
site_meta[s] = entry["site"]
@@ -439,12 +465,8 @@ def clear(self):
439465
def __len__(self):
440466
return len(self._entries)
441467

442-
def _ingest(self, meta, todo):
443-
"""Parse ``meta`` into per-site entries, then merge + evict under lock.
444-
445-
The parsing runs lock-free on a local dict; only the (cheap) merge into
446-
the shared cache and the FIFO eviction past ``maxsize`` hold the lock.
447-
"""
468+
def _parse(self, meta, todo):
469+
"""Parse ``meta`` into per-site ``{params, site}`` entries (lock-free)."""
448470
fresh = {s: {"params": {}, "site": {}} for s in todo}
449471
if not meta.empty:
450472
name_cols = [c for c in _NAME_DESCRIPTORS if c in meta.columns]
@@ -470,8 +492,20 @@ def _ingest(self, meta, todo):
470492
}
471493
if desc:
472494
fresh[site]["site"] = desc
495+
return fresh
496+
497+
def _store(self, fresh):
498+
"""Merge non-empty entries into the bounded cache (FIFO eviction).
499+
500+
Sites that came back with no metadata are *not* cached, so a later call
501+
retries them rather than being stuck with a sticky empty result; the
502+
current call still sees them via the freshly-parsed ``fresh`` dict.
503+
"""
504+
keep = {s: e for s, e in fresh.items() if e["params"] or e["site"]}
505+
if not keep:
506+
return
473507
with self._lock:
474-
self._entries.update(fresh)
508+
self._entries.update(keep)
475509
while len(self._entries) > self._maxsize:
476510
self._entries.pop(next(iter(self._entries)))
477511

@@ -524,14 +558,24 @@ def select_series(ds, **keys):
524558
"so select by name instead, e.g. "
525559
"ds[variable].sel(monitoring_location_id=...)."
526560
)
527-
inst_coords = [c for c in ds.coords if ds[c].dims == ("timeseries",)]
561+
# Selectable keys are the series *identity* coordinates only -- exclude the
562+
# per-series descriptors (lon/lat are a float-equality footgun; unit/HUC/state
563+
# are not series identifiers).
564+
descriptors = {"longitude", "latitude", "unit_of_measure", *_SITE_DESCRIPTORS}
565+
inst_coords = [
566+
c for c in ds.coords if ds[c].dims == ("timeseries",) and c not in descriptors
567+
]
528568
mask = _np.ones(ds.sizes["timeseries"], dtype=bool)
529569
for key, value in keys.items():
530570
if key not in inst_coords:
531571
raise KeyError(
532-
f"{key!r} is not a per-series coordinate; choose from {inst_coords}."
572+
f"{key!r} is not a per-series identity coordinate; choose from "
573+
f"{inst_coords}."
533574
)
534-
mask &= ds[key].to_numpy() == value
575+
arr = ds[key].to_numpy()
576+
# NaN never equals anything, so match a missing instance key (e.g. a
577+
# characteristic with no sample fraction) by null-ness instead.
578+
mask &= _pd.isna(arr) if _is_missing(value) else (arr == value)
535579
matches = _np.flatnonzero(mask)
536580
if matches.size == 0:
537581
raise KeyError(f"no time series matches {keys}.")
@@ -563,6 +607,10 @@ def select_series(ds, **keys):
563607
"Result_SampleFraction": "sample_fraction",
564608
"Result_ResultDetectionCondition": "detection_condition",
565609
"Result_MeasureStatusIdentifier": "status",
610+
# Samples carry position as explicit columns (no OGC ``geometry``); map them
611+
# to the canonical names so _point_coords surfaces station lon/lat.
612+
"Location_Longitude": "longitude",
613+
"Location_Latitude": "latitude",
566614
}
567615
_CANONICAL_COORD_ATTRS = {
568616
"parameter_code": {"long_name": "USGS parameter code"},
@@ -786,7 +834,9 @@ def _assemble(self, work, inst_cols, ancillary, has_unit):
786834
)
787835
data_vars = {
788836
"value": ("obs", work["value"].to_numpy()),
789-
"row_size": ("timeseries", row_size.to_numpy().astype("int32")),
837+
# int64 (not int32): a single long, high-frequency series can exceed
838+
# 2^31 observations, and the select_series cumsum must not overflow.
839+
"row_size": ("timeseries", row_size.to_numpy().astype("int64")),
790840
}
791841
for c in ancillary:
792842
data_vars[c] = ("obs", work[c].to_numpy())
@@ -899,16 +949,25 @@ def _build_series(self, work, group_cols, ancillary, has_unit):
899949

900950
def _variable_datasets(self, work, group_cols, ancillary, has_unit):
901951
"""One pivoted ``(site, time)`` Dataset per (parameter, statistic)."""
902-
datasets, used = [], set()
952+
# First pass: gather each group's identity and base name, so naming can
953+
# see the whole set (a bare name is only used when it is unambiguous).
954+
specs = []
903955
for _, group in work.groupby(group_cols, dropna=False):
904956
pcode = _first_present(group, "parameter_code")
905957
stat = _first_present(group, "statistic_id")
906-
group_units = group["unit_of_measure"].dropna().unique() if has_unit else ()
907-
unit = group_units[0] if len(group_units) else None
908958
desc = self.series_meta.get(str(pcode), {}) if pcode is not None else {}
909-
910-
name = self._variable_name(desc, pcode, stat, used)
911-
used.add(name)
959+
base = _slug(_none_if_nan(desc.get("parameter_name")) or pcode or "value")
960+
specs.append((group, pcode, stat, desc, base))
961+
names = self._disambiguate([s[4] for s in specs], [(s[1], s[2]) for s in specs])
962+
963+
datasets = []
964+
for (group, pcode, stat, desc, _base), name in zip(specs, names):
965+
# Sort the units so the chosen label is deterministic across pulls
966+
# (values are not converted either way; see the multi-unit warning).
967+
group_units = (
968+
sorted(group["unit_of_measure"].dropna().unique()) if has_unit else []
969+
)
970+
unit = group_units[0] if group_units else None
912971

913972
if len(group_units) > 1:
914973
# One variable can carry only one ``units`` attr; surface the
@@ -951,15 +1010,34 @@ def _variable_datasets(self, work, group_cols, ancillary, has_unit):
9511010
return datasets
9521011

9531012
@staticmethod
954-
def _variable_name(desc, pcode, stat, used):
955-
"""A unique slug for a variable; disambiguate same-parameter series."""
956-
name = _slug(_none_if_nan(desc.get("parameter_name")) or pcode or "value")
957-
if name in used: # same parameter, different statistic -> distinct var
958-
op = CF_CELL_METHODS.get(str(stat)) or (str(stat) if stat else None)
959-
name = f"{name}_{_slug(op)}" if op else name
960-
while name in used:
961-
name += "_x"
962-
return name
1013+
def _disambiguate(bases, keys):
1014+
"""Map per-group base slugs to unique, deterministic variable names.
1015+
1016+
``keys[i]`` is the group's ``(parameter_code, statistic_id)``. A base used
1017+
by exactly one group stays bare (e.g. ``discharge``); a base shared by
1018+
several groups is disambiguated for *all* of them -- by the statistic's
1019+
cell-method operator (``discharge_maximum`` / ``discharge_mean``), falling
1020+
back to the statistic id then the parameter code -- so a bare name never
1021+
silently refers to an arbitrary one of several same-named series.
1022+
"""
1023+
counts: dict[str, int] = {}
1024+
for b in bases:
1025+
counts[b] = counts.get(b, 0) + 1
1026+
names, used = [], set()
1027+
for base, (pcode, stat) in zip(bases, keys):
1028+
if counts[base] == 1:
1029+
name = base
1030+
else:
1031+
op = CF_CELL_METHODS.get(str(stat)) if stat is not None else None
1032+
suffix = op or (str(stat) if stat is not None else None)
1033+
name = f"{base}_{_slug(suffix)}" if suffix else base
1034+
if name == base or name in used: # statistic didn't separate them
1035+
name = f"{base}_{_slug(pcode)}" if pcode is not None else base
1036+
while name in used:
1037+
name += "_x"
1038+
used.add(name)
1039+
names.append(name)
1040+
return names
9631041

9641042

9651043
class _StatsBuilder(_DatasetBuilder):

tests/waterdata_xarray_test.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,41 @@ def test_select_series_returns_time_indexed_single_series():
655655
assert s["value"].sel(time="2024-06-01").item() == 100
656656

657657

658+
def test_dense_same_parameter_two_statistics_no_bare_name():
659+
# 00060 under both 00001 (max) and 00003 (mean): the bare 'discharge' name is
660+
# ambiguous, so BOTH variables are disambiguated by their cell method -- no
661+
# order-dependent bare 'discharge' that silently means one of them.
662+
mx = _daily_frame(values=(500,), times=("2024-06-01",))
663+
mx["statistic_id"] = "00001"
664+
mn = _daily_frame(values=(100,), times=("2024-06-01",))
665+
ds = wdx._build_dense(
666+
pd.concat([mx, mn]), _meta(), service="daily", series_meta=_DISCHARGE_META
667+
)
668+
assert "discharge" not in ds.data_vars # no bare (ambiguous) name
669+
assert {"discharge_maximum", "discharge_mean"} <= set(ds.data_vars)
670+
assert ds["discharge_maximum"].attrs["cell_methods"] == "time: maximum"
671+
assert ds["discharge_mean"].attrs["cell_methods"] == "time: mean"
672+
673+
674+
def test_dense_single_statistic_keeps_bare_name():
675+
# The common single-statistic case keeps the clean bare name.
676+
ds = wdx._build_dense(
677+
_daily_frame(), _meta(), service="daily", series_meta=_DISCHARGE_META
678+
)
679+
assert "discharge" in ds.data_vars
680+
681+
682+
def test_select_series_matches_nan_instance_key():
683+
# An instance whose key is null (samples with no sample_fraction) must be
684+
# selectable by passing None, since `== NaN` never matches.
685+
df = _samples_frame()
686+
df["Result_SampleFraction"] = None
687+
ds = _samples_ds(df)
688+
s = wdx.select_series(ds, characteristic="Temperature, water", sample_fraction=None)
689+
assert set(s.sizes) == {"time"}
690+
assert "value" in s.data_vars
691+
692+
658693
def test_select_series_ambiguous_raises():
659694
# selecting by site alone matches both instances -> ask for more keys
660695
ds = _two_instance_ragged()
@@ -670,8 +705,11 @@ def test_select_series_no_match_raises():
670705

671706
def test_select_series_unknown_key_raises():
672707
ds = _two_instance_ragged()
673-
with pytest.raises(KeyError, match="not a per-series coordinate"):
708+
with pytest.raises(KeyError, match="not a per-series identity coordinate"):
674709
wdx.select_series(ds, bogus="x")
710+
# descriptor coords (lon/lat/unit/HUC/state) are not selectable identity keys
711+
with pytest.raises(KeyError, match="not a per-series identity coordinate"):
712+
wdx.select_series(ds, unit_of_measure="ft^3/s")
675713

676714

677715
def test_select_series_on_dense_raises_helpful_error():
@@ -760,6 +798,20 @@ def _samples_ds(frame):
760798
)
761799

762800

801+
def test_samples_surface_lonlat_from_location_columns():
802+
# Samples carry position as Location_Latitude/Location_Longitude (no OGC
803+
# geometry); the dataset must still get numeric longitude/latitude coords.
804+
frame = _samples_frame()
805+
frame["Location_Longitude"] = [-90.44]
806+
frame["Location_Latitude"] = [43.19]
807+
ds = _samples_ds(frame)
808+
assert "longitude" in ds.coords and "latitude" in ds.coords
809+
assert ds["longitude"].dtype.kind == "f"
810+
assert float(ds["longitude"].values[0]) == -90.44
811+
assert float(ds["latitude"].values[0]) == 43.19
812+
assert ds["longitude"].attrs["units"] == "degrees_east"
813+
814+
763815
def test_build_samples_single_characteristic():
764816
ds = _samples_ds(_samples_frame())
765817
assert set(ds.sizes) == {"obs", "timeseries"}
@@ -988,3 +1040,57 @@ def fake(monitoring_location_id):
9881040
wdx._FIELD_CACHE._entries["Y"] = {"params": {}, "site": {}}
9891041
wdx.clear_metadata_cache()
9901042
assert len(wdx._TS_CACHE) == 0 and len(wdx._FIELD_CACHE) == 0
1043+
1044+
1045+
def test_metadata_missing_site_is_not_negatively_cached():
1046+
# A site the metadata endpoint returns nothing for must NOT be cached as an
1047+
# empty entry (which would never be retried); a later call re-fetches it.
1048+
calls = []
1049+
1050+
def fake(monitoring_location_id):
1051+
calls.append(list(monitoring_location_id))
1052+
# respond only for S1, never for S2
1053+
rows = [
1054+
{
1055+
"monitoring_location_id": s,
1056+
"parameter_code": "00060",
1057+
"parameter_name": s,
1058+
}
1059+
for s in monitoring_location_id
1060+
if s == "S1"
1061+
]
1062+
return pd.DataFrame(rows), SimpleNamespace(url=None)
1063+
1064+
cache = wdx._MetadataCache(fake)
1065+
cache.lookup(["S1", "S2"])
1066+
cache.lookup(["S1", "S2"])
1067+
# S1 cached (hit, not re-fetched); S2 never cached, so it is re-requested.
1068+
assert calls[0] == ["S1", "S2"]
1069+
assert calls[1] == ["S2"] # only the still-uncached S2
1070+
assert "S1" in cache._entries and "S2" not in cache._entries
1071+
1072+
1073+
def test_metadata_lookup_survives_within_batch_eviction():
1074+
# A single pull whose site count exceeds maxsize must still return metadata
1075+
# for every requested site, even though the bounded cache can't hold them all.
1076+
sites = ["S0", "S1", "S2", "S3", "S4"]
1077+
1078+
def fake(monitoring_location_id):
1079+
rows = [
1080+
{
1081+
"monitoring_location_id": s,
1082+
"parameter_code": f"p{s}", # distinct per site
1083+
"parameter_name": f"name-{s}",
1084+
"hydrologic_unit_code": f"huc-{s}",
1085+
}
1086+
for s in monitoring_location_id
1087+
]
1088+
return pd.DataFrame(rows), SimpleNamespace(url=None)
1089+
1090+
cache = wdx._MetadataCache(fake, maxsize=2)
1091+
param_meta, site_meta = cache.lookup(sites)
1092+
# every requested site's metadata is in the result even though the bounded
1093+
# cache evicted most of the just-fetched batch.
1094+
assert {f"p{s}" for s in sites} <= set(param_meta)
1095+
assert set(site_meta) == set(sites)
1096+
assert len(cache) <= 2 # cache stayed bounded

0 commit comments

Comments
 (0)