Skip to content

Commit 74bfdc1

Browse files
committed
Fix three accuracy bugs in zonal stats dask backend (#1090)
1. Dask sum/count/min/max now return NaN (not 0) for zones with all-NaN values, matching the numpy backend. Uses _nanreduce_preserve_allnan wrapper around np.nansum/nanmax/nanmin. 2. Dask std/var replaced the naive one-pass formula with the Chan-Golub-LeVeque parallel merge algorithm, which avoids catastrophic cancellation when the mean is large relative to the variance. 3. _calc_stats and crosstab helpers now skip the nodata_values != comparison when nodata_values is None, avoiding numpy FutureWarning.
1 parent 443ed78 commit 74bfdc1

File tree

1 file changed

+103
-25
lines changed

1 file changed

+103
-25
lines changed

xrspatial/zonal.py

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,76 @@ def _stats_majority(data):
202202
)
203203

204204

205+
def _nanreduce_preserve_allnan(blocks, func):
206+
"""Reduce across blocks, returning NaN when ALL blocks are NaN for a zone.
207+
208+
``np.nansum`` returns 0 for all-NaN input; we want NaN so that zones
209+
with no valid values propagate NaN, consistent with the numpy backend.
210+
"""
211+
result = func(blocks, axis=0)
212+
all_nan = np.all(np.isnan(blocks), axis=0)
213+
result[all_nan] = np.nan
214+
return result
215+
216+
205217
_DASK_STATS = dict(
206-
max=lambda block_maxes: np.nanmax(block_maxes, axis=0),
207-
min=lambda block_mins: np.nanmin(block_mins, axis=0),
208-
sum=lambda block_sums: np.nansum(block_sums, axis=0),
209-
count=lambda block_counts: np.nansum(block_counts, axis=0),
210-
sum_squares=lambda block_sum_squares: np.nansum(block_sum_squares, axis=0),
211-
squared_sum=lambda block_sums: np.nansum(block_sums, axis=0)**2,
218+
max=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmax),
219+
min=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmin),
220+
sum=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
221+
count=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
222+
sum_squares=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
212223
)
213-
def _dask_mean(sums, counts): return sums / counts # noqa
214-
def _dask_std(sum_squares, squared_sum, n): return np.sqrt((sum_squares - squared_sum/n) / n) # noqa
215-
def _dask_var(sum_squares, squared_sum, n): return (sum_squares - squared_sum/n) / n # noqa
224+
225+
226+
def _dask_mean(sums, counts): # noqa
227+
return sums / counts
228+
229+
230+
def _parallel_variance(block_counts, block_sums, block_sum_squares):
231+
"""Population variance via Chan-Golub-LeVeque parallel merge.
232+
233+
Each input is (n_blocks, n_zones). Returns (n_zones,) variance,
234+
with NaN for zones that have no valid values in any block.
235+
236+
This avoids the naive ``(Σx² − (Σx)²/n) / n`` formula whose
237+
subtraction can lose most significant digits when the mean is
238+
large relative to the standard deviation.
239+
"""
240+
n_blocks = block_counts.shape[0]
241+
n_zones = block_counts.shape[1]
242+
243+
n_acc = np.zeros(n_zones, dtype=np.float64)
244+
mean_acc = np.zeros(n_zones, dtype=np.float64)
245+
m2_acc = np.zeros(n_zones, dtype=np.float64)
246+
247+
for i in range(n_blocks):
248+
nc = np.asarray(block_counts[i], dtype=np.float64)
249+
sc = np.asarray(block_sums[i], dtype=np.float64)
250+
sqc = np.asarray(block_sum_squares[i], dtype=np.float64)
251+
252+
has_data = np.isfinite(nc) & (nc > 0)
253+
nc_safe = np.where(has_data, nc, 1.0) # avoid /0
254+
255+
with np.errstate(invalid='ignore', divide='ignore'):
256+
mean_b = sc / nc_safe
257+
m2_b = sqc - sc ** 2 / nc_safe # block-internal M2
258+
259+
nc = np.where(has_data, nc, 0.0)
260+
n_ab = n_acc + nc
261+
262+
delta = mean_b - mean_acc
263+
with np.errstate(invalid='ignore', divide='ignore'):
264+
n_ab_safe = np.where(n_ab > 0, n_ab, 1.0)
265+
correction = delta ** 2 * n_acc * nc / n_ab_safe
266+
new_mean = mean_acc + delta * nc / n_ab_safe
267+
268+
m2_acc = np.where(has_data, m2_acc + m2_b + correction, m2_acc)
269+
mean_acc = np.where(has_data, new_mean, mean_acc)
270+
n_acc = np.where(has_data, n_ab, n_acc)
271+
272+
with np.errstate(invalid='ignore', divide='ignore'):
273+
var = np.where(n_acc > 0, m2_acc / n_acc, np.nan)
274+
return var
216275

217276

218277
@ngjit
@@ -269,7 +328,10 @@ def _calc_stats(
269328
if unique_zones[i] in zone_ids:
270329
zone_values = values_by_zones[start:end]
271330
# filter out non-finite and nodata_values
272-
zone_values = zone_values[np.isfinite(zone_values) & (zone_values != nodata_values)]
331+
mask = np.isfinite(zone_values)
332+
if nodata_values is not None:
333+
mask = mask & (zone_values != nodata_values)
334+
zone_values = zone_values[mask]
273335
if len(zone_values) > 0:
274336
results[i] = func(zone_values)
275337
start = end
@@ -342,9 +404,11 @@ def _stats_dask_numpy(
342404
sum=values.dtype,
343405
count=np.int64,
344406
sum_squares=values.dtype,
345-
squared_sum=values.dtype,
346407
)
347408

409+
# Keep per-block stacked arrays for the parallel variance merge
410+
stacked_blocks = {}
411+
348412
for s in basis_stats:
349413
if s == 'sum_squares' and not compute_sum_squares:
350414
continue
@@ -358,21 +422,34 @@ def _stats_dask_numpy(
358422
for z, v in zip(zones_blocks, values_blocks)
359423
]
360424
zonal_stats = da.stack(stats_by_block, allow_unknown_chunksizes=True)
425+
426+
if compute_sum_squares and s in ('count', 'sum', 'sum_squares'):
427+
stacked_blocks[s] = zonal_stats
428+
361429
stats_func_by_block = delayed(_DASK_STATS[s])
362430
stats_dict[s] = da.from_delayed(
363431
stats_func_by_block(zonal_stats), shape=(np.nan,), dtype=np.float64
364432
)
365433

366434
if 'mean' in stats_funcs:
367435
stats_dict['mean'] = _dask_mean(stats_dict['sum'], stats_dict['count'])
368-
if 'std' in stats_funcs:
369-
stats_dict['std'] = _dask_std(
370-
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
371-
)
372-
if 'var' in stats_funcs:
373-
stats_dict['var'] = _dask_var(
374-
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
436+
437+
if 'std' in stats_funcs or 'var' in stats_funcs:
438+
var_result = da.from_delayed(
439+
delayed(_parallel_variance)(
440+
stacked_blocks['count'],
441+
stacked_blocks['sum'],
442+
stacked_blocks['sum_squares'],
443+
),
444+
shape=(np.nan,), dtype=np.float64,
375445
)
446+
if 'var' in stats_funcs:
447+
stats_dict['var'] = var_result
448+
if 'std' in stats_funcs:
449+
stats_dict['std'] = da.from_delayed(
450+
delayed(np.sqrt)(var_result),
451+
shape=(np.nan,), dtype=np.float64,
452+
)
376453

377454
# generate dask dataframe
378455
stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1, ignore_unknown_divisions=True)
@@ -846,9 +923,10 @@ def _single_zone_crosstab_2d(
846923
):
847924
# 1D flatten zone_values, i.e, original data is 2D
848925
# filter out non-finite and nodata_values
849-
zone_values = zone_values[
850-
np.isfinite(zone_values) & (zone_values != nodata_values)
851-
]
926+
mask = np.isfinite(zone_values)
927+
if nodata_values is not None:
928+
mask = mask & (zone_values != nodata_values)
929+
zone_values = zone_values[mask]
852930
total_count = zone_values.shape[0]
853931
crosstab_dict[TOTAL_COUNT].append(total_count)
854932

@@ -877,10 +955,10 @@ def _single_zone_crosstab_3d(
877955
if cat in cat_ids:
878956
zone_cat_data = zone_values[j]
879957
# filter out non-finite and nodata_values
880-
zone_cat_data = zone_cat_data[
881-
np.isfinite(zone_cat_data)
882-
& (zone_cat_data != nodata_values)
883-
]
958+
cat_mask = np.isfinite(zone_cat_data)
959+
if nodata_values is not None:
960+
cat_mask = cat_mask & (zone_cat_data != nodata_values)
961+
zone_cat_data = zone_cat_data[cat_mask]
884962
crosstab_dict[cat].append(stats_func(zone_cat_data))
885963

886964

0 commit comments

Comments
 (0)