Skip to content

Commit 1920f9a

Browse files
committed
Fix OOM in visibility module for dask-backed rasters
_extract_transect was calling .compute() on the full dask array just to read a handful of transect cells. Now uses vindex fancy indexing so only the relevant chunks are materialized. cumulative_viewshed was allocating a full-size np.zeros count array and calling .values on each viewshed result, forcing materialization every iteration. Now accumulates lazily with da.zeros and dask array addition when the input is dask-backed.
1 parent dbbf813 commit 1920f9a

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

xrspatial/visibility.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def _extract_transect(raster, cells):
6161
if has_dask_array():
6262
import dask.array as da
6363
if isinstance(data, da.Array):
64-
data = data.compute()
64+
# Only compute the needed cells, not the entire array
65+
elevations = data.vindex[rows, cols].compute().astype(np.float64)
66+
return elevations, x_coords, y_coords
6567
if has_cuda_and_cupy() and is_cupy_array(data):
6668
data = data.get()
6769

@@ -217,7 +219,16 @@ def cumulative_viewshed(
217219
if not observers:
218220
raise ValueError("observers list must not be empty")
219221

220-
count = np.zeros(raster.shape, dtype=np.int32)
222+
# Detect dask backend to keep accumulation lazy
223+
_is_dask = False
224+
if has_dask_array():
225+
import dask.array as da
226+
_is_dask = isinstance(raster.data, da.Array)
227+
228+
if _is_dask:
229+
count = da.zeros(raster.shape, dtype=np.int32, chunks=raster.data.chunks)
230+
else:
231+
count = np.zeros(raster.shape, dtype=np.int32)
221232

222233
for obs in observers:
223234
ox = obs['x']
@@ -229,11 +240,17 @@ def cumulative_viewshed(
229240
vs = viewshed(raster, x=ox, y=oy, observer_elev=oe,
230241
target_elev=te, max_distance=md)
231242

232-
vs_np = vs.values
233-
count += (vs_np != INVISIBLE).astype(np.int32)
234-
235-
return xarray.DataArray(count, coords=raster.coords,
236-
dims=raster.dims, attrs=raster.attrs)
243+
vs_data = vs.data
244+
if _is_dask and not isinstance(vs_data, da.Array):
245+
vs_data = da.from_array(vs_data, chunks=raster.data.chunks)
246+
count = count + (vs_data != INVISIBLE).astype(np.int32)
247+
248+
result = xarray.DataArray(count, coords=raster.coords,
249+
dims=raster.dims, attrs=raster.attrs)
250+
if _is_dask:
251+
chunk_dict = dict(zip(raster.dims, raster.data.chunks))
252+
result = result.chunk(chunk_dict)
253+
return result
237254

238255

239256
def visibility_frequency(

0 commit comments

Comments
 (0)