Skip to content

Commit 65b354f

Browse files
authored
Fix three accuracy bugs in zonal stats dask backend (#1090) (#1091)
* Fix three accuracy bugs in zonal stats dask backend (#1090) 1. Dask sum/count/min/max now return NaN (not 0) for zones with all-NaN values, matching the numpy backend. Uses _nanreduce_preserve_allnan wrapper around np.nansum/nanmax/nanmin. 2. Dask std/var replaced the naive one-pass formula with the Chan-Golub-LeVeque parallel merge algorithm, which avoids catastrophic cancellation when the mean is large relative to the variance. 3. _calc_stats and crosstab helpers now skip the nodata_values != comparison when nodata_values is None, avoiding numpy FutureWarning. * Add tests and fix block-level M2 for variance stability (#1090) - Block-level sum_squares now computes M2 (sum of squared deviations from block mean) instead of raw sum(x²), avoiding float64 precision loss for large values. - Updated test_stats_all_nan_zone and test_stats_nodata_wipes_zone to expect NaN from dask (no longer 0). - Added test_stats_variance_numerical_stability_1090: values near 1e8 with spread of 1, verifying dask matches numpy to 1e-6. - Added test_stats_nodata_none_no_warning_1090: confirms no FutureWarning when nodata_values=None.
1 parent e00a52a commit 65b354f

File tree

2 files changed

+168
-46
lines changed

2 files changed

+168
-46
lines changed

xrspatial/tests/test_zonal.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -680,18 +680,8 @@ def test_stats_all_nan_zone(backend):
680680
'sum': [12.0],
681681
'count': [2],
682682
}
683-
elif 'dask' in backend:
684-
# dask uses nansum reduction, so count/sum of all-NaN become 0
685-
expected = {
686-
'zone': [1, 2],
687-
'mean': [np.nan, 6.0],
688-
'max': [np.nan, 7.0],
689-
'min': [np.nan, 5.0],
690-
'sum': [0.0, 12.0],
691-
'count': [0, 2],
692-
}
693683
else:
694-
# numpy keeps empty zone with NaN for every stat
684+
# numpy and dask both return NaN for all-NaN zones
695685
expected = {
696686
'zone': [1, 2],
697687
'mean': [np.nan, 6.0],
@@ -798,16 +788,8 @@ def test_stats_nodata_wipes_zone(backend):
798788
'sum': [10.0],
799789
'count': [2],
800790
}
801-
elif 'dask' in backend:
802-
expected = {
803-
'zone': [1, 2],
804-
'mean': [np.nan, 5.0],
805-
'max': [np.nan, 7.0],
806-
'min': [np.nan, 3.0],
807-
'sum': [0.0, 10.0],
808-
'count': [0, 2],
809-
}
810791
else:
792+
# numpy and dask both return NaN for zones with no valid values
811793
expected = {
812794
'zone': [1, 2],
813795
'mean': [np.nan, 5.0],
@@ -868,6 +850,71 @@ def test_zonal_stats_inputs_unmodified(backend, data_zones, data_values_2d, resu
868850
assert_input_data_unmodified(data_values_2d, copied_data_values_2d)
869851

870852

853+
@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
854+
@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning")
855+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
856+
def test_stats_variance_numerical_stability_1090(backend):
857+
"""Dask std/var should match numpy for data with large mean, small spread.
858+
859+
Regression test for #1090: the naive one-pass formula
860+
``(Σx² − (Σx)²/n) / n`` loses precision through catastrophic
861+
cancellation. The fix uses Chan-Golub-LeVeque parallel merge.
862+
"""
863+
if 'dask' in backend and not dask_array_available():
864+
pytest.skip("Requires Dask")
865+
866+
# Values near 1e8 with a spread of 1: the naive formula would lose
867+
# most of the significant digits in float64.
868+
zones_data = np.array([[1, 1, 1, 1, 1, 1]])
869+
values_data = np.array([[1e8, 1e8 + 1, 1e8 + 2,
870+
1e8 + 3, 1e8 + 4, 1e8 + 5]], dtype=np.float64)
871+
872+
zones = create_test_raster(zones_data, backend, chunks=(1, 3))
873+
values = create_test_raster(values_data, backend, chunks=(1, 3))
874+
875+
df_result = stats(zones=zones, values=values,
876+
stats_funcs=['mean', 'std', 'var'])
877+
878+
if hasattr(df_result, 'compute'):
879+
df_result = df_result.compute()
880+
881+
# Reference: population variance of [0,1,2,3,4,5] = 35/12 ≈ 2.9167
882+
expected_var = np.var(np.arange(6, dtype=np.float64))
883+
expected_std = np.std(np.arange(6, dtype=np.float64))
884+
885+
actual_var = float(df_result['var'].iloc[0])
886+
actual_std = float(df_result['std'].iloc[0])
887+
888+
assert abs(actual_var - expected_var) < 1e-6, (
889+
f"var={actual_var}, expected={expected_var}"
890+
)
891+
assert abs(actual_std - expected_std) < 1e-6, (
892+
f"std={actual_std}, expected={expected_std}"
893+
)
894+
895+
896+
def test_stats_nodata_none_no_warning_1090():
897+
"""Passing nodata_values=None (the default) should not trigger warnings.
898+
899+
Regression test for #1090: ``zone_values != None`` triggered a numpy
900+
FutureWarning.
901+
"""
902+
import warnings
903+
904+
zones_data = np.array([[1, 1], [2, 2]], dtype=float)
905+
values_data = np.array([[1.0, 2.0], [3.0, 4.0]])
906+
zones = xr.DataArray(zones_data)
907+
values = xr.DataArray(values_data)
908+
909+
with warnings.catch_warnings():
910+
warnings.simplefilter("error")
911+
df = stats(zones=zones, values=values, nodata_values=None)
912+
913+
assert len(df) == 2
914+
assert float(df['mean'].iloc[0]) == 1.5
915+
assert float(df['mean'].iloc[1]) == 3.5
916+
917+
871918
@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
872919
@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning")
873920
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])

xrspatial/zonal.py

Lines changed: 101 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,77 @@ def _stats_majority(data):
198198
min=lambda z: z.min(),
199199
sum=lambda z: z.sum(),
200200
count=lambda z: _stats_count(z),
201-
sum_squares=lambda z: (z**2).sum()
201+
sum_squares=lambda z: ((z - z.mean()) ** 2).sum() # block-level M2
202202
)
203203

204204

205+
def _nanreduce_preserve_allnan(blocks, func):
206+
"""Reduce across blocks, returning NaN when ALL blocks are NaN for a zone.
207+
208+
``np.nansum`` returns 0 for all-NaN input; we want NaN so that zones
209+
with no valid values propagate NaN, consistent with the numpy backend.
210+
"""
211+
result = func(blocks, axis=0)
212+
all_nan = np.all(np.isnan(blocks), axis=0)
213+
result[all_nan] = np.nan
214+
return result
215+
216+
205217
_DASK_STATS = dict(
206-
max=lambda block_maxes: np.nanmax(block_maxes, axis=0),
207-
min=lambda block_mins: np.nanmin(block_mins, axis=0),
208-
sum=lambda block_sums: np.nansum(block_sums, axis=0),
209-
count=lambda block_counts: np.nansum(block_counts, axis=0),
210-
sum_squares=lambda block_sum_squares: np.nansum(block_sum_squares, axis=0),
211-
squared_sum=lambda block_sums: np.nansum(block_sums, axis=0)**2,
218+
max=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmax),
219+
min=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmin),
220+
sum=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
221+
count=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
222+
sum_squares=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
212223
)
213-
def _dask_mean(sums, counts): return sums / counts # noqa
214-
def _dask_std(sum_squares, squared_sum, n): return np.sqrt((sum_squares - squared_sum/n) / n) # noqa
215-
def _dask_var(sum_squares, squared_sum, n): return (sum_squares - squared_sum/n) / n # noqa
224+
225+
226+
def _dask_mean(sums, counts): # noqa
227+
return sums / counts
228+
229+
230+
def _parallel_variance(block_counts, block_sums, block_m2s):
231+
"""Population variance via Chan-Golub-LeVeque parallel merge.
232+
233+
Each input is (n_blocks, n_zones). ``block_m2s`` contains
234+
per-block M2 values (sum of squared deviations from the block mean),
235+
NOT raw sum-of-squares. Returns (n_zones,) population variance,
236+
with NaN for zones that have no valid values in any block.
237+
"""
238+
n_blocks = block_counts.shape[0]
239+
n_zones = block_counts.shape[1]
240+
241+
n_acc = np.zeros(n_zones, dtype=np.float64)
242+
mean_acc = np.zeros(n_zones, dtype=np.float64)
243+
m2_acc = np.zeros(n_zones, dtype=np.float64)
244+
245+
for i in range(n_blocks):
246+
nc = np.asarray(block_counts[i], dtype=np.float64)
247+
sc = np.asarray(block_sums[i], dtype=np.float64)
248+
m2_b = np.asarray(block_m2s[i], dtype=np.float64)
249+
250+
has_data = np.isfinite(nc) & (nc > 0)
251+
nc_safe = np.where(has_data, nc, 1.0) # avoid /0
252+
253+
with np.errstate(invalid='ignore', divide='ignore'):
254+
mean_b = sc / nc_safe
255+
256+
nc = np.where(has_data, nc, 0.0)
257+
n_ab = n_acc + nc
258+
259+
delta = mean_b - mean_acc
260+
with np.errstate(invalid='ignore', divide='ignore'):
261+
n_ab_safe = np.where(n_ab > 0, n_ab, 1.0)
262+
correction = delta ** 2 * n_acc * nc / n_ab_safe
263+
new_mean = mean_acc + delta * nc / n_ab_safe
264+
265+
m2_acc = np.where(has_data, m2_acc + m2_b + correction, m2_acc)
266+
mean_acc = np.where(has_data, new_mean, mean_acc)
267+
n_acc = np.where(has_data, n_ab, n_acc)
268+
269+
with np.errstate(invalid='ignore', divide='ignore'):
270+
var = np.where(n_acc > 0, m2_acc / n_acc, np.nan)
271+
return var
216272

217273

218274
@ngjit
@@ -269,7 +325,10 @@ def _calc_stats(
269325
if unique_zones[i] in zone_ids:
270326
zone_values = values_by_zones[start:end]
271327
# filter out non-finite and nodata_values
272-
zone_values = zone_values[np.isfinite(zone_values) & (zone_values != nodata_values)]
328+
mask = np.isfinite(zone_values)
329+
if nodata_values is not None:
330+
mask = mask & (zone_values != nodata_values)
331+
zone_values = zone_values[mask]
273332
if len(zone_values) > 0:
274333
results[i] = func(zone_values)
275334
start = end
@@ -342,9 +401,11 @@ def _stats_dask_numpy(
342401
sum=values.dtype,
343402
count=np.int64,
344403
sum_squares=values.dtype,
345-
squared_sum=values.dtype,
346404
)
347405

406+
# Keep per-block stacked arrays for the parallel variance merge
407+
stacked_blocks = {}
408+
348409
for s in basis_stats:
349410
if s == 'sum_squares' and not compute_sum_squares:
350411
continue
@@ -358,21 +419,34 @@ def _stats_dask_numpy(
358419
for z, v in zip(zones_blocks, values_blocks)
359420
]
360421
zonal_stats = da.stack(stats_by_block, allow_unknown_chunksizes=True)
422+
423+
if compute_sum_squares and s in ('count', 'sum', 'sum_squares'):
424+
stacked_blocks[s] = zonal_stats
425+
361426
stats_func_by_block = delayed(_DASK_STATS[s])
362427
stats_dict[s] = da.from_delayed(
363428
stats_func_by_block(zonal_stats), shape=(np.nan,), dtype=np.float64
364429
)
365430

366431
if 'mean' in stats_funcs:
367432
stats_dict['mean'] = _dask_mean(stats_dict['sum'], stats_dict['count'])
368-
if 'std' in stats_funcs:
369-
stats_dict['std'] = _dask_std(
370-
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
371-
)
372-
if 'var' in stats_funcs:
373-
stats_dict['var'] = _dask_var(
374-
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
433+
434+
if 'std' in stats_funcs or 'var' in stats_funcs:
435+
var_result = da.from_delayed(
436+
delayed(_parallel_variance)(
437+
stacked_blocks['count'],
438+
stacked_blocks['sum'],
439+
stacked_blocks['sum_squares'],
440+
),
441+
shape=(np.nan,), dtype=np.float64,
375442
)
443+
if 'var' in stats_funcs:
444+
stats_dict['var'] = var_result
445+
if 'std' in stats_funcs:
446+
stats_dict['std'] = da.from_delayed(
447+
delayed(np.sqrt)(var_result),
448+
shape=(np.nan,), dtype=np.float64,
449+
)
376450

377451
# generate dask dataframe
378452
stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1, ignore_unknown_divisions=True)
@@ -846,9 +920,10 @@ def _single_zone_crosstab_2d(
846920
):
847921
# 1D flatten zone_values, i.e, original data is 2D
848922
# filter out non-finite and nodata_values
849-
zone_values = zone_values[
850-
np.isfinite(zone_values) & (zone_values != nodata_values)
851-
]
923+
mask = np.isfinite(zone_values)
924+
if nodata_values is not None:
925+
mask = mask & (zone_values != nodata_values)
926+
zone_values = zone_values[mask]
852927
total_count = zone_values.shape[0]
853928
crosstab_dict[TOTAL_COUNT].append(total_count)
854929

@@ -877,10 +952,10 @@ def _single_zone_crosstab_3d(
877952
if cat in cat_ids:
878953
zone_cat_data = zone_values[j]
879954
# filter out non-finite and nodata_values
880-
zone_cat_data = zone_cat_data[
881-
np.isfinite(zone_cat_data)
882-
& (zone_cat_data != nodata_values)
883-
]
955+
cat_mask = np.isfinite(zone_cat_data)
956+
if nodata_values is not None:
957+
cat_mask = cat_mask & (zone_cat_data != nodata_values)
958+
zone_cat_data = zone_cat_data[cat_mask]
884959
crosstab_dict[cat].append(stats_func(zone_cat_data))
885960

886961

0 commit comments

Comments
 (0)