Skip to content

Commit 74a6da9

Browse files
authored
Fix diffusion dask OOM: pass scalar diffusivity directly to chunks (#1117)
* Add sweep-performance design spec Parallel subagent triage + ralph-loop workflow for auditing all xrspatial modules for performance bottlenecks, OOM risk under 30TB dask workloads, and backend-specific anti-patterns. * Add sweep-performance implementation plan 7 tasks covering command scaffold, module scoring, parallel subagent dispatch, report merging, ralph-loop generation, and smoke tests. * Add sweep-performance slash command * Fix diffusion dask path: pass scalar diffusivity directly (#1116) For scalar diffusivity, the dask chunk function now receives the float value directly instead of a full-raster numpy array captured in every task closure. This eliminates the O(H*W) eager allocation and the per-task serialization overhead. For DataArray diffusivity, the dask path passes the dask array as a second argument to map_overlap so each chunk gets only its own slice.
1 parent ca8825a commit 74a6da9

File tree

1 file changed

+90
-28
lines changed

1 file changed

+90
-28
lines changed

xrspatial/diffusion.py

Lines changed: 90 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -129,38 +129,69 @@ def _diffuse_cupy(data, alpha_arr, steps, dt_over_dx2, boundary):
129129
# ---- dask + numpy backend ----
130130

131131
def _diffuse_chunk_numpy(chunk, alpha_chunk, steps, dt_over_dx2, block_info=None):
132-
"""Process a single dask chunk (already overlapped by 1 cell)."""
133-
# The chunk arrives with 1-cell overlap on each side from map_overlap.
134-
# We run steps iterations; for steps > 1 the boundary data is stale
135-
# after the first step, but for typical usage (steps=1 per map_overlap
136-
# call) this is correct. The public function wraps this in a loop.
132+
"""Process a single dask chunk (already overlapped by 1 cell).
133+
134+
``alpha_chunk`` may be a scalar (uniform diffusivity) or a 2-D array
135+
matching the overlapped chunk shape.
136+
"""
137137
rows = chunk.shape[0] - 2
138138
cols = chunk.shape[1] - 2
139-
interior_alpha = alpha_chunk[1:-1, 1:-1]
139+
if np.ndim(alpha_chunk) == 0:
140+
# Scalar diffusivity — broadcast to interior shape
141+
interior_alpha = np.broadcast_to(float(alpha_chunk),
142+
(rows, cols))
143+
else:
144+
interior_alpha = alpha_chunk[1:-1, 1:-1]
140145

141146
u = chunk.copy()
142147
for _ in range(steps):
143148
interior = _diffuse_step_numpy(u, interior_alpha, dt_over_dx2, rows, cols)
144-
# rebuild padded array from new interior for next iteration
145149
u[1:-1, 1:-1] = interior
146150
return u
147151

148152

149-
def _diffuse_dask_numpy(data, alpha_arr, steps, dt_over_dx2, boundary):
150-
_func = partial(
151-
_diffuse_chunk_numpy,
152-
alpha_chunk=alpha_arr,
153-
steps=1,
154-
dt_over_dx2=dt_over_dx2,
155-
)
153+
def _diffuse_dask_numpy(data, alpha, steps, dt_over_dx2, boundary):
154+
"""Dask+numpy backend.
155+
156+
``alpha`` is either a Python float (scalar diffusivity) or a dask
157+
array matching data's shape (spatially varying diffusivity).
158+
"""
159+
if isinstance(alpha, (int, float, np.floating)):
160+
# Scalar: pass directly — no full-raster allocation, tiny closure.
161+
_func = partial(
162+
_diffuse_chunk_numpy,
163+
alpha_chunk=float(alpha),
164+
steps=1,
165+
dt_over_dx2=dt_over_dx2,
166+
)
167+
else:
168+
# Spatially varying: alpha is a dask array. map_overlap will
169+
# feed matching chunks automatically.
170+
_func = partial(
171+
_diffuse_chunk_numpy,
172+
steps=1,
173+
dt_over_dx2=dt_over_dx2,
174+
)
156175
u = data.astype(np.float64)
157176
for _ in range(steps):
158-
u = u.map_overlap(
159-
_func,
160-
depth=(1, 1),
161-
boundary=_boundary_to_dask(boundary),
162-
meta=np.array(()),
163-
)
177+
if isinstance(alpha, (int, float, np.floating)):
178+
u = u.map_overlap(
179+
_func,
180+
depth=(1, 1),
181+
boundary=_boundary_to_dask(boundary),
182+
meta=np.array(()),
183+
)
184+
else:
185+
# Pass alpha as a second dask argument to map_overlap
186+
u = da.map_overlap(
187+
_diffuse_chunk_numpy,
188+
u, alpha,
189+
depth=(1, 1),
190+
boundary=_boundary_to_dask(boundary),
191+
meta=np.array(()),
192+
steps=1,
193+
dt_over_dx2=dt_over_dx2,
194+
)
164195
return u
165196

166197

@@ -244,21 +275,34 @@ def diffuse(
244275
_validate_scalar(steps, func_name='diffuse', name='steps', dtype=int, min_val=1)
245276
_validate_boundary(boundary)
246277

247-
# resolve diffusivity to a numpy/cupy array matching agg
278+
# resolve diffusivity
279+
# - scalar: keep as float for dask paths (avoids full-raster allocation)
280+
# - DataArray: keep as .data (numpy/cupy/dask) for backend dispatch
248281
if isinstance(diffusivity, xr.DataArray):
249282
_validate_raster(diffusivity, func_name='diffuse', name='diffusivity', ndim=2)
250283
if diffusivity.shape != agg.shape:
251284
raise ValueError(
252285
f"diffuse(): diffusivity shape {diffusivity.shape} "
253286
f"does not match agg shape {agg.shape}"
254287
)
255-
alpha_arr = diffusivity.values.astype(np.float64)
288+
alpha_scalar = None
289+
alpha_data = diffusivity.data # may be numpy, cupy, or dask
290+
# For numpy/cupy eager paths, materialize to numpy
291+
if da is not None and isinstance(alpha_data, da.Array):
292+
alpha_arr_eager = None # deferred — only built if needed
293+
else:
294+
if hasattr(alpha_data, 'get'):
295+
alpha_arr_eager = alpha_data.get().astype(np.float64)
296+
else:
297+
alpha_arr_eager = np.asarray(alpha_data, dtype=np.float64)
256298
elif isinstance(diffusivity, (int, float)):
257299
if diffusivity <= 0:
258300
raise ValueError(
259301
f"diffuse(): diffusivity must be > 0, got {diffusivity}"
260302
)
261-
alpha_arr = np.full(agg.shape, float(diffusivity), dtype=np.float64)
303+
alpha_scalar = float(diffusivity)
304+
alpha_data = None
305+
alpha_arr_eager = np.full(agg.shape, alpha_scalar, dtype=np.float64)
262306
else:
263307
raise TypeError(
264308
f"diffuse(): diffusivity must be a float or xr.DataArray, "
@@ -274,7 +318,13 @@ def diffuse(
274318
else:
275319
dx = 1.0
276320

277-
alpha_max = float(np.nanmax(alpha_arr))
321+
if alpha_scalar is not None:
322+
alpha_max = alpha_scalar
323+
elif alpha_arr_eager is not None:
324+
alpha_max = float(np.nanmax(alpha_arr_eager))
325+
else:
326+
# dask DataArray diffusivity — compute max lazily
327+
alpha_max = float(da.nanmax(alpha_data).compute())
278328
if alpha_max <= 0:
279329
raise ValueError("diffuse(): all diffusivity values must be > 0")
280330

@@ -287,18 +337,30 @@ def diffuse(
287337

288338
dt_over_dx2 = float(dt) / (dx * dx)
289339

340+
# Build the alpha argument for each backend:
341+
# - numpy/cupy eager: always use alpha_arr_eager (full numpy array)
342+
# - dask: use alpha_scalar (float) or alpha_data (dask array)
343+
if alpha_arr_eager is None and alpha_data is not None:
344+
# Dask DataArray diffusivity, numpy path not yet built
345+
alpha_arr_eager = alpha_data.compute()
346+
if hasattr(alpha_arr_eager, 'get'):
347+
alpha_arr_eager = alpha_arr_eager.get()
348+
alpha_arr_eager = np.asarray(alpha_arr_eager, dtype=np.float64)
349+
350+
dask_alpha = alpha_scalar if alpha_scalar is not None else alpha_data
351+
290352
# dispatch to backend
291353
mapper = ArrayTypeFunctionMapping(
292-
numpy_func=partial(_diffuse_numpy, alpha_arr=alpha_arr,
354+
numpy_func=partial(_diffuse_numpy, alpha_arr=alpha_arr_eager,
293355
steps=steps, dt_over_dx2=dt_over_dx2,
294356
boundary=boundary),
295-
cupy_func=partial(_diffuse_cupy, alpha_arr=alpha_arr,
357+
cupy_func=partial(_diffuse_cupy, alpha_arr=alpha_arr_eager,
296358
steps=steps, dt_over_dx2=dt_over_dx2,
297359
boundary=boundary),
298-
dask_func=partial(_diffuse_dask_numpy, alpha_arr=alpha_arr,
360+
dask_func=partial(_diffuse_dask_numpy, alpha=dask_alpha,
299361
steps=steps, dt_over_dx2=dt_over_dx2,
300362
boundary=boundary),
301-
dask_cupy_func=partial(_diffuse_dask_cupy, alpha_arr=alpha_arr,
363+
dask_cupy_func=partial(_diffuse_dask_cupy, alpha_arr=alpha_arr_eager,
302364
steps=steps, dt_over_dx2=dt_over_dx2,
303365
boundary=boundary),
304366
)

0 commit comments

Comments
 (0)