Skip to content

Commit 35d3c76

Browse files
committed
Add num_sample to quantile() and percentiles() for memory-safe dask paths
The previous commit eliminated unknown dask chunks but still materialised the full array via .ravel().compute(). Now both functions accept num_sample (default 20_000, matching natural_breaks/maximum_breaks) and use _generate_sample_indices() + indexed access so only the sample is ever computed on dask backends.
1 parent 45c1afc commit 35d3c76

File tree

1 file changed

+46
-20
lines changed

1 file changed

+46
-20
lines changed

xrspatial/classify.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,8 @@ def reclassify(agg: xr.DataArray,
393393
attrs=agg.attrs)
394394

395395

396-
def _run_quantile(data, k, module):
396+
def _run_quantile(data, num_sample, k, module):
397+
# num_sample ignored for in-memory backends
397398
w = 100.0 / k
398399
p = module.arange(w, 100 + w, w)
399400

@@ -405,41 +406,47 @@ def _run_quantile(data, k, module):
405406
return q
406407

407408

408-
def _run_dask_quantile(data, k):
409+
def _run_dask_quantile(data, num_sample, k):
409410
# Avoid boolean fancy indexing (data[da.isfinite(data)]) which creates
410-
# unknown dask chunk sizes. Instead, replace inf with nan (preserves
411-
# known chunks), compute to numpy, then use np.nanpercentile (#884).
411+
# unknown dask chunk sizes. Use sampling via indexed access to avoid
412+
# materialising the full array (#884).
412413
w = 100.0 / k
413414
p = np.arange(w, 100 + w, w)
414415
if p[-1] > 100.0:
415416
p[-1] = 100.0
416417
clean = da.where(da.isinf(data), np.nan, data)
417-
values = clean.ravel().compute()
418-
q = np.nanpercentile(values, p)
418+
num_data = data.size
419+
if num_sample is None or num_sample >= num_data:
420+
num_sample = num_data
421+
sample_idx = _generate_sample_indices(num_data, num_sample)
422+
values = np.asarray(clean.ravel()[sample_idx].compute())
423+
values = values[np.isfinite(values)]
424+
q = np.percentile(values, p)
419425
q = np.unique(q)
420426
return q
421427

422428

423-
def _run_dask_cupy_quantile(data, k):
429+
def _run_dask_cupy_quantile(data, num_sample, k):
424430
# Convert dask+cupy chunks to numpy, then same safe path as dask (#884).
425431
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
426-
return _run_dask_quantile(data_cpu, k)
432+
return _run_dask_quantile(data_cpu, num_sample, k)
427433

428434

429-
def _quantile(agg, k):
435+
def _quantile(agg, num_sample, k):
430436
mapper = ArrayTypeFunctionMapping(
431437
numpy_func=lambda *args: _run_quantile(*args, module=np),
432438
dask_func=_run_dask_quantile,
433439
cupy_func=lambda *args: _run_quantile(*args, module=cupy),
434440
dask_cupy_func=_run_dask_cupy_quantile
435441
)
436-
out = mapper(agg)(agg.data, k)
442+
out = mapper(agg)(agg.data, num_sample, k)
437443
return out
438444

439445

440446
@supports_dataset
441447
def quantile(agg: xr.DataArray,
442448
k: int = 4,
449+
num_sample: Optional[int] = 20_000,
443450
name: Optional[str] = 'quantile') -> xr.DataArray:
444451
"""
445452
Reclassifies data for array `agg` into new values based on quantile
@@ -452,6 +459,12 @@ def quantile(agg: xr.DataArray,
452459
of values to be reclassified.
453460
k : int, default=4
454461
Number of quantiles to be produced.
462+
num_sample : int or None, default=20000
463+
Number of sample data points used to compute percentile
464+
breakpoints. For dask-backed arrays the sample is drawn
465+
lazily to avoid materialising the entire array into RAM.
466+
``None`` means use all data (safe for numpy/cupy,
467+
automatically capped for dask).
455468
name : str, default='quantile'
456469
Name of the output aggregate array.
457470
@@ -503,7 +516,7 @@ def quantile(agg: xr.DataArray,
503516
res: (10.0, 10.0)
504517
"""
505518

506-
q = _quantile(agg, k)
519+
q = _quantile(agg, num_sample, k)
507520
k_q = q.shape[0]
508521
if k_q < k:
509522
print("Quantile Warning: Not enough unique values"
@@ -1113,32 +1126,39 @@ def head_tail_breaks(agg: xr.DataArray,
11131126
attrs=agg.attrs)
11141127

11151128

1116-
def _run_percentiles(data, pct, module):
1129+
def _run_percentiles(data, num_sample, pct, module):
1130+
# num_sample ignored for in-memory backends
11171131
q = module.percentile(data[module.isfinite(data)], pct)
11181132
q = module.unique(q)
11191133
return q
11201134

11211135

1122-
def _run_dask_percentiles(data, pct):
1136+
def _run_dask_percentiles(data, num_sample, pct):
11231137
# Avoid boolean fancy indexing (data[da.isfinite(data)]) which creates
1124-
# unknown dask chunk sizes. Replace inf with nan, compute to numpy,
1125-
# then use np.nanpercentile (#884).
1138+
# unknown dask chunk sizes. Use sampling via indexed access to avoid
1139+
# materialising the full array (#884).
11261140
clean = da.where(da.isinf(data), np.nan, data)
1127-
values = clean.ravel().compute()
1128-
q = np.nanpercentile(values, pct)
1141+
num_data = data.size
1142+
if num_sample is None or num_sample >= num_data:
1143+
num_sample = num_data
1144+
sample_idx = _generate_sample_indices(num_data, num_sample)
1145+
values = np.asarray(clean.ravel()[sample_idx].compute())
1146+
values = values[np.isfinite(values)]
1147+
q = np.percentile(values, pct)
11291148
q = np.unique(q)
11301149
return q
11311150

11321151

1133-
def _run_dask_cupy_percentiles(data, pct):
1152+
def _run_dask_cupy_percentiles(data, num_sample, pct):
11341153
# Convert dask+cupy chunks to numpy, then same safe path as dask (#884).
11351154
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
1136-
return _run_dask_percentiles(data_cpu, pct)
1155+
return _run_dask_percentiles(data_cpu, num_sample, pct)
11371156

11381157

11391158
@supports_dataset
11401159
def percentiles(agg: xr.DataArray,
11411160
pct: Optional[List] = None,
1161+
num_sample: Optional[int] = 20_000,
11421162
name: Optional[str] = 'percentiles') -> xr.DataArray:
11431163
"""
11441164
Classify data based on percentile breakpoints.
@@ -1150,6 +1170,12 @@ def percentiles(agg: xr.DataArray,
11501170
of values to be classified.
11511171
pct : list of float, default=[1, 10, 50, 90, 99]
11521172
Percentile values to use as breakpoints.
1173+
num_sample : int or None, default=20000
1174+
Number of sample data points used to compute percentile
1175+
breakpoints. For dask-backed arrays the sample is drawn
1176+
lazily to avoid materialising the entire array into RAM.
1177+
``None`` means use all data (safe for numpy/cupy,
1178+
automatically capped for dask).
11531179
name : str, default='percentiles'
11541180
Name of output aggregate array.
11551181
@@ -1174,7 +1200,7 @@ def percentiles(agg: xr.DataArray,
11741200
cupy_func=lambda *args: _run_percentiles(*args, module=cupy),
11751201
dask_cupy_func=_run_dask_cupy_percentiles,
11761202
)
1177-
q = mapper(agg)(agg.data, pct)
1203+
q = mapper(agg)(agg.data, num_sample, pct)
11781204

11791205
# Materialize bin edges to numpy
11801206
if hasattr(q, 'compute'):

0 commit comments

Comments
 (0)