Skip to content

Commit 0d4bb18

Browse files
authored
Propagate NaN from curve_number in curve_number_runoff (#1104) (#1105)
The NaN fixup only checked np.isnan(p), missing NaN in the curve number raster. When CN is NaN, P > NaN is False, so np.where wrote 0.0 instead of NaN. Now checks both p and cn in all three backends.
1 parent 18177b4 commit 0d4bb18

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

xrspatial/flood.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,8 @@ def _cn_runoff_numpy(p, cn):
328328
s = (25400.0 / cn) - 254.0
329329
ia = 0.2 * s
330330
q = np.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0)
331-
# propagate NaN from rainfall
332-
q = np.where(np.isnan(p), np.nan, q)
331+
# propagate NaN from rainfall or curve number
332+
q = np.where(np.isnan(p) | np.isnan(cn), np.nan, q)
333333
return q
334334

335335

@@ -340,7 +340,7 @@ def _cn_runoff_cupy(p, cn):
340340
s = (25400.0 / cn) - 254.0
341341
ia = 0.2 * s
342342
q = cp.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0)
343-
q = cp.where(cp.isnan(p), cp.nan, q)
343+
q = cp.where(cp.isnan(p) | cp.isnan(cn), cp.nan, q)
344344
return q
345345

346346

@@ -349,7 +349,7 @@ def _cn_runoff_dask(p, cn):
349349
s = (25400.0 / cn) - 254.0
350350
ia = 0.2 * s
351351
q = _da.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0)
352-
q = _da.where(_da.isnan(p), np.nan, q)
352+
q = _da.where(_da.isnan(p) | _da.isnan(cn), np.nan, q)
353353
return q
354354

355355

xrspatial/tests/test_flood.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,32 @@ def test_numpy_equals_dask_cupy(self):
362362
expected_results=result_np.data)
363363

364364

365+
def test_cn_runoff_nan_curve_number_1104():
366+
"""NaN in curve_number should produce NaN output, not 0.
367+
368+
Regression test for #1104: P > NaN is always False, so np.where
369+
took the else-branch and wrote 0.0 instead of NaN.
370+
"""
371+
rainfall = xr.DataArray(
372+
np.array([[100.0, 100.0, 100.0]], dtype=np.float64)
373+
)
374+
cn_data = np.array([[80.0, np.nan, 90.0]], dtype=np.float64)
375+
cn_raster = xr.DataArray(cn_data)
376+
377+
result = curve_number_runoff(rainfall, curve_number=cn_raster)
378+
data = result.data
379+
if hasattr(data, 'compute'):
380+
data = data.compute()
381+
data = np.asarray(data)
382+
383+
# Cell 0 (CN=80): valid runoff
384+
assert np.isfinite(data[0, 0]) and data[0, 0] > 0
385+
# Cell 1 (CN=NaN): must be NaN, not 0
386+
assert np.isnan(data[0, 1]), f"expected NaN, got {data[0, 1]}"
387+
# Cell 2 (CN=90): valid runoff
388+
assert np.isfinite(data[0, 2]) and data[0, 2] > 0
389+
390+
365391
# ===================================================================
366392
# travel_time
367393
# ===================================================================

0 commit comments

Comments
 (0)