Skip to content

Commit fd78352

Browse files
authored
Fixes #877, #876, #884: prevent OOM and unknown chunks in classify.py dask paths (#895)
* Fixes #877, #876: prevent OOM from full dask materialisation in classify.py natural_breaks and maximum_breaks dask code paths called .ravel().compute() on the full array, materialising the entire dataset into RAM. Replace with capped sampling via _generate_sample_indices() + indexed access so only the sample is ever computed. Add num_sample parameter to maximum_breaks (default 20_000, matching natural_breaks). * Fixes #884: replace boolean fancy indexing with dask-safe percentile path quantile() and percentiles() used data[module.isfinite(data)] on dask arrays, which creates unknown chunk sizes that degrade scheduling and can force unexpected materialisations. Replace with dedicated dask functions that use da.where to clean inf→nan (preserving known chunks), compute to numpy, then use np.nanpercentile + np.unique. * Add num_sample to quantile() and percentiles() for memory-safe dask paths The previous commit eliminated unknown dask chunks but still materialised the full array via .ravel().compute(). Now both functions accept num_sample (default 20_000, matching natural_breaks/maximum_breaks) and use _generate_sample_indices() + indexed access so only the sample is ever computed on dask backends.
1 parent 0cf6ce2 commit fd78352

File tree

2 files changed

+220
-41
lines changed

2 files changed

+220
-41
lines changed

xrspatial/classify.py

Lines changed: 98 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,8 @@ def reclassify(agg: xr.DataArray,
393393
attrs=agg.attrs)
394394

395395

396-
def _run_quantile(data, k, module):
396+
def _run_quantile(data, num_sample, k, module):
397+
# num_sample ignored for in-memory backends
397398
w = 100.0 / k
398399
p = module.arange(w, 100 + w, w)
399400

@@ -405,27 +406,47 @@ def _run_quantile(data, k, module):
405406
return q
406407

407408

408-
def _run_dask_cupy_quantile(data, k):
409-
# Convert dask+cupy chunks to numpy one at a time via map_blocks,
410-
# then use dask's streaming approximate percentile (no full materialization).
409+
def _run_dask_quantile(data, num_sample, k):
410+
# Avoid boolean fancy indexing (data[da.isfinite(data)]) which creates
411+
# unknown dask chunk sizes. Use sampling via indexed access to avoid
412+
# materialising the full array (#884).
413+
w = 100.0 / k
414+
p = np.arange(w, 100 + w, w)
415+
if p[-1] > 100.0:
416+
p[-1] = 100.0
417+
clean = da.where(da.isinf(data), np.nan, data)
418+
num_data = data.size
419+
if num_sample is None or num_sample >= num_data:
420+
num_sample = num_data
421+
sample_idx = _generate_sample_indices(num_data, num_sample)
422+
values = np.asarray(clean.ravel()[sample_idx].compute())
423+
values = values[np.isfinite(values)]
424+
q = np.percentile(values, p)
425+
q = np.unique(q)
426+
return q
427+
428+
429+
def _run_dask_cupy_quantile(data, num_sample, k):
430+
# Convert dask+cupy chunks to numpy, then same safe path as dask (#884).
411431
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
412-
return _run_quantile(data_cpu, k, da)
432+
return _run_dask_quantile(data_cpu, num_sample, k)
413433

414434

415-
def _quantile(agg, k):
435+
def _quantile(agg, num_sample, k):
416436
mapper = ArrayTypeFunctionMapping(
417437
numpy_func=lambda *args: _run_quantile(*args, module=np),
418-
dask_func=lambda *args: _run_quantile(*args, module=da),
438+
dask_func=_run_dask_quantile,
419439
cupy_func=lambda *args: _run_quantile(*args, module=cupy),
420440
dask_cupy_func=_run_dask_cupy_quantile
421441
)
422-
out = mapper(agg)(agg.data, k)
442+
out = mapper(agg)(agg.data, num_sample, k)
423443
return out
424444

425445

426446
@supports_dataset
427447
def quantile(agg: xr.DataArray,
428448
k: int = 4,
449+
num_sample: Optional[int] = 20_000,
429450
name: Optional[str] = 'quantile') -> xr.DataArray:
430451
"""
431452
Reclassifies data for array `agg` into new values based on quantile
@@ -438,6 +459,12 @@ def quantile(agg: xr.DataArray,
438459
of values to be reclassified.
439460
k : int, default=4
440461
Number of quantiles to be produced.
462+
num_sample : int or None, default=20000
463+
Number of sample data points used to compute percentile
464+
breakpoints. For dask-backed arrays the sample is drawn
465+
lazily to avoid materialising the entire array into RAM.
466+
``None`` means use all data (safe for numpy/cupy,
467+
automatically capped for dask).
441468
name : str, default='quantile'
442469
Name of the output aggregate array.
443470
@@ -489,7 +516,7 @@ def quantile(agg: xr.DataArray,
489516
res: (10.0, 10.0)
490517
"""
491518

492-
q = _quantile(agg, k)
519+
q = _quantile(agg, num_sample, k)
493520
k_q = q.shape[0]
494521
if k_q < k:
495522
print("Quantile Warning: Not enough unique values"
@@ -695,16 +722,11 @@ def _run_dask_natural_break(agg, num_sample, k):
695722
max_data = float(da.nanmax(da.where(da.isinf(data), np.nan, data)).compute())
696723

697724
num_data = data.size
698-
if num_sample is not None and num_sample < num_data:
699-
# Sample lazily from dask array; only materialize the sample
700-
sample_idx = _generate_sample_indices(num_data, num_sample)
701-
sample_data_np = np.asarray(data.ravel()[sample_idx].compute())
702-
bins, uvk = _compute_natural_break_bins(
703-
sample_data_np, None, k, max_data)
704-
else:
705-
data_flat_np = np.asarray(data.ravel().compute())
706-
bins, uvk = _compute_natural_break_bins(
707-
data_flat_np, None, k, max_data)
725+
if num_sample is None or num_sample >= num_data:
726+
num_sample = num_data # cap: still uses indexed access, never .compute() all
727+
sample_idx = _generate_sample_indices(num_data, num_sample)
728+
sample_data_np = np.asarray(data.ravel()[sample_idx].compute())
729+
bins, uvk = _compute_natural_break_bins(sample_data_np, None, k, max_data)
708730

709731
out = _bin(agg, bins, np.arange(uvk))
710732
return out
@@ -717,17 +739,12 @@ def _run_dask_cupy_natural_break(agg, num_sample, k):
717739
max_data = float(da.nanmax(da.where(da.isinf(data), np.nan, data)).compute().item())
718740

719741
num_data = data.size
720-
if num_sample is not None and num_sample < num_data:
721-
# Sample lazily from dask array; only materialize the sample
722-
sample_idx = _generate_sample_indices(num_data, num_sample)
723-
sample_data = data.ravel()[sample_idx].compute()
724-
sample_data_np = cupy.asnumpy(sample_data)
725-
bins, uvk = _compute_natural_break_bins(
726-
sample_data_np, None, k, max_data)
727-
else:
728-
data_flat_np = cupy.asnumpy(data.ravel().compute())
729-
bins, uvk = _compute_natural_break_bins(
730-
data_flat_np, None, k, max_data)
742+
if num_sample is None or num_sample >= num_data:
743+
num_sample = num_data # cap: still uses indexed access, never .compute() all
744+
sample_idx = _generate_sample_indices(num_data, num_sample)
745+
sample_data = data.ravel()[sample_idx].compute()
746+
sample_data_np = cupy.asnumpy(sample_data)
747+
bins, uvk = _compute_natural_break_bins(sample_data_np, None, k, max_data)
731748

732749
out = _bin(agg, bins, np.arange(uvk))
733750
return out
@@ -1109,20 +1126,39 @@ def head_tail_breaks(agg: xr.DataArray,
11091126
attrs=agg.attrs)
11101127

11111128

1112-
def _run_percentiles(data, pct, module):
1129+
def _run_percentiles(data, num_sample, pct, module):
1130+
# num_sample ignored for in-memory backends
11131131
q = module.percentile(data[module.isfinite(data)], pct)
11141132
q = module.unique(q)
11151133
return q
11161134

11171135

1118-
def _run_dask_cupy_percentiles(data, pct):
1136+
def _run_dask_percentiles(data, num_sample, pct):
1137+
# Avoid boolean fancy indexing (data[da.isfinite(data)]) which creates
1138+
# unknown dask chunk sizes. Use sampling via indexed access to avoid
1139+
# materialising the full array (#884).
1140+
clean = da.where(da.isinf(data), np.nan, data)
1141+
num_data = data.size
1142+
if num_sample is None or num_sample >= num_data:
1143+
num_sample = num_data
1144+
sample_idx = _generate_sample_indices(num_data, num_sample)
1145+
values = np.asarray(clean.ravel()[sample_idx].compute())
1146+
values = values[np.isfinite(values)]
1147+
q = np.percentile(values, pct)
1148+
q = np.unique(q)
1149+
return q
1150+
1151+
1152+
def _run_dask_cupy_percentiles(data, num_sample, pct):
1153+
# Convert dask+cupy chunks to numpy, then same safe path as dask (#884).
11191154
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
1120-
return _run_percentiles(data_cpu, pct, da)
1155+
return _run_dask_percentiles(data_cpu, num_sample, pct)
11211156

11221157

11231158
@supports_dataset
11241159
def percentiles(agg: xr.DataArray,
11251160
pct: Optional[List] = None,
1161+
num_sample: Optional[int] = 20_000,
11261162
name: Optional[str] = 'percentiles') -> xr.DataArray:
11271163
"""
11281164
Classify data based on percentile breakpoints.
@@ -1134,6 +1170,12 @@ def percentiles(agg: xr.DataArray,
11341170
of values to be classified.
11351171
pct : list of float, default=[1, 10, 50, 90, 99]
11361172
Percentile values to use as breakpoints.
1173+
num_sample : int or None, default=20000
1174+
Number of sample data points used to compute percentile
1175+
breakpoints. For dask-backed arrays the sample is drawn
1176+
lazily to avoid materialising the entire array into RAM.
1177+
``None`` means use all data (safe for numpy/cupy,
1178+
automatically capped for dask).
11371179
name : str, default='percentiles'
11381180
Name of output aggregate array.
11391181
@@ -1154,11 +1196,11 @@ def percentiles(agg: xr.DataArray,
11541196

11551197
mapper = ArrayTypeFunctionMapping(
11561198
numpy_func=lambda *args: _run_percentiles(*args, module=np),
1157-
dask_func=lambda *args: _run_percentiles(*args, module=da),
1199+
dask_func=_run_dask_percentiles,
11581200
cupy_func=lambda *args: _run_percentiles(*args, module=cupy),
11591201
dask_cupy_func=_run_dask_cupy_percentiles,
11601202
)
1161-
q = mapper(agg)(agg.data, pct)
1203+
q = mapper(agg)(agg.data, num_sample, pct)
11621204

11631205
# Materialize bin edges to numpy
11641206
if hasattr(q, 'compute'):
@@ -1204,7 +1246,8 @@ def _compute_maximum_break_bins(values_np, k):
12041246
return bins
12051247

12061248

1207-
def _run_maximum_breaks(agg, k, module):
1249+
def _run_maximum_breaks(agg, num_sample, k, module):
1250+
# num_sample ignored for in-memory backends
12081251
data = agg.data
12091252
if module == cupy:
12101253
finite = data[cupy.isfinite(data)]
@@ -1218,21 +1261,29 @@ def _run_maximum_breaks(agg, k, module):
12181261
return out
12191262

12201263

1221-
def _run_dask_maximum_breaks(agg, k):
1264+
def _run_dask_maximum_breaks(agg, num_sample, k):
12221265
data = agg.data
12231266
data_clean = da.where(da.isinf(data), np.nan, data)
1224-
values_np = np.asarray(data_clean.ravel().compute())
1267+
num_data = data.size
1268+
if num_sample is None or num_sample >= num_data:
1269+
num_sample = num_data
1270+
sample_idx = _generate_sample_indices(num_data, num_sample)
1271+
values_np = np.asarray(data_clean.ravel()[sample_idx].compute())
12251272
values_np = values_np[np.isfinite(values_np)]
12261273
bins = _compute_maximum_break_bins(values_np, k)
12271274
out = _bin(agg, bins, np.arange(len(bins)))
12281275
return out
12291276

12301277

1231-
def _run_dask_cupy_maximum_breaks(agg, k):
1278+
def _run_dask_cupy_maximum_breaks(agg, num_sample, k):
12321279
data = agg.data
12331280
data_clean = da.where(da.isinf(data), np.nan, data)
12341281
data_cpu = data_clean.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
1235-
values_np = np.asarray(data_cpu.ravel().compute())
1282+
num_data = data.size
1283+
if num_sample is None or num_sample >= num_data:
1284+
num_sample = num_data
1285+
sample_idx = _generate_sample_indices(num_data, num_sample)
1286+
values_np = np.asarray(data_cpu.ravel()[sample_idx].compute())
12361287
values_np = values_np[np.isfinite(values_np)]
12371288
bins = _compute_maximum_break_bins(values_np, k)
12381289
out = _bin(agg, bins, np.arange(len(bins)))
@@ -1242,6 +1293,7 @@ def _run_dask_cupy_maximum_breaks(agg, k):
12421293
@supports_dataset
12431294
def maximum_breaks(agg: xr.DataArray,
12441295
k: int = 5,
1296+
num_sample: Optional[int] = 20_000,
12451297
name: Optional[str] = 'maximum_breaks') -> xr.DataArray:
12461298
"""
12471299
Classify data using the Maximum Breaks algorithm.
@@ -1256,6 +1308,11 @@ def maximum_breaks(agg: xr.DataArray,
12561308
of values to be classified.
12571309
k : int, default=5
12581310
Number of classes to be produced.
1311+
num_sample : int or None, default=20000
1312+
Number of sample data points used to fit the model.
1313+
For dask-backed arrays the sample is drawn lazily to avoid
1314+
materialising the entire array into RAM. ``None`` means use
1315+
all data (safe for numpy/cupy, automatically capped for dask).
12591316
name : str, default='maximum_breaks'
12601317
Name of output aggregate array.
12611318
@@ -1277,7 +1334,7 @@ def maximum_breaks(agg: xr.DataArray,
12771334
cupy_func=lambda *args: _run_maximum_breaks(*args, module=cupy),
12781335
dask_cupy_func=_run_dask_cupy_maximum_breaks,
12791336
)
1280-
out = mapper(agg)(agg, k)
1337+
out = mapper(agg)(agg, num_sample, k)
12811338
return xr.DataArray(out,
12821339
name=name,
12831340
dims=agg.dims,

0 commit comments

Comments
 (0)