@@ -52,14 +52,23 @@ def _as_numpy(arr):
5252
5353
5454def _extract_sources (raster , target_values ):
55- """Return sorted array of unique source IDs from the raster."""
56- data = _to_numpy (raster .data )
55+ """Return sorted array of unique source IDs from the raster.
56+
57+ For dask arrays, uses ``da.unique`` (per-chunk reduction) so the full
58+ raster is never pulled into RAM just to discover source IDs.
59+ """
5760 if len (target_values ) > 0 :
5861 ids = np .asarray (target_values , dtype = np .float64 )
59- else :
60- mask = np .isfinite (data ) & (data != 0 )
61- ids = np .unique (data [mask ])
62- return ids [np .isfinite (ids )]
62+ return ids [np .isfinite (ids )]
63+
64+ data = raster .data
65+ if da is not None and isinstance (data , da .Array ):
66+ uniq = da .unique (data ).compute () # small result array
67+ mask = np .isfinite (uniq ) & (uniq != 0 )
68+ return np .sort (uniq [mask ])
69+ data_np = _to_numpy (data )
70+ mask = np .isfinite (data_np ) & (data_np != 0 )
71+ return np .unique (data_np [mask ])
6372
6473
6574def _make_single_source_raster (raster , source_id ):
@@ -297,6 +306,25 @@ def balanced_allocation(
297306 return xr .DataArray (out .astype (np .float32 ), coords = raster .coords ,
298307 dims = raster .dims , attrs = raster .attrs )
299308
309+ # Memory guard: we hold N cost surfaces + friction simultaneously.
310+ # Estimate total footprint before doing any expensive work.
311+ array_bytes = np .prod (raster .shape ) * 8 # float64
312+ # N cost surfaces + friction + allocation + stacked intermediate
313+ total_estimate = array_bytes * (n_sources + 3 )
314+ try :
315+ from xrspatial .zonal import _available_memory_bytes
316+ avail = _available_memory_bytes ()
317+ except ImportError :
318+ avail = 2 * 1024 ** 3
319+ if total_estimate > 0.8 * avail :
320+ raise MemoryError (
321+ f"balanced_allocation with { n_sources } sources needs "
322+ f"~{ total_estimate / 1e9 :.1f} GB ({ n_sources } cost surfaces "
323+ f"+ friction + intermediates) but only ~{ avail / 1e9 :.1f} GB "
324+ f"available. Reduce the number of sources, downsample the "
325+ f"raster, or increase available memory."
326+ )
327+
300328 # Step 1: compute per-source cost-distance surfaces
301329 cost_surfaces = [] # list of raw data arrays (numpy/cupy/dask)
302330 for sid in source_ids :
0 commit comments