Skip to content

Commit 1437c61

Browse files
authored
Fix NaN handling in focal_stats CUDA kernels (#1092) (#1093)
* Add NaN-skipping to focal_stats CUDA kernels (#1092) All focal_stats CUDA kernels (_focal_mean_cuda, _focal_sum_cuda, _focal_std_cuda, _focal_var_cuda, _focal_range_cuda, _focal_min_cuda, _focal_max_cuda) now skip NaN neighbors with `if v != v: continue`, matching the numpy path which uses np.nanmean/nansum/nanstd/etc. Previously, NaN propagated through arithmetic, giving different results on GPU vs CPU when input contained NaN. * Add NaN tests for focal_stats CUDA kernels (#1092) - test_focal_stats_nan_handling_1092: verifies all 7 stats (mean, sum, min, max, std, var, range) skip NaN neighbors across all 4 backends. - test_focal_stats_all_nan_window_1092: all-NaN window gives NaN for mean/min/max and 0 for sum (matching numpy nansum behavior). - Fixed sum kernel to return 0 (not NaN) for all-NaN windows, matching numpy nansum semantics.
1 parent 65b354f commit 1437c61

File tree

2 files changed

+132
-23
lines changed

2 files changed

+132
-23
lines changed

xrspatial/focal.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -608,13 +608,13 @@ def _focal_min_cuda(data, kernel, out):
608608

609609
if 0 <= ii < rows and 0 <= jj < cols:
610610
v = data[ii, jj]
611+
if v != v: # NaN check
612+
continue
611613
if (not found) or (v < m):
612614
m = v
613615
found = True
614616

615-
# With your mask containing the center, found should be True.
616-
# But keep a safe fallback anyway.
617-
out[i, j] = m if found else data[i, j]
617+
out[i, j] = m if found else math.nan
618618

619619

620620
@cuda.jit
@@ -636,20 +636,20 @@ def _focal_max_cuda(data, kernel, out):
636636
for h in range(kernel.shape[1]):
637637
w = kernel[k, h]
638638
if w == 0:
639-
continue # mask says "ignore this neighbor"
639+
continue
640640

641641
ii = i + k - dr
642642
jj = j + h - dc
643643

644644
if 0 <= ii < rows and 0 <= jj < cols:
645645
v = data[ii, jj]
646+
if v != v: # NaN check
647+
continue
646648
if (not found) or (v > m):
647649
m = v
648650
found = True
649651

650-
# With your mask containing the center (1), found should always be True.
651-
# But keep this for safety.
652-
out[i, j] = m if found else data[i, j]
652+
out[i, j] = m if found else math.nan
653653

654654

655655
def _focal_range_cupy(data, kernel):
@@ -684,6 +684,8 @@ def _focal_range_cuda(data, kernel, out):
684684

685685
if 0 <= ii < rows and 0 <= jj < cols:
686686
v = data[ii, jj]
687+
if v != v: # NaN check
688+
continue
687689
if not found:
688690
mx = v
689691
mn = v
@@ -694,7 +696,7 @@ def _focal_range_cuda(data, kernel, out):
694696
if v < mn:
695697
mn = v
696698

697-
out[i, j] = (mx - mn) if found else 0.0
699+
out[i, j] = (mx - mn) if found else math.nan
698700

699701

700702
@cuda.jit
@@ -716,29 +718,29 @@ def _focal_std_cuda(data, kernel, out):
716718
for h in range(kernel.shape[1]):
717719
w = kernel[k, h]
718720
if w == 0:
719-
continue # mask says ignore
721+
continue
720722

721723
ii = i + k - dr
722724
jj = j + h - dc
723725

724726
if 0 <= ii < rows and 0 <= jj < cols:
725727
x = data[ii, jj]
728+
if x != x: # NaN check
729+
continue
726730
w_sum += w
727731
sum_wx += w * x
728732
sum_wx2 += w * x * x
729733

730-
# With your mask including the center, w_sum should be > 0. Guard anyway.
731734
if w_sum > 0.0:
732735
mean = sum_wx / w_sum
733736
var = (sum_wx2 / w_sum) - (mean * mean)
734737

735-
# Numerical safety (tiny negative due to floating point)
736738
if var < 0.0:
737739
var = 0.0
738740

739741
out[i, j] = math.sqrt(var)
740742
else:
741-
out[i, j] = 0.0
743+
out[i, j] = math.nan
742744

743745

744746
@cuda.jit
@@ -760,13 +762,15 @@ def _focal_var_cuda(data, kernel, out):
760762
for h in range(kernel.shape[1]):
761763
w = kernel[k, h]
762764
if w == 0:
763-
continue # mask says ignore
765+
continue
764766

765767
ii = i + k - dr
766768
jj = j + h - dc
767769

768770
if 0 <= ii < rows and 0 <= jj < cols:
769771
x = data[ii, jj]
772+
if x != x: # NaN check
773+
continue
770774
w_sum += w
771775
sum_wx += w * x
772776
sum_wx2 += w * x * x
@@ -775,13 +779,12 @@ def _focal_var_cuda(data, kernel, out):
775779
mean = sum_wx / w_sum
776780
var = (sum_wx2 / w_sum) - (mean * mean)
777781

778-
# numerical guard for tiny negative due to float rounding
779782
if var < 0.0:
780783
var = 0.0
781784

782785
out[i, j] = var
783786
else:
784-
out[i, j] = 0.0
787+
out[i, j] = math.nan
785788

786789

787790
@cuda.jit
@@ -856,15 +859,18 @@ def _focal_sum_cuda(data, kernel, out):
856859
for h in range(kernel.shape[1]):
857860
w = kernel[k, h]
858861
if w == 0:
859-
continue # mask says ignore
862+
continue
860863

861864
ii = i + k - dr
862865
jj = j + h - dc
863866

864867
if 0 <= ii < rows and 0 <= jj < cols:
865-
s += w * data[ii, jj]
868+
v = data[ii, jj]
869+
if v != v: # NaN check
870+
continue
871+
s += w * v
866872

867-
out[i, j] = s
873+
out[i, j] = s # nansum: 0 when all NaN (matches numpy)
868874

869875

870876
def _focal_stats_func_cupy(data, kernel, func=_focal_max_cuda):
@@ -894,21 +900,22 @@ def _focal_mean_cuda(data, kernel, out):
894900
for h in range(kernel.shape[1]):
895901
w = kernel[k, h]
896902
if w == 0:
897-
continue # mask says ignore
903+
continue
898904

899905
ii = i + k - dr
900906
jj = j + h - dc
901907

902908
if 0 <= ii < rows and 0 <= jj < cols:
903-
s += w * data[ii, jj]
909+
v = data[ii, jj]
910+
if v != v: # NaN check
911+
continue
912+
s += w * v
904913
w_sum += w
905914

906-
# With your mask including the center, w_sum should be > 0.
907-
# Guard anyway to avoid divide-by-zero.
908915
if w_sum > 0.0:
909916
out[i, j] = s / w_sum
910917
else:
911-
out[i, j] = data[i, j]
918+
out[i, j] = math.nan
912919

913920

914921
def _focal_stats_cupy(agg, kernel, stats_funcs):

xrspatial/tests/test_focal.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,108 @@ def test_focal_stats_dask_cupy():
505505
equal_nan=True, rtol=1e-4)
506506

507507

508+
# --- focal_stats NaN handling (issue-1092) --------------------------------
509+
510+
@pytest.mark.parametrize("backend", ['numpy', 'cupy', 'dask+numpy', 'dask+cupy'])
511+
def test_focal_stats_nan_handling_1092(backend):
512+
"""All backends should skip NaN neighbors, not propagate them.
513+
514+
Regression test for #1092: CUDA kernels propagated NaN through
515+
arithmetic instead of skipping.
516+
"""
517+
from xrspatial.tests.general_checks import has_cuda_and_cupy
518+
if 'cupy' in backend and not has_cuda_and_cupy():
519+
pytest.skip("Requires CUDA and CuPy")
520+
if 'dask' in backend and da is None:
521+
pytest.skip("Requires Dask")
522+
523+
data = np.array([
524+
[1.0, np.nan, 3.0],
525+
[4.0, 5.0, 6.0],
526+
[7.0, 8.0, 9.0],
527+
])
528+
kernel = custom_kernel(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]))
529+
530+
agg = create_test_raster(data, backend=backend, chunks=(3, 3))
531+
result = focal_stats(agg, kernel,
532+
stats_funcs=['mean', 'sum', 'min', 'max', 'std', 'var', 'range'])
533+
534+
if hasattr(result.data, 'compute'):
535+
result = result.compute()
536+
537+
def _val(stat, r, c):
538+
d = result.sel(stats=stat).data
539+
if hasattr(d, 'get'):
540+
d = d.get()
541+
return float(np.asarray(d)[r, c])
542+
543+
# Center pixel (1,1): kernel hits [NaN, 4, 5, 6, 8] -> skip NaN -> [4,5,6,8]
544+
center_vals = np.array([4.0, 5.0, 6.0, 8.0])
545+
atol = 1e-3 # float32 tolerance
546+
547+
mean_val = _val('mean', 1, 1)
548+
sum_val = _val('sum', 1, 1)
549+
min_val = _val('min', 1, 1)
550+
max_val = _val('max', 1, 1)
551+
std_val = _val('std', 1, 1)
552+
var_val = _val('var', 1, 1)
553+
range_val = _val('range', 1, 1)
554+
555+
assert abs(mean_val - np.nanmean(center_vals)) < atol, f"mean={mean_val}"
556+
assert abs(sum_val - np.nansum(center_vals)) < atol, f"sum={sum_val}"
557+
assert abs(min_val - np.nanmin(center_vals)) < atol, f"min={min_val}"
558+
assert abs(max_val - np.nanmax(center_vals)) < atol, f"max={max_val}"
559+
assert abs(std_val - np.nanstd(center_vals)) < atol, f"std={std_val}"
560+
assert abs(var_val - np.nanvar(center_vals)) < atol, f"var={var_val}"
561+
assert abs(range_val - (np.nanmax(center_vals) - np.nanmin(center_vals))) < atol, (
562+
f"range={range_val}"
563+
)
564+
565+
# Top-left corner (0,0): kernel hits [NaN, 4, 1] (cross pattern)
566+
# NaN is from data[0,1] (up direction is OOB, left is OOB)
567+
# Wait: the cross kernel at (0,0) covers:
568+
# up=(-1,0)=OOB, down=(1,0)=4, left=(0,-1)=OOB, right=(0,1)=NaN, center=(0,0)=1
569+
# So valid values = [1, 4], NaN is skipped
570+
corner_vals = np.array([1.0, 4.0])
571+
mean_corner = _val('mean', 0, 0)
572+
assert abs(mean_corner - np.nanmean(corner_vals)) < atol, (
573+
f"corner mean={mean_corner}"
574+
)
575+
576+
577+
@pytest.mark.parametrize("backend", ['numpy', 'cupy'])
578+
def test_focal_stats_all_nan_window_1092(backend):
579+
"""A pixel whose entire kernel window is NaN should produce NaN."""
580+
from xrspatial.tests.general_checks import has_cuda_and_cupy
581+
if 'cupy' in backend and not has_cuda_and_cupy():
582+
pytest.skip("Requires CUDA and CuPy")
583+
584+
data = np.array([
585+
[np.nan, np.nan, np.nan],
586+
[np.nan, np.nan, np.nan],
587+
[np.nan, np.nan, 1.0],
588+
])
589+
kernel = custom_kernel(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]))
590+
591+
agg = create_test_raster(data, backend=backend)
592+
result = focal_stats(agg, kernel, stats_funcs=['mean', 'sum', 'min', 'max'])
593+
594+
if hasattr(result.data, 'compute'):
595+
result = result.compute()
596+
597+
def _val(stat, r, c):
598+
d = result.sel(stats=stat).data
599+
if hasattr(d, 'get'):
600+
d = d.get()
601+
return float(np.asarray(d)[r, c])
602+
603+
# Center pixel (1,1): kernel hits [NaN, NaN, NaN, NaN, NaN] -> all NaN
604+
assert np.isnan(_val('mean', 1, 1))
605+
assert _val('sum', 1, 1) == 0.0 # nansum of all-NaN = 0 (numpy behavior)
606+
assert np.isnan(_val('min', 1, 1))
607+
assert np.isnan(_val('max', 1, 1))
608+
609+
508610
# --- focal variety (issue-1040) ------------------------------------------
509611

510612
def _variety_reference_data():

0 commit comments

Comments
 (0)