Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions xrspatial/flood.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ def _cn_runoff_numpy(p, cn):
s = (25400.0 / cn) - 254.0
ia = 0.2 * s
q = np.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0)
# propagate NaN from rainfall
q = np.where(np.isnan(p), np.nan, q)
# propagate NaN from rainfall or curve number
q = np.where(np.isnan(p) | np.isnan(cn), np.nan, q)
return q


Expand All @@ -340,7 +340,7 @@ def _cn_runoff_cupy(p, cn):
s = (25400.0 / cn) - 254.0
ia = 0.2 * s
q = cp.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0)
q = cp.where(cp.isnan(p), cp.nan, q)
q = cp.where(cp.isnan(p) | cp.isnan(cn), cp.nan, q)
return q


Expand All @@ -349,7 +349,7 @@ def _cn_runoff_dask(p, cn):
s = (25400.0 / cn) - 254.0
ia = 0.2 * s
q = _da.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0)
q = _da.where(_da.isnan(p), np.nan, q)
q = _da.where(_da.isnan(p) | _da.isnan(cn), np.nan, q)
return q


Expand Down
26 changes: 26 additions & 0 deletions xrspatial/tests/test_flood.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,32 @@ def test_numpy_equals_dask_cupy(self):
expected_results=result_np.data)


def test_cn_runoff_nan_curve_number_1104():
"""NaN in curve_number should produce NaN output, not 0.

Regression test for #1104: P > NaN is always False, so np.where
took the else-branch and wrote 0.0 instead of NaN.
"""
rainfall = xr.DataArray(
np.array([[100.0, 100.0, 100.0]], dtype=np.float64)
)
cn_data = np.array([[80.0, np.nan, 90.0]], dtype=np.float64)
cn_raster = xr.DataArray(cn_data)

result = curve_number_runoff(rainfall, curve_number=cn_raster)
data = result.data
if hasattr(data, 'compute'):
data = data.compute()
data = np.asarray(data)

# Cell 0 (CN=80): valid runoff
assert np.isfinite(data[0, 0]) and data[0, 0] > 0
# Cell 1 (CN=NaN): must be NaN, not 0
assert np.isnan(data[0, 1]), f"expected NaN, got {data[0, 1]}"
# Cell 2 (CN=90): valid runoff
assert np.isfinite(data[0, 2]) and data[0, 2] > 0


# ===================================================================
# travel_time
# ===================================================================
Expand Down
Loading