Skip to content

Commit 26bde73

Browse files
authored
Fix dask OOM in visibility and viewshed modules (#1167)
* Fix OOM in visibility module for dask-backed rasters _extract_transect was calling .compute() on the full dask array just to read a handful of transect cells. Now uses vindex fancy indexing so only the relevant chunks are materialized. cumulative_viewshed was allocating a full-size np.zeros count array and calling .values on each viewshed result, forcing materialization every iteration. Now accumulates lazily with da.zeros and dask array addition when the input is dask-backed. * Tighten viewshed Tier B memory estimate and avoid needless copy The dask Tier B memory guard underestimated peak usage at 280 bytes/pixel. Actual peak during lexsort reaches ~360 bytes/pixel (sorted + unsorted event_list coexist) plus 8 bytes/pixel for the computed raster. Updated estimate to 368 bytes/pixel to prevent borderline OOM. Also use astype(copy=False) to skip the float64 copy when data is already float64.
1 parent dbbf813 commit 26bde73

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

xrspatial/viewshed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,7 +1560,7 @@ def _viewshed_cpu(
15601560
num_events = 3 * (n_rows * n_cols - 1)
15611561
event_list = np.zeros((num_events, 7), dtype=np.float64)
15621562

1563-
raster.data = raster.data.astype(np.float64)
1563+
raster.data = raster.data.astype(np.float64, copy=False)
15641564

15651565
_init_event_list(event_list=event_list, raster=raster.data,
15661566
vp_row=viewpoint_row, vp_col=viewpoint_col,
@@ -2167,7 +2167,9 @@ def _viewshed_dask(raster, x, y, observer_elev, target_elev):
21672167
cupy_backed = is_dask_cupy(raster)
21682168

21692169
# --- Tier B: full grid fits in memory → compute and run exact algo ---
2170-
r2_bytes = 280 * height * width
2170+
# Peak memory: event_list sort needs 2x 168*H*W + raster 8*H*W +
2171+
# visibility_grid 8*H*W ≈ 360 bytes/pixel, plus the computed raster.
2172+
r2_bytes = 360 * height * width + 8 * height * width # working + raster
21712173
avail = _available_memory_bytes()
21722174
if r2_bytes < 0.5 * avail:
21732175
raster_mem = raster.copy()

xrspatial/visibility.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def _extract_transect(raster, cells):
6161
if has_dask_array():
6262
import dask.array as da
6363
if isinstance(data, da.Array):
64-
data = data.compute()
64+
# Only compute the needed cells, not the entire array
65+
elevations = data.vindex[rows, cols].compute().astype(np.float64)
66+
return elevations, x_coords, y_coords
6567
if has_cuda_and_cupy() and is_cupy_array(data):
6668
data = data.get()
6769

@@ -217,7 +219,16 @@ def cumulative_viewshed(
217219
if not observers:
218220
raise ValueError("observers list must not be empty")
219221

220-
count = np.zeros(raster.shape, dtype=np.int32)
222+
# Detect dask backend to keep accumulation lazy
223+
_is_dask = False
224+
if has_dask_array():
225+
import dask.array as da
226+
_is_dask = isinstance(raster.data, da.Array)
227+
228+
if _is_dask:
229+
count = da.zeros(raster.shape, dtype=np.int32, chunks=raster.data.chunks)
230+
else:
231+
count = np.zeros(raster.shape, dtype=np.int32)
221232

222233
for obs in observers:
223234
ox = obs['x']
@@ -229,11 +240,17 @@ def cumulative_viewshed(
229240
vs = viewshed(raster, x=ox, y=oy, observer_elev=oe,
230241
target_elev=te, max_distance=md)
231242

232-
vs_np = vs.values
233-
count += (vs_np != INVISIBLE).astype(np.int32)
234-
235-
return xarray.DataArray(count, coords=raster.coords,
236-
dims=raster.dims, attrs=raster.attrs)
243+
vs_data = vs.data
244+
if _is_dask and not isinstance(vs_data, da.Array):
245+
vs_data = da.from_array(vs_data, chunks=raster.data.chunks)
246+
count = count + (vs_data != INVISIBLE).astype(np.int32)
247+
248+
result = xarray.DataArray(count, coords=raster.coords,
249+
dims=raster.dims, attrs=raster.attrs)
250+
if _is_dask:
251+
chunk_dict = dict(zip(raster.dims, raster.data.chunks))
252+
result = result.chunk(chunk_dict)
253+
return result
237254

238255

239256
def visibility_frequency(

0 commit comments

Comments
 (0)