Skip to content

Commit 1e88041

Browse files
committed
Fix balanced_allocation OOM: lazy source extraction + memory guard (#1114)
- _extract_sources now uses da.unique() for dask arrays instead of materializing the full raster to find source IDs - Add memory guard before computing N cost surfaces: estimates N * array_bytes + overhead and raises MemoryError if it would exceed 80% of available RAM
1 parent 4087176 commit 1e88041

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

xrspatial/balanced_allocation.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,23 @@ def _as_numpy(arr):
5252

5353

5454
def _extract_sources(raster, target_values):
55-
"""Return sorted array of unique source IDs from the raster."""
56-
data = _to_numpy(raster.data)
55+
"""Return sorted array of unique source IDs from the raster.
56+
57+
For dask arrays, uses ``da.unique`` (per-chunk reduction) so the full
58+
raster is never pulled into RAM just to discover source IDs.
59+
"""
5760
if len(target_values) > 0:
5861
ids = np.asarray(target_values, dtype=np.float64)
59-
else:
60-
mask = np.isfinite(data) & (data != 0)
61-
ids = np.unique(data[mask])
62-
return ids[np.isfinite(ids)]
62+
return ids[np.isfinite(ids)]
63+
64+
data = raster.data
65+
if da is not None and isinstance(data, da.Array):
66+
uniq = da.unique(data).compute() # small result array
67+
mask = np.isfinite(uniq) & (uniq != 0)
68+
return np.sort(uniq[mask])
69+
data_np = _to_numpy(data)
70+
mask = np.isfinite(data_np) & (data_np != 0)
71+
return np.unique(data_np[mask])
6372

6473

6574
def _make_single_source_raster(raster, source_id):
@@ -297,6 +306,25 @@ def balanced_allocation(
297306
return xr.DataArray(out.astype(np.float32), coords=raster.coords,
298307
dims=raster.dims, attrs=raster.attrs)
299308

309+
# Memory guard: we hold N cost surfaces + friction simultaneously.
310+
# Estimate total footprint before doing any expensive work.
311+
array_bytes = np.prod(raster.shape) * 8 # float64
312+
# N cost surfaces + friction + allocation + stacked intermediate
313+
total_estimate = array_bytes * (n_sources + 3)
314+
try:
315+
from xrspatial.zonal import _available_memory_bytes
316+
avail = _available_memory_bytes()
317+
except ImportError:
318+
avail = 2 * 1024**3
319+
if total_estimate > 0.8 * avail:
320+
raise MemoryError(
321+
f"balanced_allocation with {n_sources} sources needs "
322+
f"~{total_estimate / 1e9:.1f} GB ({n_sources} cost surfaces "
323+
f"+ friction + intermediates) but only ~{avail / 1e9:.1f} GB "
324+
f"available. Reduce the number of sources, downsample the "
325+
f"raster, or increase available memory."
326+
)
327+
300328
# Step 1: compute per-source cost-distance surfaces
301329
cost_surfaces = [] # list of raw data arrays (numpy/cupy/dask)
302330
for sid in source_ids:

0 commit comments

Comments
 (0)