Skip to content

Commit 4091b6a

Browse files
committed
Fix OOM in dask max_distance: build output lazily with _dask_embed_window
_viewshed_windowed was allocating np.full((H, W), ...) for the output before wrapping it as dask — instant OOM on a 30TB input even with max_distance set. Now for dask inputs the output is built chunk-by-chunk: overlapping chunks get a concrete numpy block, all others are lazy da.full blocks that consume no memory until materialized. Adds test_viewshed_dask_max_distance_lazy_output which creates a 100k x 100k (80GB) dask raster and verifies the output stays lazy.
1 parent 00c97a3 commit 4091b6a

File tree

2 files changed

+92
-5
lines changed

2 files changed

+92
-5
lines changed

xrspatial/tests/test_viewshed.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,29 @@ def test_viewshed_dask_distance_sweep():
263263
assert (result > INVISIBLE).all()
264264

265265

266+
def test_viewshed_dask_max_distance_lazy_output():
267+
"""max_distance on a large dask raster should produce a lazy output
268+
without allocating the full grid in memory."""
269+
ny, nx = 100_000, 100_000 # 80 GB if materialized
270+
# Don't actually create the data — just define a lazy dask array
271+
single_chunk = da.zeros((1000, 1000), chunks=(1000, 1000),
272+
dtype=np.float64)
273+
# Tile to 100k x 100k via dask (no memory allocated)
274+
big = da.tile(single_chunk, (100, 100))
275+
xs = np.arange(nx, dtype=float)
276+
ys = np.arange(ny, dtype=float)
277+
raster = xa.DataArray(big, coords=dict(x=xs, y=ys), dims=["y", "x"])
278+
v = viewshed(raster, x=50000.0, y=50000.0,
279+
observer_elev=5, max_distance=10.0)
280+
# Output should be a dask array (lazy), not numpy
281+
assert isinstance(v.data, da.Array)
282+
assert v.shape == (ny, nx)
283+
# Only compute a small slice to verify correctness
284+
center = v.isel(y=slice(49990, 50011), x=slice(49990, 50011)).values
285+
assert center[10, 10] == 180.0 # observer cell
286+
assert (center > INVISIBLE).all() # flat terrain, all visible
287+
288+
266289
def test_viewshed_numpy_max_distance():
267290
"""max_distance should work on plain numpy raster too."""
268291
ny, nx = 20, 20

xrspatial/viewshed.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,63 @@ def viewshed(raster: xarray.DataArray,
17201720
# Dask backend helpers
17211721
# ---------------------------------------------------------------------------
17221722

1723+
def _dask_embed_window(window_np, H, W, r_lo, r_hi, c_lo, c_hi, chunks):
1724+
"""Embed a small numpy result into a full-size lazy dask INVISIBLE array.
1725+
1726+
Builds the output chunk-by-chunk: chunks that overlap the window get a
1727+
numpy array with the window values pasted in; all other chunks are
1728+
created via ``dask.array.full`` so they consume no memory until
1729+
materialised.
1730+
"""
1731+
import dask.array as da
1732+
1733+
y_offsets = _chunk_offsets(chunks[0])
1734+
x_offsets = _chunk_offsets(chunks[1])
1735+
n_yc = len(chunks[0])
1736+
n_xc = len(chunks[1])
1737+
1738+
rows = []
1739+
for yi in range(n_yc):
1740+
row_blocks = []
1741+
cy0, cy1 = int(y_offsets[yi]), int(y_offsets[yi + 1])
1742+
cy_size = cy1 - cy0
1743+
for xi in range(n_xc):
1744+
cx0, cx1 = int(x_offsets[xi]), int(x_offsets[xi + 1])
1745+
cx_size = cx1 - cx0
1746+
1747+
# Does this chunk overlap the result window?
1748+
ov_r0 = max(cy0, r_lo)
1749+
ov_r1 = min(cy1, r_hi)
1750+
ov_c0 = max(cx0, c_lo)
1751+
ov_c1 = min(cx1, c_hi)
1752+
1753+
if ov_r0 < ov_r1 and ov_c0 < ov_c1:
1754+
# This chunk overlaps — build a concrete numpy block
1755+
block = np.full((cy_size, cx_size), INVISIBLE,
1756+
dtype=np.float64)
1757+
# Local indices within the block and within window_np
1758+
block[ov_r0 - cy0:ov_r1 - cy0,
1759+
ov_c0 - cx0:ov_c1 - cx0] = \
1760+
window_np[ov_r0 - r_lo:ov_r1 - r_lo,
1761+
ov_c0 - c_lo:ov_c1 - c_lo]
1762+
row_blocks.append(da.from_delayed(
1763+
_identity_delayed(block),
1764+
shape=(cy_size, cx_size), dtype=np.float64))
1765+
else:
1766+
# No overlap — lazy INVISIBLE block (zero memory)
1767+
row_blocks.append(
1768+
da.full((cy_size, cx_size), INVISIBLE,
1769+
dtype=np.float64, chunks=(cy_size, cx_size)))
1770+
rows.append(da.concatenate(row_blocks, axis=1))
1771+
return da.concatenate(rows, axis=0)
1772+
1773+
1774+
def _identity_delayed(x):
1775+
"""Wrap a concrete value in a dask delayed for da.from_delayed."""
1776+
import dask
1777+
return dask.delayed(lambda v: v)(x)
1778+
1779+
17231780
def _available_memory_bytes():
17241781
"""Best-effort estimate of available memory in bytes."""
17251782
try:
@@ -2034,7 +2091,18 @@ def _viewshed_windowed(raster, x, y, observer_elev, target_elev,
20342091
window, x, y, observer_elev, target_elev)
20352092

20362093
# Embed in full-size INVISIBLE output, preserving array type
2037-
if is_cupy and has_rtx():
2094+
is_dask = has_dask_array() and isinstance(raster.data, da.Array)
2095+
2096+
if is_dask:
2097+
# Build output lazily to avoid allocating the full grid in memory.
2098+
# The window result is a small numpy array; the surrounding region
2099+
# is filled with INVISIBLE via dask.array.full.
2100+
local_vals = local_result.values if isinstance(
2101+
local_result.data, np.ndarray) else local_result.data.get()
2102+
full_vis = _dask_embed_window(
2103+
local_vals, height, width, r_lo, r_hi, c_lo, c_hi,
2104+
raster.data.chunks)
2105+
elif is_cupy and has_rtx():
20382106
import cupy as cp
20392107
full_vis = cp.full((height, width), INVISIBLE, dtype=np.float64)
20402108
full_vis[r_lo:r_hi, c_lo:c_hi] = local_result.data
@@ -2043,10 +2111,6 @@ def _viewshed_windowed(raster, x, y, observer_elev, target_elev,
20432111
full_vis = np.full((height, width), INVISIBLE, dtype=np.float64)
20442112
full_vis[r_lo:r_hi, c_lo:c_hi] = local_vals
20452113

2046-
# Wrap in the same array type as the input
2047-
if has_dask_array() and isinstance(raster.data, da.Array):
2048-
full_vis = da.from_array(full_vis, chunks=raster.data.chunks)
2049-
20502114
return xarray.DataArray(full_vis, coords=raster.coords,
20512115
dims=raster.dims, attrs=raster.attrs)
20522116

0 commit comments

Comments
 (0)