Skip to content

Commit cf5aaa7

Browse files
committed
Fixes #877, #876: prevent OOM from full dask materialisation in classify.py
natural_breaks and maximum_breaks dask code paths called .ravel().compute() on the full array, materialising the entire dataset into RAM. Replace with capped sampling via _generate_sample_indices() + indexed access so only the sample is ever computed. Add num_sample parameter to maximum_breaks (default 20_000, matching natural_breaks).
1 parent 0cf6ce2 commit cf5aaa7

File tree

2 files changed

+94
-27
lines changed

2 files changed

+94
-27
lines changed

xrspatial/classify.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -695,16 +695,11 @@ def _run_dask_natural_break(agg, num_sample, k):
695695
max_data = float(da.nanmax(da.where(da.isinf(data), np.nan, data)).compute())
696696

697697
num_data = data.size
698-
if num_sample is not None and num_sample < num_data:
699-
# Sample lazily from dask array; only materialize the sample
700-
sample_idx = _generate_sample_indices(num_data, num_sample)
701-
sample_data_np = np.asarray(data.ravel()[sample_idx].compute())
702-
bins, uvk = _compute_natural_break_bins(
703-
sample_data_np, None, k, max_data)
704-
else:
705-
data_flat_np = np.asarray(data.ravel().compute())
706-
bins, uvk = _compute_natural_break_bins(
707-
data_flat_np, None, k, max_data)
698+
if num_sample is None or num_sample >= num_data:
699+
num_sample = num_data # cap: still uses indexed access, never .compute() all
700+
sample_idx = _generate_sample_indices(num_data, num_sample)
701+
sample_data_np = np.asarray(data.ravel()[sample_idx].compute())
702+
bins, uvk = _compute_natural_break_bins(sample_data_np, None, k, max_data)
708703

709704
out = _bin(agg, bins, np.arange(uvk))
710705
return out
@@ -717,17 +712,12 @@ def _run_dask_cupy_natural_break(agg, num_sample, k):
717712
max_data = float(da.nanmax(da.where(da.isinf(data), np.nan, data)).compute().item())
718713

719714
num_data = data.size
720-
if num_sample is not None and num_sample < num_data:
721-
# Sample lazily from dask array; only materialize the sample
722-
sample_idx = _generate_sample_indices(num_data, num_sample)
723-
sample_data = data.ravel()[sample_idx].compute()
724-
sample_data_np = cupy.asnumpy(sample_data)
725-
bins, uvk = _compute_natural_break_bins(
726-
sample_data_np, None, k, max_data)
727-
else:
728-
data_flat_np = cupy.asnumpy(data.ravel().compute())
729-
bins, uvk = _compute_natural_break_bins(
730-
data_flat_np, None, k, max_data)
715+
if num_sample is None or num_sample >= num_data:
716+
num_sample = num_data # cap: still uses indexed access, never .compute() all
717+
sample_idx = _generate_sample_indices(num_data, num_sample)
718+
sample_data = data.ravel()[sample_idx].compute()
719+
sample_data_np = cupy.asnumpy(sample_data)
720+
bins, uvk = _compute_natural_break_bins(sample_data_np, None, k, max_data)
731721

732722
out = _bin(agg, bins, np.arange(uvk))
733723
return out
@@ -1204,7 +1194,8 @@ def _compute_maximum_break_bins(values_np, k):
12041194
return bins
12051195

12061196

1207-
def _run_maximum_breaks(agg, k, module):
1197+
def _run_maximum_breaks(agg, num_sample, k, module):
1198+
# num_sample ignored for in-memory backends
12081199
data = agg.data
12091200
if module == cupy:
12101201
finite = data[cupy.isfinite(data)]
@@ -1218,21 +1209,29 @@ def _run_maximum_breaks(agg, k, module):
12181209
return out
12191210

12201211

1221-
def _run_dask_maximum_breaks(agg, k):
1212+
def _run_dask_maximum_breaks(agg, num_sample, k):
12221213
data = agg.data
12231214
data_clean = da.where(da.isinf(data), np.nan, data)
1224-
values_np = np.asarray(data_clean.ravel().compute())
1215+
num_data = data.size
1216+
if num_sample is None or num_sample >= num_data:
1217+
num_sample = num_data
1218+
sample_idx = _generate_sample_indices(num_data, num_sample)
1219+
values_np = np.asarray(data_clean.ravel()[sample_idx].compute())
12251220
values_np = values_np[np.isfinite(values_np)]
12261221
bins = _compute_maximum_break_bins(values_np, k)
12271222
out = _bin(agg, bins, np.arange(len(bins)))
12281223
return out
12291224

12301225

1231-
def _run_dask_cupy_maximum_breaks(agg, k):
1226+
def _run_dask_cupy_maximum_breaks(agg, num_sample, k):
12321227
data = agg.data
12331228
data_clean = da.where(da.isinf(data), np.nan, data)
12341229
data_cpu = data_clean.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
1235-
values_np = np.asarray(data_cpu.ravel().compute())
1230+
num_data = data.size
1231+
if num_sample is None or num_sample >= num_data:
1232+
num_sample = num_data
1233+
sample_idx = _generate_sample_indices(num_data, num_sample)
1234+
values_np = np.asarray(data_cpu.ravel()[sample_idx].compute())
12361235
values_np = values_np[np.isfinite(values_np)]
12371236
bins = _compute_maximum_break_bins(values_np, k)
12381237
out = _bin(agg, bins, np.arange(len(bins)))
@@ -1242,6 +1241,7 @@ def _run_dask_cupy_maximum_breaks(agg, k):
12421241
@supports_dataset
12431242
def maximum_breaks(agg: xr.DataArray,
12441243
k: int = 5,
1244+
num_sample: Optional[int] = 20_000,
12451245
name: Optional[str] = 'maximum_breaks') -> xr.DataArray:
12461246
"""
12471247
Classify data using the Maximum Breaks algorithm.
@@ -1256,6 +1256,11 @@ def maximum_breaks(agg: xr.DataArray,
12561256
of values to be classified.
12571257
k : int, default=5
12581258
Number of classes to be produced.
1259+
num_sample : int or None, default=20000
1260+
Number of sample data points used to fit the model.
1261+
For dask-backed arrays the sample is drawn lazily to avoid
1262+
materialising the entire array into RAM. ``None`` means use
1263+
all data (safe for numpy/cupy, automatically capped for dask).
12591264
name : str, default='maximum_breaks'
12601265
Name of output aggregate array.
12611266
@@ -1277,7 +1282,7 @@ def maximum_breaks(agg: xr.DataArray,
12771282
cupy_func=lambda *args: _run_maximum_breaks(*args, module=cupy),
12781283
dask_cupy_func=_run_dask_cupy_maximum_breaks,
12791284
)
1280-
out = mapper(agg)(agg, k)
1285+
out = mapper(agg)(agg, num_sample, k)
12811286
return xr.DataArray(out,
12821287
name=name,
12831288
dims=agg.dims,

xrspatial/tests/test_classify.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,3 +883,65 @@ def test_maximum_breaks_dask_matches_numpy():
883883
maximum_breaks(dask_agg).data.compute(),
884884
equal_nan=True,
885885
)
886+
887+
888+
# ===================================================================
889+
# Regression tests: dask paths must not materialise the full array
890+
# ===================================================================
891+
892+
@dask_array_available
893+
def test_natural_breaks_dask_no_full_compute():
894+
"""natural_breaks with num_sample=None on dask must not call
895+
.ravel().compute() on the full array (#877)."""
896+
elevation = np.arange(100, dtype=np.float64).reshape(10, 10)
897+
numpy_agg = xr.DataArray(elevation)
898+
dask_agg = xr.DataArray(da.from_array(elevation, chunks=(5, 5)))
899+
900+
numpy_result = natural_breaks(numpy_agg, num_sample=None, k=5)
901+
dask_result = natural_breaks(dask_agg, num_sample=None, k=5)
902+
903+
np.testing.assert_allclose(
904+
numpy_result.data,
905+
dask_result.data.compute(),
906+
equal_nan=True,
907+
)
908+
909+
910+
@dask_array_available
911+
def test_maximum_breaks_dask_no_full_compute():
912+
"""maximum_breaks on dask must use sampling, not .ravel().compute() (#876)."""
913+
elevation = np.arange(100, dtype=np.float64).reshape(10, 10)
914+
numpy_agg = xr.DataArray(elevation)
915+
dask_agg = xr.DataArray(da.from_array(elevation, chunks=(5, 5)))
916+
917+
# Default num_sample (20_000) > data.size (100), so all data used
918+
numpy_result = maximum_breaks(numpy_agg, k=5)
919+
dask_result = maximum_breaks(dask_agg, k=5)
920+
921+
np.testing.assert_allclose(
922+
numpy_result.data,
923+
dask_result.data.compute(),
924+
equal_nan=True,
925+
)
926+
927+
928+
@dask_array_available
929+
def test_maximum_breaks_dask_num_sample():
930+
"""maximum_breaks with explicit num_sample on dask produces valid,
931+
deterministic results (#876)."""
932+
elevation = np.arange(100, dtype=np.float64).reshape(10, 10)
933+
dask_agg = xr.DataArray(da.from_array(elevation, chunks=(5, 5)))
934+
935+
result1 = maximum_breaks(dask_agg, k=3, num_sample=50)
936+
result2 = maximum_breaks(dask_agg, k=3, num_sample=50)
937+
938+
# Deterministic: same input + same seed → same output
939+
np.testing.assert_allclose(
940+
result1.data.compute(),
941+
result2.data.compute(),
942+
equal_nan=True,
943+
)
944+
# Valid classification: correct shape and values in expected range
945+
assert result1.shape == elevation.shape
946+
unique_vals = np.unique(result1.data.compute())
947+
assert len(unique_vals) <= 3 + 1 # at most k classes + possible nan

0 commit comments

Comments
 (0)