Skip to content

Commit 94a4747

Browse files
committed
Fixes #899: fix boolean short-circuit bug in dask zonal.stats()
The conditions `if 'mean' or 'std' or 'var' in stats_funcs` always evaluated to True because the string 'mean' is truthy. This caused compute_sum, compute_count, and compute_sum_squares to always be set, wasting work on every dask zonal.stats() call regardless of which stats were requested. Fix: use `any(s in stats_funcs for s in (...))` for correct membership testing. Add regression tests covering 7 stat subsets on both numpy and dask backends to exercise each compute flag independently.
1 parent 86da4e0 commit 94a4747

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

xrspatial/tests/test_zonal.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,47 @@ def test_majority_with_ties(backend):
590590
check_results(backend, df_result, expected_result)
591591

592592

593+
@pytest.mark.parametrize("stats_funcs, expected_cols", [
594+
(['min', 'max'], ['zone', 'min', 'max']),
595+
(['mean'], ['zone', 'mean']),
596+
(['std'], ['zone', 'std']),
597+
(['var'], ['zone', 'var']),
598+
(['count'], ['zone', 'count']),
599+
(['sum'], ['zone', 'sum']),
600+
(['min', 'max', 'count'], ['zone', 'min', 'max', 'count']),
601+
])
602+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
603+
def test_stats_subset_columns(backend, data_zones, data_values_2d,
604+
stats_funcs, expected_cols):
605+
"""Requesting a subset of stats returns only those columns.
606+
607+
Regression test for GH-899: the dask path had a boolean short-circuit
608+
bug (``if 'mean' or 'std' or 'var' in stats_funcs``) that always
609+
evaluated to True, causing unnecessary intermediate stats to be
610+
computed. After the fix, each subset exercises a distinct code path
611+
for compute_sum / compute_count / compute_sum_squares flags.
612+
"""
613+
if 'dask' in backend and not dask_array_available():
614+
pytest.skip("Requires Dask")
615+
616+
df_result = stats(zones=data_zones, values=data_values_2d,
617+
stats_funcs=stats_funcs)
618+
619+
# Verify values are correct for the requested stats
620+
all_expected = {
621+
'zone': [0, 1, 2, 3],
622+
'mean': [0, 1, 2, 2.4],
623+
'max': [0, 1, 2, 3],
624+
'min': [0, 1, 2, 0],
625+
'sum': [0, 6, 8, 12],
626+
'std': [0, 0, 0, 1.2],
627+
'var': [0, 0, 0, 1.44],
628+
'count': [5, 6, 4, 5],
629+
}
630+
expected = {k: all_expected[k] for k in expected_cols}
631+
check_results(backend, df_result, expected)
632+
633+
593634
def test_zonal_stats_against_qgis(elevation_raster_no_nans, raster, qgis_zonal_stats):
594635
stats_funcs = list(set(qgis_zonal_stats.keys()) - set(['zone']))
595636
zones_agg = create_test_raster(raster)

xrspatial/zonal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ def _stats_dask_numpy(
240240
compute_sum = False
241241
compute_count = False
242242

243-
if 'mean' or 'std' or 'var' in stats_funcs:
243+
if any(s in stats_funcs for s in ('mean', 'std', 'var')):
244244
compute_sum = True
245245
compute_count = True
246246

247-
if 'std' or 'var' in stats_funcs:
247+
if any(s in stats_funcs for s in ('std', 'var')):
248248
compute_sum_squares = True
249249

250250
basis_stats = [s for s in _DASK_BLOCK_STATS if s in stats_funcs]

0 commit comments

Comments
 (0)