Skip to content

Commit b627c89

Browse files
committed
Add NaN-skipping to focal_stats CUDA kernels (#1092)
All focal_stats CUDA kernels (_focal_mean_cuda, _focal_sum_cuda, _focal_std_cuda, _focal_var_cuda, _focal_range_cuda, _focal_min_cuda, _focal_max_cuda) now skip NaN neighbors with `if v != v: continue`, matching the numpy path which uses np.nanmean/nansum/nanstd/etc. Previously, NaN propagated through arithmetic, giving different results on GPU vs CPU when input contained NaN.
1 parent 443ed78 commit b627c89

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

xrspatial/focal.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -608,13 +608,13 @@ def _focal_min_cuda(data, kernel, out):
608608

609609
if 0 <= ii < rows and 0 <= jj < cols:
610610
v = data[ii, jj]
611+
if v != v: # NaN check
612+
continue
611613
if (not found) or (v < m):
612614
m = v
613615
found = True
614616

615-
# With your mask containing the center, found should be True.
616-
# But keep a safe fallback anyway.
617-
out[i, j] = m if found else data[i, j]
617+
out[i, j] = m if found else math.nan
618618

619619

620620
@cuda.jit
@@ -636,20 +636,20 @@ def _focal_max_cuda(data, kernel, out):
636636
for h in range(kernel.shape[1]):
637637
w = kernel[k, h]
638638
if w == 0:
639-
continue # mask says "ignore this neighbor"
639+
continue
640640

641641
ii = i + k - dr
642642
jj = j + h - dc
643643

644644
if 0 <= ii < rows and 0 <= jj < cols:
645645
v = data[ii, jj]
646+
if v != v: # NaN check
647+
continue
646648
if (not found) or (v > m):
647649
m = v
648650
found = True
649651

650-
# With your mask containing the center (1), found should always be True.
651-
# But keep this for safety.
652-
out[i, j] = m if found else data[i, j]
652+
out[i, j] = m if found else math.nan
653653

654654

655655
def _focal_range_cupy(data, kernel):
@@ -684,6 +684,8 @@ def _focal_range_cuda(data, kernel, out):
684684

685685
if 0 <= ii < rows and 0 <= jj < cols:
686686
v = data[ii, jj]
687+
if v != v: # NaN check
688+
continue
687689
if not found:
688690
mx = v
689691
mn = v
@@ -694,7 +696,7 @@ def _focal_range_cuda(data, kernel, out):
694696
if v < mn:
695697
mn = v
696698

697-
out[i, j] = (mx - mn) if found else 0.0
699+
out[i, j] = (mx - mn) if found else math.nan
698700

699701

700702
@cuda.jit
@@ -716,29 +718,29 @@ def _focal_std_cuda(data, kernel, out):
716718
for h in range(kernel.shape[1]):
717719
w = kernel[k, h]
718720
if w == 0:
719-
continue # mask says ignore
721+
continue
720722

721723
ii = i + k - dr
722724
jj = j + h - dc
723725

724726
if 0 <= ii < rows and 0 <= jj < cols:
725727
x = data[ii, jj]
728+
if x != x: # NaN check
729+
continue
726730
w_sum += w
727731
sum_wx += w * x
728732
sum_wx2 += w * x * x
729733

730-
# With your mask including the center, w_sum should be > 0. Guard anyway.
731734
if w_sum > 0.0:
732735
mean = sum_wx / w_sum
733736
var = (sum_wx2 / w_sum) - (mean * mean)
734737

735-
# Numerical safety (tiny negative due to floating point)
736738
if var < 0.0:
737739
var = 0.0
738740

739741
out[i, j] = math.sqrt(var)
740742
else:
741-
out[i, j] = 0.0
743+
out[i, j] = math.nan
742744

743745

744746
@cuda.jit
@@ -760,13 +762,15 @@ def _focal_var_cuda(data, kernel, out):
760762
for h in range(kernel.shape[1]):
761763
w = kernel[k, h]
762764
if w == 0:
763-
continue # mask says ignore
765+
continue
764766

765767
ii = i + k - dr
766768
jj = j + h - dc
767769

768770
if 0 <= ii < rows and 0 <= jj < cols:
769771
x = data[ii, jj]
772+
if x != x: # NaN check
773+
continue
770774
w_sum += w
771775
sum_wx += w * x
772776
sum_wx2 += w * x * x
@@ -775,13 +779,12 @@ def _focal_var_cuda(data, kernel, out):
775779
mean = sum_wx / w_sum
776780
var = (sum_wx2 / w_sum) - (mean * mean)
777781

778-
# numerical guard for tiny negative due to float rounding
779782
if var < 0.0:
780783
var = 0.0
781784

782785
out[i, j] = var
783786
else:
784-
out[i, j] = 0.0
787+
out[i, j] = math.nan
785788

786789

787790
@cuda.jit
@@ -852,19 +855,24 @@ def _focal_sum_cuda(data, kernel, out):
852855
dc = kernel.shape[1] // 2
853856

854857
s = 0.0
858+
found = False
855859
for k in range(kernel.shape[0]):
856860
for h in range(kernel.shape[1]):
857861
w = kernel[k, h]
858862
if w == 0:
859-
continue # mask says ignore
863+
continue
860864

861865
ii = i + k - dr
862866
jj = j + h - dc
863867

864868
if 0 <= ii < rows and 0 <= jj < cols:
865-
s += w * data[ii, jj]
869+
v = data[ii, jj]
870+
if v != v: # NaN check
871+
continue
872+
s += w * v
873+
found = True
866874

867-
out[i, j] = s
875+
out[i, j] = s if found else math.nan
868876

869877

870878
def _focal_stats_func_cupy(data, kernel, func=_focal_max_cuda):
@@ -894,21 +902,22 @@ def _focal_mean_cuda(data, kernel, out):
894902
for h in range(kernel.shape[1]):
895903
w = kernel[k, h]
896904
if w == 0:
897-
continue # mask says ignore
905+
continue
898906

899907
ii = i + k - dr
900908
jj = j + h - dc
901909

902910
if 0 <= ii < rows and 0 <= jj < cols:
903-
s += w * data[ii, jj]
911+
v = data[ii, jj]
912+
if v != v: # NaN check
913+
continue
914+
s += w * v
904915
w_sum += w
905916

906-
# With your mask including the center, w_sum should be > 0.
907-
# Guard anyway to avoid divide-by-zero.
908917
if w_sum > 0.0:
909918
out[i, j] = s / w_sum
910919
else:
911-
out[i, j] = data[i, j]
920+
out[i, j] = math.nan
912921

913922

914923
def _focal_stats_cupy(agg, kernel, stats_funcs):

0 commit comments

Comments
 (0)