Skip to content

Commit 45c1afc

Browse files
committed
Fixes #884: replace boolean fancy indexing with dask-safe percentile path
quantile() and percentiles() used data[module.isfinite(data)] on dask arrays, which creates unknown chunk sizes that degrade scheduling and can force unexpected materialisations. Replace with dedicated dask functions that use da.where to clean inf→nan (preserving known chunks), compute to numpy, then use np.nanpercentile + np.unique.
1 parent cf5aaa7 commit 45c1afc

File tree

2 files changed

+92
-6
lines changed

2 files changed

+92
-6
lines changed

xrspatial/classify.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,17 +405,31 @@ def _run_quantile(data, k, module):
405405
return q
406406

407407

408+
def _run_dask_quantile(data, k):
409+
# Avoid boolean fancy indexing (data[da.isfinite(data)]) which creates
410+
# unknown dask chunk sizes. Instead, replace inf with nan (preserves
411+
# known chunks), compute to numpy, then use np.nanpercentile (#884).
412+
w = 100.0 / k
413+
p = np.arange(w, 100 + w, w)
414+
if p[-1] > 100.0:
415+
p[-1] = 100.0
416+
clean = da.where(da.isinf(data), np.nan, data)
417+
values = clean.ravel().compute()
418+
q = np.nanpercentile(values, p)
419+
q = np.unique(q)
420+
return q
421+
422+
408423
def _run_dask_cupy_quantile(data, k):
409-
# Convert dask+cupy chunks to numpy one at a time via map_blocks,
410-
# then use dask's streaming approximate percentile (no full materialization).
424+
# Convert dask+cupy chunks to numpy, then same safe path as dask (#884).
411425
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
412-
return _run_quantile(data_cpu, k, da)
426+
return _run_dask_quantile(data_cpu, k)
413427

414428

415429
def _quantile(agg, k):
416430
mapper = ArrayTypeFunctionMapping(
417431
numpy_func=lambda *args: _run_quantile(*args, module=np),
418-
dask_func=lambda *args: _run_quantile(*args, module=da),
432+
dask_func=_run_dask_quantile,
419433
cupy_func=lambda *args: _run_quantile(*args, module=cupy),
420434
dask_cupy_func=_run_dask_cupy_quantile
421435
)
@@ -1105,9 +1119,21 @@ def _run_percentiles(data, pct, module):
11051119
return q
11061120

11071121

1122+
def _run_dask_percentiles(data, pct):
1123+
# Avoid boolean fancy indexing (data[da.isfinite(data)]) which creates
1124+
# unknown dask chunk sizes. Replace inf with nan, compute to numpy,
1125+
# then use np.nanpercentile (#884).
1126+
clean = da.where(da.isinf(data), np.nan, data)
1127+
values = clean.ravel().compute()
1128+
q = np.nanpercentile(values, pct)
1129+
q = np.unique(q)
1130+
return q
1131+
1132+
11081133
def _run_dask_cupy_percentiles(data, pct):
1134+
# Convert dask+cupy chunks to numpy, then same safe path as dask (#884).
11091135
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
1110-
return _run_percentiles(data_cpu, pct, da)
1136+
return _run_dask_percentiles(data_cpu, pct)
11111137

11121138

11131139
@supports_dataset
@@ -1144,7 +1170,7 @@ def percentiles(agg: xr.DataArray,
11441170

11451171
mapper = ArrayTypeFunctionMapping(
11461172
numpy_func=lambda *args: _run_percentiles(*args, module=np),
1147-
dask_func=lambda *args: _run_percentiles(*args, module=da),
1173+
dask_func=_run_dask_percentiles,
11481174
cupy_func=lambda *args: _run_percentiles(*args, module=cupy),
11491175
dask_cupy_func=_run_dask_cupy_percentiles,
11501176
)

xrspatial/tests/test_classify.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,3 +945,63 @@ def test_maximum_breaks_dask_num_sample():
945945
assert result1.shape == elevation.shape
946946
unique_vals = np.unique(result1.data.compute())
947947
assert len(unique_vals) <= 3 + 1 # at most k classes + possible nan
948+
949+
950+
# ===================================================================
951+
# Regression tests: dask paths must not use boolean fancy indexing
952+
# ===================================================================
953+
954+
@dask_array_available
955+
def test_quantile_dask_no_unknown_chunks():
956+
"""quantile on dask must not create unknown chunk sizes (#884)."""
957+
elevation = np.arange(100, dtype=np.float64).reshape(10, 10)
958+
numpy_agg = xr.DataArray(elevation)
959+
dask_agg = xr.DataArray(da.from_array(elevation, chunks=(5, 5)))
960+
961+
numpy_result = quantile(numpy_agg, k=5)
962+
dask_result = quantile(dask_agg, k=5)
963+
964+
# Dask percentile is approximate, so just check same shape and k classes
965+
assert dask_result.shape == numpy_result.shape
966+
dask_vals = dask_result.data.compute()
967+
unique_vals = np.unique(dask_vals[np.isfinite(dask_vals)])
968+
assert len(unique_vals) == 5
969+
970+
971+
@dask_array_available
972+
def test_quantile_dask_with_nan_inf():
973+
"""quantile on dask handles NaN and inf without unknown chunks (#884)."""
974+
elevation = np.array([
975+
[-np.inf, 2., 3., 4., np.nan],
976+
[5., 6., 7., 8., 9.],
977+
[10., 11., 12., 13., 14.],
978+
[15., 16., 17., 18., np.inf],
979+
])
980+
dask_agg = xr.DataArray(da.from_array(elevation, chunks=(2, 5)))
981+
result = quantile(dask_agg, k=5)
982+
result_data = result.data.compute()
983+
# NaN and inf inputs should produce NaN in the output
984+
assert np.isnan(result_data[0, 0]) # was -inf
985+
assert np.isnan(result_data[0, 4]) # was nan
986+
assert np.isnan(result_data[3, 4]) # was inf
987+
# Finite values should be classified
988+
finite_result = result_data[np.isfinite(result_data)]
989+
assert len(np.unique(finite_result)) == 5
990+
991+
992+
@dask_array_available
993+
def test_percentiles_dask_no_unknown_chunks():
994+
"""percentiles on dask must not create unknown chunk sizes (#884)."""
995+
from xrspatial import percentiles as percentiles_fn
996+
elevation = np.arange(100, dtype=np.float64).reshape(10, 10)
997+
numpy_agg = xr.DataArray(elevation)
998+
dask_agg = xr.DataArray(da.from_array(elevation, chunks=(5, 5)))
999+
1000+
numpy_result = percentiles_fn(numpy_agg)
1001+
dask_result = percentiles_fn(dask_agg)
1002+
1003+
np.testing.assert_allclose(
1004+
numpy_result.data,
1005+
dask_result.data.compute(),
1006+
equal_nan=True,
1007+
)

0 commit comments

Comments
 (0)