Skip to content

Commit a67aa8e

Browse files
committed
Add tests and fix block-level M2 for variance stability (#1090)
- Block-level sum_squares now computes M2 (sum of squared deviations from block mean) instead of raw sum(x²), avoiding float64 precision loss for large values. - Updated test_stats_all_nan_zone and test_stats_nodata_wipes_zone to expect NaN from dask (no longer 0). - Added test_stats_variance_numerical_stability_1090: values near 1e8 with spread of 1, verifying dask matches numpy to 1e-6. - Added test_stats_nodata_none_no_warning_1090: confirms no FutureWarning when nodata_values=None.
1 parent 74bfdc1 commit a67aa8e

File tree

2 files changed

+73
-29
lines changed

2 files changed

+73
-29
lines changed

xrspatial/tests/test_zonal.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -680,18 +680,8 @@ def test_stats_all_nan_zone(backend):
680680
'sum': [12.0],
681681
'count': [2],
682682
}
683-
elif 'dask' in backend:
684-
# dask uses nansum reduction, so count/sum of all-NaN become 0
685-
expected = {
686-
'zone': [1, 2],
687-
'mean': [np.nan, 6.0],
688-
'max': [np.nan, 7.0],
689-
'min': [np.nan, 5.0],
690-
'sum': [0.0, 12.0],
691-
'count': [0, 2],
692-
}
693683
else:
694-
# numpy keeps empty zone with NaN for every stat
684+
# numpy and dask both return NaN for all-NaN zones
695685
expected = {
696686
'zone': [1, 2],
697687
'mean': [np.nan, 6.0],
@@ -798,16 +788,8 @@ def test_stats_nodata_wipes_zone(backend):
798788
'sum': [10.0],
799789
'count': [2],
800790
}
801-
elif 'dask' in backend:
802-
expected = {
803-
'zone': [1, 2],
804-
'mean': [np.nan, 5.0],
805-
'max': [np.nan, 7.0],
806-
'min': [np.nan, 3.0],
807-
'sum': [0.0, 10.0],
808-
'count': [0, 2],
809-
}
810791
else:
792+
# numpy and dask both return NaN for zones with no valid values
811793
expected = {
812794
'zone': [1, 2],
813795
'mean': [np.nan, 5.0],
@@ -868,6 +850,71 @@ def test_zonal_stats_inputs_unmodified(backend, data_zones, data_values_2d, resu
868850
assert_input_data_unmodified(data_values_2d, copied_data_values_2d)
869851

870852

853+
@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
854+
@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning")
855+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
856+
def test_stats_variance_numerical_stability_1090(backend):
857+
"""Dask std/var should match numpy for data with large mean, small spread.
858+
859+
Regression test for #1090: the naive one-pass formula
860+
``(Σx² − (Σx)²/n) / n`` loses precision through catastrophic
861+
cancellation. The fix uses Chan-Golub-LeVeque parallel merge.
862+
"""
863+
if 'dask' in backend and not dask_array_available():
864+
pytest.skip("Requires Dask")
865+
866+
# Values near 1e8 with a spread of 1: the naive formula would lose
867+
# most of the significant digits in float64.
868+
zones_data = np.array([[1, 1, 1, 1, 1, 1]])
869+
values_data = np.array([[1e8, 1e8 + 1, 1e8 + 2,
870+
1e8 + 3, 1e8 + 4, 1e8 + 5]], dtype=np.float64)
871+
872+
zones = create_test_raster(zones_data, backend, chunks=(1, 3))
873+
values = create_test_raster(values_data, backend, chunks=(1, 3))
874+
875+
df_result = stats(zones=zones, values=values,
876+
stats_funcs=['mean', 'std', 'var'])
877+
878+
if hasattr(df_result, 'compute'):
879+
df_result = df_result.compute()
880+
881+
# Reference: population variance of [0,1,2,3,4,5] = 35/12 ≈ 2.9167
882+
expected_var = np.var(np.arange(6, dtype=np.float64))
883+
expected_std = np.std(np.arange(6, dtype=np.float64))
884+
885+
actual_var = float(df_result['var'].iloc[0])
886+
actual_std = float(df_result['std'].iloc[0])
887+
888+
assert abs(actual_var - expected_var) < 1e-6, (
889+
f"var={actual_var}, expected={expected_var}"
890+
)
891+
assert abs(actual_std - expected_std) < 1e-6, (
892+
f"std={actual_std}, expected={expected_std}"
893+
)
894+
895+
896+
def test_stats_nodata_none_no_warning_1090():
897+
"""Passing nodata_values=None (the default) should not trigger warnings.
898+
899+
Regression test for #1090: ``zone_values != None`` triggered a numpy
900+
FutureWarning.
901+
"""
902+
import warnings
903+
904+
zones_data = np.array([[1, 1], [2, 2]], dtype=float)
905+
values_data = np.array([[1.0, 2.0], [3.0, 4.0]])
906+
zones = xr.DataArray(zones_data)
907+
values = xr.DataArray(values_data)
908+
909+
with warnings.catch_warnings():
910+
warnings.simplefilter("error")
911+
df = stats(zones=zones, values=values, nodata_values=None)
912+
913+
assert len(df) == 2
914+
assert float(df['mean'].iloc[0]) == 1.5
915+
assert float(df['mean'].iloc[1]) == 3.5
916+
917+
871918
@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
872919
@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning")
873920
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])

xrspatial/zonal.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _stats_majority(data):
198198
min=lambda z: z.min(),
199199
sum=lambda z: z.sum(),
200200
count=lambda z: _stats_count(z),
201-
sum_squares=lambda z: (z**2).sum()
201+
sum_squares=lambda z: ((z - z.mean()) ** 2).sum() # block-level M2
202202
)
203203

204204

@@ -227,15 +227,13 @@ def _dask_mean(sums, counts): # noqa
227227
return sums / counts
228228

229229

230-
def _parallel_variance(block_counts, block_sums, block_sum_squares):
230+
def _parallel_variance(block_counts, block_sums, block_m2s):
231231
"""Population variance via Chan-Golub-LeVeque parallel merge.
232232
233-
Each input is (n_blocks, n_zones). Returns (n_zones,) variance,
233+
Each input is (n_blocks, n_zones). ``block_m2s`` contains
234+
per-block M2 values (sum of squared deviations from the block mean),
235+
NOT raw sum-of-squares. Returns (n_zones,) population variance,
234236
with NaN for zones that have no valid values in any block.
235-
236-
This avoids the naive ``(Σx² − (Σx)²/n) / n`` formula whose
237-
subtraction can lose most significant digits when the mean is
238-
large relative to the standard deviation.
239237
"""
240238
n_blocks = block_counts.shape[0]
241239
n_zones = block_counts.shape[1]
@@ -247,14 +245,13 @@ def _parallel_variance(block_counts, block_sums, block_sum_squares):
247245
for i in range(n_blocks):
248246
nc = np.asarray(block_counts[i], dtype=np.float64)
249247
sc = np.asarray(block_sums[i], dtype=np.float64)
250-
sqc = np.asarray(block_sum_squares[i], dtype=np.float64)
248+
m2_b = np.asarray(block_m2s[i], dtype=np.float64)
251249

252250
has_data = np.isfinite(nc) & (nc > 0)
253251
nc_safe = np.where(has_data, nc, 1.0) # avoid /0
254252

255253
with np.errstate(invalid='ignore', divide='ignore'):
256254
mean_b = sc / nc_safe
257-
m2_b = sqc - sc ** 2 / nc_safe # block-internal M2
258255

259256
nc = np.where(has_data, nc, 0.0)
260257
n_ab = n_acc + nc

0 commit comments

Comments
 (0)