Skip to content

Commit 837eab3

Browse files
authored
Add scalar diffusion solver (#940) (#944)
New `diffuse()` function that runs explicit forward-Euler diffusion on a 2D raster. Supports uniform or spatially varying diffusivity, auto CFL time step, and all four backends (numpy, cupy, dask+numpy, dask+cupy).
1 parent 2688e56 commit 837eab3

File tree

8 files changed

+779
-0
lines changed

8 files changed

+779
-0
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
151151

152152
-------
153153

154+
### **Diffusion**
155+
156+
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
157+
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
158+
| [Diffuse](xrspatial/diffusion.py) | Runs explicit forward-Euler diffusion on a 2D scalar field | ✅️ | ✅️ | ✅️ | ✅️ |
159+
160+
-------
161+
154162
### **Focal**
155163

156164
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
.. _reference.diffusion:
2+
3+
*********
4+
Diffusion
5+
*********
6+
7+
Diffuse
8+
=======
9+
.. autosummary::
10+
:toctree: _autosummary
11+
12+
xrspatial.diffusion.diffuse

docs/source/reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Reference
99

1010
classification
1111
dasymetric
12+
diffusion
1213
fire
1314
flood
1415
focal
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Scalar Diffusion\n",
8+
"\n",
9+
"The `diffuse()` function models how a scalar field (temperature, concentration, humidity) spreads across a raster over time. It solves the 2D diffusion equation using an explicit forward-Euler scheme with a 5-point Laplacian stencil:\n",
10+
"\n",
11+
" du/dt = alpha * laplacian(u)\n",
12+
"\n",
13+
"You can use a uniform diffusivity (single float) or a spatially varying diffusivity raster. The solver auto-selects a stable time step when you don't provide one."
14+
]
15+
},
16+
{
17+
"cell_type": "code",
18+
"execution_count": null,
19+
"metadata": {},
20+
"outputs": [],
21+
"source": [
22+
"import numpy as np\n",
23+
"import xarray as xr\n",
24+
"import matplotlib.pyplot as plt\n",
25+
"\n",
26+
"from xrspatial.diffusion import diffuse"
27+
]
28+
},
29+
{
30+
"cell_type": "markdown",
31+
"metadata": {},
32+
"source": [
33+
"## 1. Point source diffusion\n",
34+
"\n",
35+
"Start with a single hot cell in the center of a cold field and watch it spread."
36+
]
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": null,
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"# Create a 51x51 grid with a hot spot at the center\n",
45+
"shape = (51, 51)\n",
46+
"data = np.zeros(shape)\n",
47+
"data[25, 25] = 100.0\n",
48+
"\n",
49+
"initial = xr.DataArray(\n",
50+
" data,\n",
51+
" dims=['y', 'x'],\n",
52+
" attrs={'res': (1.0, 1.0)},\n",
53+
")\n",
54+
"\n",
55+
"# Run diffusion for different step counts\n",
56+
"steps_list = [0, 10, 50, 200]\n",
57+
"results = {}\n",
58+
"for s in steps_list:\n",
59+
" if s == 0:\n",
60+
" results[s] = initial\n",
61+
" else:\n",
62+
" results[s] = diffuse(initial, diffusivity=1.0, steps=s, boundary='nearest')\n",
63+
"\n",
64+
"fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n",
65+
"for ax, s in zip(axes, steps_list):\n",
66+
" im = ax.imshow(results[s].values, cmap='hot', vmin=0, vmax=10)\n",
67+
" ax.set_title(f'Step {s}')\n",
68+
" ax.axis('off')\n",
69+
"fig.colorbar(im, ax=axes, shrink=0.6, label='Temperature')\n",
70+
"plt.suptitle('Point source diffusion', y=1.02)\n",
71+
"plt.tight_layout()\n",
72+
"plt.show()"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"## 2. Boundary modes\n",
80+
"\n",
81+
"The `boundary` parameter controls what happens at the edges: `'nan'`, `'nearest'`, `'reflect'`, or `'wrap'`."
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"metadata": {},
88+
"outputs": [],
89+
"source": [
90+
"# Hot spot near the edge to highlight boundary behavior\n",
91+
"data_edge = np.zeros((31, 31))\n",
92+
"data_edge[2, 15] = 100.0\n",
93+
"\n",
94+
"edge_agg = xr.DataArray(\n",
95+
" data_edge, dims=['y', 'x'], attrs={'res': (1.0, 1.0)}\n",
96+
")\n",
97+
"\n",
98+
"boundaries = ['nan', 'nearest', 'reflect', 'wrap']\n",
99+
"fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n",
100+
"for ax, bnd in zip(axes, boundaries):\n",
101+
" r = diffuse(edge_agg, diffusivity=1.0, steps=30, boundary=bnd)\n",
102+
" ax.imshow(r.values, cmap='hot', vmin=0, vmax=5)\n",
103+
" ax.set_title(f'boundary={bnd!r}')\n",
104+
" ax.axis('off')\n",
105+
"plt.suptitle('Edge behavior with different boundary modes', y=1.02)\n",
106+
"plt.tight_layout()\n",
107+
"plt.show()"
108+
]
109+
},
110+
{
111+
"cell_type": "markdown",
112+
"metadata": {},
113+
"source": [
114+
"## 3. Spatially varying diffusivity\n",
115+
"\n",
116+
"You can pass a DataArray for `diffusivity` to model materials with different thermal properties. Here we create a wall (low diffusivity) that partially blocks heat flow."
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": null,
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"shape = (51, 51)\n",
126+
"data = np.zeros(shape)\n",
127+
"data[25, 10] = 100.0 # heat source on the left\n",
128+
"\n",
129+
"# Diffusivity field: mostly 1.0, with a vertical wall of low diffusivity\n",
130+
"alpha = np.ones(shape)\n",
131+
"alpha[:, 25] = 0.01 # thin wall in the middle\n",
132+
"\n",
133+
"field = xr.DataArray(data, dims=['y', 'x'], attrs={'res': (1.0, 1.0)})\n",
134+
"alpha_da = xr.DataArray(alpha, dims=['y', 'x'], attrs={'res': (1.0, 1.0)})\n",
135+
"\n",
136+
"result = diffuse(field, diffusivity=alpha_da, steps=300, boundary='nearest')\n",
137+
"\n",
138+
"fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
139+
"axes[0].imshow(alpha, cmap='gray')\n",
140+
"axes[0].set_title('Diffusivity (dark = wall)')\n",
141+
"axes[0].axis('off')\n",
142+
"\n",
143+
"axes[1].imshow(result.values, cmap='hot', vmin=0)\n",
144+
"axes[1].set_title('Temperature after 300 steps')\n",
145+
"axes[1].axis('off')\n",
146+
"\n",
147+
"plt.tight_layout()\n",
148+
"plt.show()"
149+
]
150+
},
151+
{
152+
"cell_type": "markdown",
153+
"metadata": {},
154+
"source": [
155+
"## 4. Practical example: HVAC failure\n",
156+
"\n",
157+
"Simulate what happens when a cooling unit fails in a building floor plan. We start with a comfortable 22 C everywhere, set the failed zone to 35 C, and let heat diffuse."
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": null,
163+
"metadata": {},
164+
"outputs": [],
165+
"source": [
166+
"shape = (61, 61)\n",
167+
"temp = np.full(shape, 22.0) # comfortable baseline\n",
168+
"\n",
169+
"# Failed zone: a 5x5 block near center heats up\n",
170+
"temp[28:33, 28:33] = 35.0\n",
171+
"\n",
172+
"# Walls (NaN = impassable)\n",
173+
"temp[20, 10:50] = np.nan # horizontal wall\n",
174+
"temp[40, 10:50] = np.nan # horizontal wall\n",
175+
"temp[20:41, 10] = np.nan # left wall\n",
176+
"temp[20:41, 50] = np.nan # right wall\n",
177+
"# Door opening\n",
178+
"temp[20, 29:32] = 22.0\n",
179+
"\n",
180+
"field = xr.DataArray(temp, dims=['y', 'x'], attrs={'res': (1.0, 1.0)})\n",
181+
"\n",
182+
"snapshots = [0, 50, 150, 400]\n",
183+
"fig, axes = plt.subplots(1, 4, figsize=(18, 4))\n",
184+
"for ax, s in zip(axes, snapshots):\n",
185+
" if s == 0:\n",
186+
" r = field\n",
187+
" else:\n",
188+
" r = diffuse(field, diffusivity=0.5, steps=s, boundary='nearest')\n",
189+
" im = ax.imshow(r.values, cmap='coolwarm', vmin=20, vmax=36)\n",
190+
" ax.set_title(f'Step {s}')\n",
191+
" ax.axis('off')\n",
192+
"fig.colorbar(im, ax=axes, shrink=0.6, label='Temperature (C)')\n",
193+
"plt.suptitle('Heat spread after HVAC failure', y=1.02)\n",
194+
"plt.tight_layout()\n",
195+
"plt.show()"
196+
]
197+
}
198+
],
199+
"metadata": {
200+
"kernelspec": {
201+
"display_name": "Python 3",
202+
"language": "python",
203+
"name": "python3"
204+
},
205+
"language_info": {
206+
"name": "python",
207+
"version": "3.11.0"
208+
}
209+
},
210+
"nbformat": 4,
211+
"nbformat_minor": 4
212+
}

xrspatial/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from xrspatial.classify import percentiles # noqa
1212
from xrspatial.classify import std_mean # noqa
1313
from xrspatial.diagnostics import diagnose # noqa
14+
from xrspatial.diffusion import diffuse # noqa
1415
from xrspatial.classify import equal_interval # noqa
1516
from xrspatial.classify import natural_breaks # noqa
1617
from xrspatial.classify import quantile # noqa

xrspatial/accessor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,12 @@ def regions(self, **kwargs):
235235
from .zonal import regions
236236
return regions(self._obj, **kwargs)
237237

238+
# ---- Diffusion ----
239+
240+
def diffuse(self, **kwargs):
241+
from .diffusion import diffuse
242+
return diffuse(self._obj, **kwargs)
243+
238244
# ---- Dasymetric ----
239245

240246
def disaggregate(self, values, weight, **kwargs):
@@ -495,6 +501,12 @@ def focal_mean(self, **kwargs):
495501
from .focal import mean
496502
return mean(self._obj, **kwargs)
497503

504+
# ---- Diffusion ----
505+
506+
def diffuse(self, **kwargs):
507+
from .diffusion import diffuse
508+
return diffuse(self._obj, **kwargs)
509+
498510
# ---- Proximity ----
499511

500512
def proximity(self, **kwargs):

0 commit comments

Comments
 (0)