Skip to content

Commit 9b726c5

Browse files
committed
fix OOM in dask classification backends for large datasets
- quantile dask+cupy: replace full materialization with map_blocks(cupy.asnumpy) to convert chunks to CPU one at a time, then delegate to dask's streaming approximate percentile - natural_breaks dask backends: sample lazily from the dask array and only materialize the sample (default 20k points), not the entire dataset. Add _generate_sample_indices helper that uses O(num_sample) memory via RandomState.choice() for large datasets, falling back to the original linspace+shuffle for small datasets to preserve determinism with numpy
1 parent 835b5b1 commit 9b726c5

File tree

1 file changed

+53
-16
lines changed

1 file changed

+53
-16
lines changed

xrspatial/classify.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -398,18 +398,10 @@ def _run_quantile(data, k, module):
398398

399399

400400
def _run_dask_cupy_quantile(data, k):
401-
w = 100.0 / k
402-
p = np.arange(w, 100 + w, w)
403-
if p[-1] > 100.0:
404-
p[-1] = 100.0
405-
406-
finite_mask = da.isfinite(data)
407-
finite_data = data[finite_mask].compute()
408-
# transfer from GPU to CPU
409-
finite_data_np = cupy.asnumpy(finite_data)
410-
q = np.percentile(finite_data_np, p)
411-
q = np.unique(q)
412-
return q
401+
# Convert dask+cupy chunks to numpy one at a time via map_blocks,
402+
# then use dask's streaming approximate percentile (no full materialization).
403+
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
404+
return _run_quantile(data_cpu, k, da)
413405

414406

415407
def _quantile(agg, k):
@@ -661,22 +653,67 @@ def _run_cupy_natural_break(agg, num_sample, k):
661653
return out
662654

663655

656+
def _generate_sample_indices(num_data, num_sample, seed=1234567890):
657+
"""Generate sorted sample indices for natural breaks sampling.
658+
659+
For small datasets (<=10M elements), uses the same linspace+shuffle
660+
approach as the numpy backend for exact reproducibility.
661+
For large datasets, uses memory-efficient RandomState.choice()
662+
which is O(num_sample) rather than O(num_data).
663+
"""
664+
generator = np.random.RandomState(seed)
665+
if num_data <= 10_000_000:
666+
idx = np.linspace(
667+
0, num_data, num_data, endpoint=False, dtype=np.uint32
668+
)
669+
generator.shuffle(idx)
670+
sample_idx = idx[:num_sample]
671+
else:
672+
sample_idx = generator.choice(
673+
num_data, size=num_sample, replace=False
674+
)
675+
# sort for efficient dask chunk access
676+
sample_idx.sort()
677+
return sample_idx
678+
679+
664680
def _run_dask_natural_break(agg, num_sample, k):
665681
data = agg.data
666682
max_data = float(da.max(data[da.isfinite(data)]).compute())
667-
data_flat_np = np.asarray(data.ravel().compute())
668683

669-
bins, uvk = _compute_natural_break_bins(data_flat_np, num_sample, k, max_data)
684+
num_data = data.size
685+
if num_sample is not None and num_sample < num_data:
686+
# Sample lazily from dask array; only materialize the sample
687+
sample_idx = _generate_sample_indices(num_data, num_sample)
688+
sample_data_np = np.asarray(data.ravel()[sample_idx].compute())
689+
bins, uvk = _compute_natural_break_bins(
690+
sample_data_np, None, k, max_data)
691+
else:
692+
data_flat_np = np.asarray(data.ravel().compute())
693+
bins, uvk = _compute_natural_break_bins(
694+
data_flat_np, None, k, max_data)
695+
670696
out = _bin(agg, bins, np.arange(uvk))
671697
return out
672698

673699

674700
def _run_dask_cupy_natural_break(agg, num_sample, k):
675701
data = agg.data
676702
max_data = float(da.max(data[da.isfinite(data)]).compute().item())
677-
data_flat_np = cupy.asnumpy(data.ravel().compute())
678703

679-
bins, uvk = _compute_natural_break_bins(data_flat_np, num_sample, k, max_data)
704+
num_data = data.size
705+
if num_sample is not None and num_sample < num_data:
706+
# Sample lazily from dask array; only materialize the sample
707+
sample_idx = _generate_sample_indices(num_data, num_sample)
708+
sample_data = data.ravel()[sample_idx].compute()
709+
sample_data_np = cupy.asnumpy(sample_data)
710+
bins, uvk = _compute_natural_break_bins(
711+
sample_data_np, None, k, max_data)
712+
else:
713+
data_flat_np = cupy.asnumpy(data.ravel().compute())
714+
bins, uvk = _compute_natural_break_bins(
715+
data_flat_np, None, k, max_data)
716+
680717
out = _bin(agg, bins, np.arange(uvk))
681718
return out
682719

0 commit comments

Comments
 (0)