@@ -202,17 +202,76 @@ def _stats_majority(data):
202202)
203203
204204
205+ def _nanreduce_preserve_allnan (blocks , func ):
206+ """Reduce across blocks, returning NaN when ALL blocks are NaN for a zone.
207+
208+ ``np.nansum`` returns 0 for all-NaN input; we want NaN so that zones
209+ with no valid values propagate NaN, consistent with the numpy backend.
210+ """
211+ result = func (blocks , axis = 0 )
212+ all_nan = np .all (np .isnan (blocks ), axis = 0 )
213+ result [all_nan ] = np .nan
214+ return result
215+
216+
205217_DASK_STATS = dict (
206- max = lambda block_maxes : np .nanmax (block_maxes , axis = 0 ),
207- min = lambda block_mins : np .nanmin (block_mins , axis = 0 ),
208- sum = lambda block_sums : np .nansum (block_sums , axis = 0 ),
209- count = lambda block_counts : np .nansum (block_counts , axis = 0 ),
210- sum_squares = lambda block_sum_squares : np .nansum (block_sum_squares , axis = 0 ),
211- squared_sum = lambda block_sums : np .nansum (block_sums , axis = 0 )** 2 ,
218+ max = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nanmax ),
219+ min = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nanmin ),
220+ sum = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nansum ),
221+ count = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nansum ),
222+ sum_squares = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nansum ),
212223)
213- def _dask_mean (sums , counts ): return sums / counts # noqa
214- def _dask_std (sum_squares , squared_sum , n ): return np .sqrt ((sum_squares - squared_sum / n ) / n ) # noqa
215- def _dask_var (sum_squares , squared_sum , n ): return (sum_squares - squared_sum / n ) / n # noqa
224+
225+
226+ def _dask_mean (sums , counts ): # noqa
227+ return sums / counts
228+
229+
230+ def _parallel_variance (block_counts , block_sums , block_sum_squares ):
231+ """Population variance via Chan-Golub-LeVeque parallel merge.
232+
233+ Each input is (n_blocks, n_zones). Returns (n_zones,) variance,
234+ with NaN for zones that have no valid values in any block.
235+
236+ This avoids the naive ``(Σx² − (Σx)²/n) / n`` formula whose
237+ subtraction can lose most significant digits when the mean is
238+ large relative to the standard deviation.
239+ """
240+ n_blocks = block_counts .shape [0 ]
241+ n_zones = block_counts .shape [1 ]
242+
243+ n_acc = np .zeros (n_zones , dtype = np .float64 )
244+ mean_acc = np .zeros (n_zones , dtype = np .float64 )
245+ m2_acc = np .zeros (n_zones , dtype = np .float64 )
246+
247+ for i in range (n_blocks ):
248+ nc = np .asarray (block_counts [i ], dtype = np .float64 )
249+ sc = np .asarray (block_sums [i ], dtype = np .float64 )
250+ sqc = np .asarray (block_sum_squares [i ], dtype = np .float64 )
251+
252+ has_data = np .isfinite (nc ) & (nc > 0 )
253+ nc_safe = np .where (has_data , nc , 1.0 ) # avoid /0
254+
255+ with np .errstate (invalid = 'ignore' , divide = 'ignore' ):
256+ mean_b = sc / nc_safe
257+ m2_b = sqc - sc ** 2 / nc_safe # block-internal M2
258+
259+ nc = np .where (has_data , nc , 0.0 )
260+ n_ab = n_acc + nc
261+
262+ delta = mean_b - mean_acc
263+ with np .errstate (invalid = 'ignore' , divide = 'ignore' ):
264+ n_ab_safe = np .where (n_ab > 0 , n_ab , 1.0 )
265+ correction = delta ** 2 * n_acc * nc / n_ab_safe
266+ new_mean = mean_acc + delta * nc / n_ab_safe
267+
268+ m2_acc = np .where (has_data , m2_acc + m2_b + correction , m2_acc )
269+ mean_acc = np .where (has_data , new_mean , mean_acc )
270+ n_acc = np .where (has_data , n_ab , n_acc )
271+
272+ with np .errstate (invalid = 'ignore' , divide = 'ignore' ):
273+ var = np .where (n_acc > 0 , m2_acc / n_acc , np .nan )
274+ return var
216275
217276
218277@ngjit
@@ -269,7 +328,10 @@ def _calc_stats(
269328 if unique_zones [i ] in zone_ids :
270329 zone_values = values_by_zones [start :end ]
271330 # filter out non-finite and nodata_values
272- zone_values = zone_values [np .isfinite (zone_values ) & (zone_values != nodata_values )]
331+ mask = np .isfinite (zone_values )
332+ if nodata_values is not None :
333+ mask = mask & (zone_values != nodata_values )
334+ zone_values = zone_values [mask ]
273335 if len (zone_values ) > 0 :
274336 results [i ] = func (zone_values )
275337 start = end
@@ -342,9 +404,11 @@ def _stats_dask_numpy(
342404 sum = values .dtype ,
343405 count = np .int64 ,
344406 sum_squares = values .dtype ,
345- squared_sum = values .dtype ,
346407 )
347408
409+ # Keep per-block stacked arrays for the parallel variance merge
410+ stacked_blocks = {}
411+
348412 for s in basis_stats :
349413 if s == 'sum_squares' and not compute_sum_squares :
350414 continue
@@ -358,21 +422,34 @@ def _stats_dask_numpy(
358422 for z , v in zip (zones_blocks , values_blocks )
359423 ]
360424 zonal_stats = da .stack (stats_by_block , allow_unknown_chunksizes = True )
425+
426+ if compute_sum_squares and s in ('count' , 'sum' , 'sum_squares' ):
427+ stacked_blocks [s ] = zonal_stats
428+
361429 stats_func_by_block = delayed (_DASK_STATS [s ])
362430 stats_dict [s ] = da .from_delayed (
363431 stats_func_by_block (zonal_stats ), shape = (np .nan ,), dtype = np .float64
364432 )
365433
366434 if 'mean' in stats_funcs :
367435 stats_dict ['mean' ] = _dask_mean (stats_dict ['sum' ], stats_dict ['count' ])
368- if 'std' in stats_funcs :
369- stats_dict ['std' ] = _dask_std (
370- stats_dict ['sum_squares' ], stats_dict ['sum' ] ** 2 , stats_dict ['count' ]
371- )
372- if 'var' in stats_funcs :
373- stats_dict ['var' ] = _dask_var (
374- stats_dict ['sum_squares' ], stats_dict ['sum' ] ** 2 , stats_dict ['count' ]
436+
437+ if 'std' in stats_funcs or 'var' in stats_funcs :
438+ var_result = da .from_delayed (
439+ delayed (_parallel_variance )(
440+ stacked_blocks ['count' ],
441+ stacked_blocks ['sum' ],
442+ stacked_blocks ['sum_squares' ],
443+ ),
444+ shape = (np .nan ,), dtype = np .float64 ,
375445 )
446+ if 'var' in stats_funcs :
447+ stats_dict ['var' ] = var_result
448+ if 'std' in stats_funcs :
449+ stats_dict ['std' ] = da .from_delayed (
450+ delayed (np .sqrt )(var_result ),
451+ shape = (np .nan ,), dtype = np .float64 ,
452+ )
376453
377454 # generate dask dataframe
378455 stats_df = dd .concat ([dd .from_dask_array (s ) for s in stats_dict .values ()], axis = 1 , ignore_unknown_divisions = True )
@@ -846,9 +923,10 @@ def _single_zone_crosstab_2d(
846923):
847924 # 1D flatten zone_values, i.e, original data is 2D
848925 # filter out non-finite and nodata_values
849- zone_values = zone_values [
850- np .isfinite (zone_values ) & (zone_values != nodata_values )
851- ]
926+ mask = np .isfinite (zone_values )
927+ if nodata_values is not None :
928+ mask = mask & (zone_values != nodata_values )
929+ zone_values = zone_values [mask ]
852930 total_count = zone_values .shape [0 ]
853931 crosstab_dict [TOTAL_COUNT ].append (total_count )
854932
@@ -877,10 +955,10 @@ def _single_zone_crosstab_3d(
877955 if cat in cat_ids :
878956 zone_cat_data = zone_values [j ]
879957 # filter out non-finite and nodata_values
880- zone_cat_data = zone_cat_data [
881- np . isfinite ( zone_cat_data )
882- & (zone_cat_data != nodata_values )
883- ]
958+ cat_mask = np . isfinite ( zone_cat_data )
959+ if nodata_values is not None :
960+ cat_mask = cat_mask & (zone_cat_data != nodata_values )
961+ zone_cat_data = zone_cat_data [ cat_mask ]
884962 crosstab_dict [cat ].append (stats_func (zone_cat_data ))
885963
886964
0 commit comments