@@ -198,21 +198,77 @@ def _stats_majority(data):
198198 min = lambda z : z .min (),
199199 sum = lambda z : z .sum (),
200200 count = lambda z : _stats_count (z ),
201- sum_squares = lambda z : (z ** 2 ).sum ()
201+ sum_squares = lambda z : (( z - z . mean ()) ** 2 ).sum () # block-level M2
202202)
203203
204204
205+ def _nanreduce_preserve_allnan (blocks , func ):
206+ """Reduce across blocks, returning NaN when ALL blocks are NaN for a zone.
207+
208+ ``np.nansum`` returns 0 for all-NaN input; we want NaN so that zones
209+ with no valid values propagate NaN, consistent with the numpy backend.
210+ """
211+ result = func (blocks , axis = 0 )
212+ all_nan = np .all (np .isnan (blocks ), axis = 0 )
213+ result [all_nan ] = np .nan
214+ return result
215+
216+
205217_DASK_STATS = dict (
206- max = lambda block_maxes : np .nanmax (block_maxes , axis = 0 ),
207- min = lambda block_mins : np .nanmin (block_mins , axis = 0 ),
208- sum = lambda block_sums : np .nansum (block_sums , axis = 0 ),
209- count = lambda block_counts : np .nansum (block_counts , axis = 0 ),
210- sum_squares = lambda block_sum_squares : np .nansum (block_sum_squares , axis = 0 ),
211- squared_sum = lambda block_sums : np .nansum (block_sums , axis = 0 )** 2 ,
218+ max = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nanmax ),
219+ min = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nanmin ),
220+ sum = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nansum ),
221+ count = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nansum ),
222+ sum_squares = lambda blocks : _nanreduce_preserve_allnan (blocks , np .nansum ),
212223)
213- def _dask_mean (sums , counts ): return sums / counts # noqa
214- def _dask_std (sum_squares , squared_sum , n ): return np .sqrt ((sum_squares - squared_sum / n ) / n ) # noqa
215- def _dask_var (sum_squares , squared_sum , n ): return (sum_squares - squared_sum / n ) / n # noqa
224+
225+
226+ def _dask_mean (sums , counts ): # noqa
227+ return sums / counts
228+
229+
230+ def _parallel_variance (block_counts , block_sums , block_m2s ):
231+ """Population variance via Chan-Golub-LeVeque parallel merge.
232+
233+ Each input is (n_blocks, n_zones). ``block_m2s`` contains
234+ per-block M2 values (sum of squared deviations from the block mean),
235+ NOT raw sum-of-squares. Returns (n_zones,) population variance,
236+ with NaN for zones that have no valid values in any block.
237+ """
238+ n_blocks = block_counts .shape [0 ]
239+ n_zones = block_counts .shape [1 ]
240+
241+ n_acc = np .zeros (n_zones , dtype = np .float64 )
242+ mean_acc = np .zeros (n_zones , dtype = np .float64 )
243+ m2_acc = np .zeros (n_zones , dtype = np .float64 )
244+
245+ for i in range (n_blocks ):
246+ nc = np .asarray (block_counts [i ], dtype = np .float64 )
247+ sc = np .asarray (block_sums [i ], dtype = np .float64 )
248+ m2_b = np .asarray (block_m2s [i ], dtype = np .float64 )
249+
250+ has_data = np .isfinite (nc ) & (nc > 0 )
251+ nc_safe = np .where (has_data , nc , 1.0 ) # avoid /0
252+
253+ with np .errstate (invalid = 'ignore' , divide = 'ignore' ):
254+ mean_b = sc / nc_safe
255+
256+ nc = np .where (has_data , nc , 0.0 )
257+ n_ab = n_acc + nc
258+
259+ delta = mean_b - mean_acc
260+ with np .errstate (invalid = 'ignore' , divide = 'ignore' ):
261+ n_ab_safe = np .where (n_ab > 0 , n_ab , 1.0 )
262+ correction = delta ** 2 * n_acc * nc / n_ab_safe
263+ new_mean = mean_acc + delta * nc / n_ab_safe
264+
265+ m2_acc = np .where (has_data , m2_acc + m2_b + correction , m2_acc )
266+ mean_acc = np .where (has_data , new_mean , mean_acc )
267+ n_acc = np .where (has_data , n_ab , n_acc )
268+
269+ with np .errstate (invalid = 'ignore' , divide = 'ignore' ):
270+ var = np .where (n_acc > 0 , m2_acc / n_acc , np .nan )
271+ return var
216272
217273
218274@ngjit
@@ -269,7 +325,10 @@ def _calc_stats(
269325 if unique_zones [i ] in zone_ids :
270326 zone_values = values_by_zones [start :end ]
271327 # filter out non-finite and nodata_values
272- zone_values = zone_values [np .isfinite (zone_values ) & (zone_values != nodata_values )]
328+ mask = np .isfinite (zone_values )
329+ if nodata_values is not None :
330+ mask = mask & (zone_values != nodata_values )
331+ zone_values = zone_values [mask ]
273332 if len (zone_values ) > 0 :
274333 results [i ] = func (zone_values )
275334 start = end
@@ -342,9 +401,11 @@ def _stats_dask_numpy(
342401 sum = values .dtype ,
343402 count = np .int64 ,
344403 sum_squares = values .dtype ,
345- squared_sum = values .dtype ,
346404 )
347405
406+ # Keep per-block stacked arrays for the parallel variance merge
407+ stacked_blocks = {}
408+
348409 for s in basis_stats :
349410 if s == 'sum_squares' and not compute_sum_squares :
350411 continue
@@ -358,21 +419,34 @@ def _stats_dask_numpy(
358419 for z , v in zip (zones_blocks , values_blocks )
359420 ]
360421 zonal_stats = da .stack (stats_by_block , allow_unknown_chunksizes = True )
422+
423+ if compute_sum_squares and s in ('count' , 'sum' , 'sum_squares' ):
424+ stacked_blocks [s ] = zonal_stats
425+
361426 stats_func_by_block = delayed (_DASK_STATS [s ])
362427 stats_dict [s ] = da .from_delayed (
363428 stats_func_by_block (zonal_stats ), shape = (np .nan ,), dtype = np .float64
364429 )
365430
366431 if 'mean' in stats_funcs :
367432 stats_dict ['mean' ] = _dask_mean (stats_dict ['sum' ], stats_dict ['count' ])
368- if 'std' in stats_funcs :
369- stats_dict ['std' ] = _dask_std (
370- stats_dict ['sum_squares' ], stats_dict ['sum' ] ** 2 , stats_dict ['count' ]
371- )
372- if 'var' in stats_funcs :
373- stats_dict ['var' ] = _dask_var (
374- stats_dict ['sum_squares' ], stats_dict ['sum' ] ** 2 , stats_dict ['count' ]
433+
434+ if 'std' in stats_funcs or 'var' in stats_funcs :
435+ var_result = da .from_delayed (
436+ delayed (_parallel_variance )(
437+ stacked_blocks ['count' ],
438+ stacked_blocks ['sum' ],
439+ stacked_blocks ['sum_squares' ],
440+ ),
441+ shape = (np .nan ,), dtype = np .float64 ,
375442 )
443+ if 'var' in stats_funcs :
444+ stats_dict ['var' ] = var_result
445+ if 'std' in stats_funcs :
446+ stats_dict ['std' ] = da .from_delayed (
447+ delayed (np .sqrt )(var_result ),
448+ shape = (np .nan ,), dtype = np .float64 ,
449+ )
376450
377451 # generate dask dataframe
378452 stats_df = dd .concat ([dd .from_dask_array (s ) for s in stats_dict .values ()], axis = 1 , ignore_unknown_divisions = True )
@@ -846,9 +920,10 @@ def _single_zone_crosstab_2d(
846920):
847921 # 1D flatten zone_values, i.e, original data is 2D
848922 # filter out non-finite and nodata_values
849- zone_values = zone_values [
850- np .isfinite (zone_values ) & (zone_values != nodata_values )
851- ]
923+ mask = np .isfinite (zone_values )
924+ if nodata_values is not None :
925+ mask = mask & (zone_values != nodata_values )
926+ zone_values = zone_values [mask ]
852927 total_count = zone_values .shape [0 ]
853928 crosstab_dict [TOTAL_COUNT ].append (total_count )
854929
@@ -877,10 +952,10 @@ def _single_zone_crosstab_3d(
877952 if cat in cat_ids :
878953 zone_cat_data = zone_values [j ]
879954 # filter out non-finite and nodata_values
880- zone_cat_data = zone_cat_data [
881- np . isfinite ( zone_cat_data )
882- & (zone_cat_data != nodata_values )
883- ]
955+ cat_mask = np . isfinite ( zone_cat_data )
956+ if nodata_values is not None :
957+ cat_mask = cat_mask & (zone_cat_data != nodata_values )
958+ zone_cat_data = zone_cat_data [ cat_mask ]
884959 crosstab_dict [cat ].append (stats_func (zone_cat_data ))
885960
886961
0 commit comments