@@ -129,38 +129,69 @@ def _diffuse_cupy(data, alpha_arr, steps, dt_over_dx2, boundary):
129129# ---- dask + numpy backend ----
130130
131131def _diffuse_chunk_numpy (chunk , alpha_chunk , steps , dt_over_dx2 , block_info = None ):
132- """Process a single dask chunk (already overlapped by 1 cell)."""
133- # The chunk arrives with 1-cell overlap on each side from map_overlap.
134- # We run steps iterations; for steps > 1 the boundary data is stale
135- # after the first step, but for typical usage (steps=1 per map_overlap
136- # call) this is correct. The public function wraps this in a loop.
132+ """Process a single dask chunk (already overlapped by 1 cell).
133+
134+ ``alpha_chunk`` may be a scalar (uniform diffusivity) or a 2-D array
135+ matching the overlapped chunk shape.
136+ """
137137 rows = chunk .shape [0 ] - 2
138138 cols = chunk .shape [1 ] - 2
139- interior_alpha = alpha_chunk [1 :- 1 , 1 :- 1 ]
139+ if np .ndim (alpha_chunk ) == 0 :
140+ # Scalar diffusivity — broadcast to interior shape
141+ interior_alpha = np .broadcast_to (float (alpha_chunk ),
142+ (rows , cols ))
143+ else :
144+ interior_alpha = alpha_chunk [1 :- 1 , 1 :- 1 ]
140145
141146 u = chunk .copy ()
142147 for _ in range (steps ):
143148 interior = _diffuse_step_numpy (u , interior_alpha , dt_over_dx2 , rows , cols )
144- # rebuild padded array from new interior for next iteration
145149 u [1 :- 1 , 1 :- 1 ] = interior
146150 return u
147151
148152
149- def _diffuse_dask_numpy (data , alpha_arr , steps , dt_over_dx2 , boundary ):
150- _func = partial (
151- _diffuse_chunk_numpy ,
152- alpha_chunk = alpha_arr ,
153- steps = 1 ,
154- dt_over_dx2 = dt_over_dx2 ,
155- )
153+ def _diffuse_dask_numpy (data , alpha , steps , dt_over_dx2 , boundary ):
154+ """Dask+numpy backend.
155+
156+ ``alpha`` is either a Python float (scalar diffusivity) or a dask
157+ array matching data's shape (spatially varying diffusivity).
158+ """
159+ if isinstance (alpha , (int , float , np .floating )):
160+ # Scalar: pass directly — no full-raster allocation, tiny closure.
161+ _func = partial (
162+ _diffuse_chunk_numpy ,
163+ alpha_chunk = float (alpha ),
164+ steps = 1 ,
165+ dt_over_dx2 = dt_over_dx2 ,
166+ )
167+ else :
168+ # Spatially varying: alpha is a dask array. map_overlap will
169+ # feed matching chunks automatically.
170+ _func = partial (
171+ _diffuse_chunk_numpy ,
172+ steps = 1 ,
173+ dt_over_dx2 = dt_over_dx2 ,
174+ )
156175 u = data .astype (np .float64 )
157176 for _ in range (steps ):
158- u = u .map_overlap (
159- _func ,
160- depth = (1 , 1 ),
161- boundary = _boundary_to_dask (boundary ),
162- meta = np .array (()),
163- )
177+ if isinstance (alpha , (int , float , np .floating )):
178+ u = u .map_overlap (
179+ _func ,
180+ depth = (1 , 1 ),
181+ boundary = _boundary_to_dask (boundary ),
182+ meta = np .array (()),
183+ )
184+ else :
185+ # Pass alpha as a second dask argument to map_overlap
186+ u = da .map_overlap (
187+ _diffuse_chunk_numpy ,
188+ u , alpha ,
189+ depth = (1 , 1 ),
190+ boundary = _boundary_to_dask (boundary ),
191+ meta = np .array (()),
192+ steps = 1 ,
193+ dt_over_dx2 = dt_over_dx2 ,
194+ )
164195 return u
165196
166197
@@ -244,21 +275,34 @@ def diffuse(
244275 _validate_scalar (steps , func_name = 'diffuse' , name = 'steps' , dtype = int , min_val = 1 )
245276 _validate_boundary (boundary )
246277
247- # resolve diffusivity to a numpy/cupy array matching agg
278+ # resolve diffusivity
279+ # - scalar: keep as float for dask paths (avoids full-raster allocation)
280+ # - DataArray: keep as .data (numpy/cupy/dask) for backend dispatch
248281 if isinstance (diffusivity , xr .DataArray ):
249282 _validate_raster (diffusivity , func_name = 'diffuse' , name = 'diffusivity' , ndim = 2 )
250283 if diffusivity .shape != agg .shape :
251284 raise ValueError (
252285 f"diffuse(): diffusivity shape { diffusivity .shape } "
253286 f"does not match agg shape { agg .shape } "
254287 )
255- alpha_arr = diffusivity .values .astype (np .float64 )
288+ alpha_scalar = None
289+ alpha_data = diffusivity .data # may be numpy, cupy, or dask
290+ # For numpy/cupy eager paths, materialize to numpy
291+ if da is not None and isinstance (alpha_data , da .Array ):
292+ alpha_arr_eager = None # deferred — only built if needed
293+ else :
294+ if hasattr (alpha_data , 'get' ):
295+ alpha_arr_eager = alpha_data .get ().astype (np .float64 )
296+ else :
297+ alpha_arr_eager = np .asarray (alpha_data , dtype = np .float64 )
256298 elif isinstance (diffusivity , (int , float )):
257299 if diffusivity <= 0 :
258300 raise ValueError (
259301 f"diffuse(): diffusivity must be > 0, got { diffusivity } "
260302 )
261- alpha_arr = np .full (agg .shape , float (diffusivity ), dtype = np .float64 )
303+ alpha_scalar = float (diffusivity )
304+ alpha_data = None
305+ alpha_arr_eager = np .full (agg .shape , alpha_scalar , dtype = np .float64 )
262306 else :
263307 raise TypeError (
264308 f"diffuse(): diffusivity must be a float or xr.DataArray, "
@@ -274,7 +318,13 @@ def diffuse(
274318 else :
275319 dx = 1.0
276320
277- alpha_max = float (np .nanmax (alpha_arr ))
321+ if alpha_scalar is not None :
322+ alpha_max = alpha_scalar
323+ elif alpha_arr_eager is not None :
324+ alpha_max = float (np .nanmax (alpha_arr_eager ))
325+ else :
326+ # dask DataArray diffusivity — compute max lazily
327+ alpha_max = float (da .nanmax (alpha_data ).compute ())
278328 if alpha_max <= 0 :
279329 raise ValueError ("diffuse(): all diffusivity values must be > 0" )
280330
@@ -287,18 +337,30 @@ def diffuse(
287337
288338 dt_over_dx2 = float (dt ) / (dx * dx )
289339
340+ # Build the alpha argument for each backend:
341+ # - numpy/cupy eager: always use alpha_arr_eager (full numpy array)
342+ # - dask: use alpha_scalar (float) or alpha_data (dask array)
343+ if alpha_arr_eager is None and alpha_data is not None :
344+ # Dask DataArray diffusivity, numpy path not yet built
345+ alpha_arr_eager = alpha_data .compute ()
346+ if hasattr (alpha_arr_eager , 'get' ):
347+ alpha_arr_eager = alpha_arr_eager .get ()
348+ alpha_arr_eager = np .asarray (alpha_arr_eager , dtype = np .float64 )
349+
350+ dask_alpha = alpha_scalar if alpha_scalar is not None else alpha_data
351+
290352 # dispatch to backend
291353 mapper = ArrayTypeFunctionMapping (
292- numpy_func = partial (_diffuse_numpy , alpha_arr = alpha_arr ,
354+ numpy_func = partial (_diffuse_numpy , alpha_arr = alpha_arr_eager ,
293355 steps = steps , dt_over_dx2 = dt_over_dx2 ,
294356 boundary = boundary ),
295- cupy_func = partial (_diffuse_cupy , alpha_arr = alpha_arr ,
357+ cupy_func = partial (_diffuse_cupy , alpha_arr = alpha_arr_eager ,
296358 steps = steps , dt_over_dx2 = dt_over_dx2 ,
297359 boundary = boundary ),
298- dask_func = partial (_diffuse_dask_numpy , alpha_arr = alpha_arr ,
360+ dask_func = partial (_diffuse_dask_numpy , alpha = dask_alpha ,
299361 steps = steps , dt_over_dx2 = dt_over_dx2 ,
300362 boundary = boundary ),
301- dask_cupy_func = partial (_diffuse_dask_cupy , alpha_arr = alpha_arr ,
363+ dask_cupy_func = partial (_diffuse_dask_cupy , alpha_arr = alpha_arr_eager ,
302364 steps = steps , dt_over_dx2 = dt_over_dx2 ,
303365 boundary = boundary ),
304366 )
0 commit comments