Skip to content

Commit f4c9fcc

Browse files
authored
Add CuPy and Dask+CuPy backends for kriging (#951) (#960)
* Add CuPy and Dask+CuPy backends for kriging (#951) Replace the GPU stubs with working implementations. Variogram fitting and matrix inversion stay on CPU (small matrices). The per-pixel prediction runs on GPU via vectorized CuPy operations, matching the existing numpy algorithm. Variogram model functions now auto-detect the array module so they work with both numpy and cupy inputs. * Add GPU test coverage for kriging (#951) Tests for CuPy and Dask+CuPy kriging backends covering prediction accuracy, all three variogram models, and variance output. Adds _to_numpy() helper and cupy/dask_cupy backends to _make_template(). * Update README feature matrix for kriging GPU support (#951) * Add kriging user guide notebook (#951) Covers basic interpolation, variance output, variogram model comparison, and a soil pH mapping example. * Rename kriging notebook to 19_Kriging.ipynb (#951)
1 parent 64e59b2 commit f4c9fcc

File tree

4 files changed

+423
-12
lines changed

4 files changed

+423
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
315315
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
316316
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
317317
| [IDW](xrspatial/interpolate/_idw.py) | Inverse Distance Weighting from scattered points to a raster grid | ✅️ | ✅️ | ✅️ | ✅️ |
318-
| [Kriging](xrspatial/interpolate/_kriging.py) | Ordinary Kriging with automatic variogram fitting (spherical, exponential, gaussian) | ✅️ | ✅️ | | |
318+
| [Kriging](xrspatial/interpolate/_kriging.py) | Ordinary Kriging with automatic variogram fitting (spherical, exponential, gaussian) | ✅️ | ✅️ | ✅️ | ✅️ |
319319
| [Spline](xrspatial/interpolate/_spline.py) | Thin Plate Spline interpolation with optional smoothing | ✅️ | ✅️ | ✅️ | ✅️ |
320320

321321
-----------
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "cell-0",
6+
"metadata": {},
7+
"source": [
8+
"# Kriging Interpolation\n",
9+
"\n",
10+
"The `kriging()` function performs Ordinary Kriging, a geostatistical interpolation method that produces optimal, unbiased predictions from scattered point observations. Unlike IDW, kriging accounts for the spatial correlation structure of the data through a variogram model.\n",
11+
"\n",
12+
"Key features:\n",
13+
"- Automatic experimental variogram computation and model fitting\n",
14+
"- Three variogram models: spherical, exponential, gaussian\n",
15+
"- Optional kriging variance (prediction uncertainty) output\n",
16+
"- All four backends: NumPy, Dask, CuPy, Dask+CuPy"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"id": "cell-1",
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"import numpy as np\n",
27+
"import xarray as xr\n",
28+
"import matplotlib.pyplot as plt\n",
29+
"\n",
30+
"from xrspatial.interpolate import kriging"
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"id": "cell-2",
36+
"metadata": {},
37+
"source": [
38+
"## 1. Basic interpolation from point observations\n",
39+
"\n",
40+
"Generate scattered sample points from a known surface and use kriging to reconstruct the full field."
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": null,
46+
"id": "cell-3",
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"# True surface: z = sin(x) * cos(y)\n",
51+
"rng = np.random.RandomState(42)\n",
52+
"n_pts = 40\n",
53+
"x_pts = rng.uniform(0, 6, n_pts)\n",
54+
"y_pts = rng.uniform(0, 6, n_pts)\n",
55+
"z_pts = np.sin(x_pts) * np.cos(y_pts) + rng.normal(0, 0.05, n_pts)\n",
56+
"\n",
57+
"# Output grid\n",
58+
"x_grid = np.linspace(0, 6, 60)\n",
59+
"y_grid = np.linspace(0, 6, 60)\n",
60+
"template = xr.DataArray(\n",
61+
" np.zeros((len(y_grid), len(x_grid))),\n",
62+
" dims=['y', 'x'],\n",
63+
" coords={'y': y_grid, 'x': x_grid},\n",
64+
")\n",
65+
"\n",
66+
"result = kriging(x_pts, y_pts, z_pts, template)\n",
67+
"\n",
68+
"# True surface for comparison\n",
69+
"gx, gy = np.meshgrid(x_grid, y_grid)\n",
70+
"true_surface = np.sin(gx) * np.cos(gy)\n",
71+
"\n",
72+
"fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
73+
"im0 = axes[0].imshow(true_surface, extent=[0, 6, 6, 0], cmap='viridis')\n",
74+
"axes[0].scatter(x_pts, y_pts, c='red', s=15, edgecolors='k', linewidth=0.5)\n",
75+
"axes[0].set_title('True surface + sample points')\n",
76+
"fig.colorbar(im0, ax=axes[0], shrink=0.7)\n",
77+
"\n",
78+
"im1 = axes[1].imshow(result.values, extent=[0, 6, 6, 0], cmap='viridis')\n",
79+
"axes[1].set_title('Kriging prediction')\n",
80+
"fig.colorbar(im1, ax=axes[1], shrink=0.7)\n",
81+
"\n",
82+
"plt.tight_layout()\n",
83+
"plt.show()"
84+
]
85+
},
86+
{
87+
"cell_type": "markdown",
88+
"id": "cell-4",
89+
"metadata": {},
90+
"source": [
91+
"## 2. Kriging variance\n",
92+
"\n",
93+
"Set `return_variance=True` to get prediction uncertainty. Variance is low near observed points and higher in data-sparse regions."
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"id": "cell-5",
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"pred, var = kriging(x_pts, y_pts, z_pts, template, return_variance=True)\n",
104+
"\n",
105+
"fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
106+
"im0 = axes[0].imshow(pred.values, extent=[0, 6, 6, 0], cmap='viridis')\n",
107+
"axes[0].scatter(x_pts, y_pts, c='red', s=15, edgecolors='k', linewidth=0.5)\n",
108+
"axes[0].set_title('Prediction')\n",
109+
"fig.colorbar(im0, ax=axes[0], shrink=0.7)\n",
110+
"\n",
111+
"im1 = axes[1].imshow(var.values, extent=[0, 6, 6, 0], cmap='magma')\n",
112+
"axes[1].scatter(x_pts, y_pts, c='cyan', s=15, edgecolors='k', linewidth=0.5)\n",
113+
"axes[1].set_title('Kriging variance')\n",
114+
"fig.colorbar(im1, ax=axes[1], shrink=0.7)\n",
115+
"\n",
116+
"plt.tight_layout()\n",
117+
"plt.show()"
118+
]
119+
},
120+
{
121+
"cell_type": "markdown",
122+
"id": "cell-6",
123+
"metadata": {},
124+
"source": [
125+
"## 3. Variogram model comparison\n",
126+
"\n",
127+
"The `variogram_model` parameter controls the spatial correlation model. Different models produce subtly different predictions."
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": null,
133+
"id": "cell-7",
134+
"metadata": {},
135+
"outputs": [],
136+
"source": [
137+
"models = ['spherical', 'exponential', 'gaussian']\n",
138+
"fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
139+
"\n",
140+
"for ax, model in zip(axes, models):\n",
141+
" r = kriging(x_pts, y_pts, z_pts, template, variogram_model=model)\n",
142+
" im = ax.imshow(r.values, extent=[0, 6, 6, 0], cmap='viridis')\n",
143+
" ax.set_title(f'{model}')\n",
144+
" fig.colorbar(im, ax=ax, shrink=0.7)\n",
145+
"\n",
146+
"plt.suptitle('Variogram model comparison', y=1.02)\n",
147+
"plt.tight_layout()\n",
148+
"plt.show()"
149+
]
150+
},
151+
{
152+
"cell_type": "markdown",
153+
"id": "cell-8",
154+
"metadata": {},
155+
"source": [
156+
"## 4. Practical example: soil property mapping\n",
157+
"\n",
158+
"Simulate soil pH measurements at random field locations and produce a continuous map with uncertainty."
159+
]
160+
},
161+
{
162+
"cell_type": "code",
163+
"execution_count": null,
164+
"id": "cell-9",
165+
"metadata": {},
166+
"outputs": [],
167+
"source": [
168+
"# Simulated soil pH: smooth trend + spatially correlated noise\n",
169+
"rng = np.random.RandomState(7)\n",
170+
"n_samples = 50\n",
171+
"x_soil = rng.uniform(0, 100, n_samples) # meters\n",
172+
"y_soil = rng.uniform(0, 100, n_samples)\n",
173+
"\n",
174+
"# Trend: pH increases toward the northeast\n",
175+
"ph = 5.5 + 0.015 * x_soil + 0.010 * y_soil + rng.normal(0, 0.3, n_samples)\n",
176+
"\n",
177+
"# Dense prediction grid\n",
178+
"xg = np.linspace(0, 100, 80)\n",
179+
"yg = np.linspace(0, 100, 80)\n",
180+
"template_soil = xr.DataArray(\n",
181+
" np.zeros((len(yg), len(xg))),\n",
182+
" dims=['y', 'x'],\n",
183+
" coords={'y': yg, 'x': xg},\n",
184+
")\n",
185+
"\n",
186+
"ph_pred, ph_var = kriging(\n",
187+
" x_soil, y_soil, ph, template_soil,\n",
188+
" variogram_model='spherical', return_variance=True,\n",
189+
")\n",
190+
"\n",
191+
"fig, axes = plt.subplots(1, 2, figsize=(13, 5))\n",
192+
"\n",
193+
"im0 = axes[0].imshow(\n",
194+
" ph_pred.values, extent=[0, 100, 100, 0],\n",
195+
" cmap='RdYlGn', vmin=5, vmax=8,\n",
196+
")\n",
197+
"axes[0].scatter(x_soil, y_soil, c=ph, cmap='RdYlGn', vmin=5, vmax=8,\n",
198+
" s=30, edgecolors='k', linewidth=0.5)\n",
199+
"axes[0].set_title('Predicted soil pH')\n",
200+
"axes[0].set_xlabel('East (m)')\n",
201+
"axes[0].set_ylabel('North (m)')\n",
202+
"fig.colorbar(im0, ax=axes[0], shrink=0.7, label='pH')\n",
203+
"\n",
204+
"im1 = axes[1].imshow(\n",
205+
" ph_var.values, extent=[0, 100, 100, 0], cmap='magma',\n",
206+
")\n",
207+
"axes[1].scatter(x_soil, y_soil, c='cyan', s=15, edgecolors='k', linewidth=0.5)\n",
208+
"axes[1].set_title('Prediction variance')\n",
209+
"axes[1].set_xlabel('East (m)')\n",
210+
"fig.colorbar(im1, ax=axes[1], shrink=0.7, label='Variance')\n",
211+
"\n",
212+
"plt.tight_layout()\n",
213+
"plt.show()"
214+
]
215+
}
216+
],
217+
"metadata": {
218+
"kernelspec": {
219+
"display_name": "Python 3",
220+
"language": "python",
221+
"name": "python3"
222+
},
223+
"language_info": {
224+
"name": "python",
225+
"version": "3.10.0"
226+
}
227+
},
228+
"nbformat": 4,
229+
"nbformat_minor": 5
230+
}

xrspatial/interpolate/_kriging.py

Lines changed: 102 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,34 @@
2626
da = None
2727

2828

29+
def _get_xp(arr):
30+
"""Return the array module (numpy or cupy) for *arr*."""
31+
if cupy is not None and isinstance(arr, cupy.ndarray):
32+
return cupy
33+
return np
34+
35+
2936
# ---------------------------------------------------------------------------
3037
# Variogram models
3138
# ---------------------------------------------------------------------------
3239

3340
def _spherical(h, c0, c, a):
34-
return np.where(
41+
xp = _get_xp(h)
42+
return xp.where(
3543
h < a,
3644
c0 + c * (1.5 * h / a - 0.5 * (h / a) ** 3),
3745
c0 + c,
3846
)
3947

4048

4149
def _exponential(h, c0, c, a):
42-
return c0 + c * (1.0 - np.exp(-3.0 * h / a))
50+
xp = _get_xp(h)
51+
return c0 + c * (1.0 - xp.exp(-3.0 * h / a))
4352

4453

4554
def _gaussian(h, c0, c, a):
46-
return c0 + c * (1.0 - np.exp(-3.0 * (h / a) ** 2))
55+
xp = _get_xp(h)
56+
return c0 + c * (1.0 - xp.exp(-3.0 * (h / a) ** 2))
4757

4858

4959
_VARIOGRAM_MODELS = {
@@ -230,15 +240,97 @@ def _chunk_var(block, block_info=None):
230240

231241

232242
# ---------------------------------------------------------------------------
233-
# GPU stubs
243+
# CuPy prediction
234244
# ---------------------------------------------------------------------------
235245

236-
def _kriging_gpu_not_impl(*args, **kwargs):
237-
raise NotImplementedError(
238-
"kriging(): GPU (CuPy) backend is not supported. "
239-
"Use numpy or dask+numpy backend."
246+
def _kriging_predict_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
247+
vario_func, K_inv, return_variance):
248+
"""Vectorised kriging prediction on GPU via CuPy."""
249+
n = len(x_pts)
250+
251+
x_gpu = cupy.asarray(x_pts)
252+
y_gpu = cupy.asarray(y_pts)
253+
z_gpu = cupy.asarray(z_pts)
254+
xg_gpu = cupy.asarray(x_grid)
255+
yg_gpu = cupy.asarray(y_grid)
256+
K_inv_gpu = cupy.asarray(K_inv)
257+
258+
gx, gy = cupy.meshgrid(xg_gpu, yg_gpu)
259+
gx_flat = gx.ravel()
260+
gy_flat = gy.ravel()
261+
262+
dx = gx_flat[:, None] - x_gpu[None, :]
263+
dy = gy_flat[:, None] - y_gpu[None, :]
264+
dists = cupy.sqrt(dx ** 2 + dy ** 2)
265+
266+
k0 = cupy.empty((len(gx_flat), n + 1), dtype=np.float64)
267+
k0[:, :n] = vario_func(dists)
268+
k0[:, n] = 1.0
269+
270+
w = k0 @ K_inv_gpu
271+
272+
prediction = (w[:, :n] * z_gpu[None, :]).sum(axis=1)
273+
prediction = prediction.reshape(len(y_grid), len(x_grid))
274+
275+
variance = None
276+
if return_variance:
277+
variance = cupy.sum(w * k0, axis=1)
278+
variance = variance.reshape(len(y_grid), len(x_grid))
279+
280+
return prediction, variance
281+
282+
283+
# ---------------------------------------------------------------------------
284+
# CuPy backend wrapper
285+
# ---------------------------------------------------------------------------
286+
287+
def _kriging_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
288+
vario_func, K_inv, return_variance, template_data):
289+
return _kriging_predict_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
290+
vario_func, K_inv, return_variance)
291+
292+
293+
# ---------------------------------------------------------------------------
294+
# Dask + CuPy backend
295+
# ---------------------------------------------------------------------------
296+
297+
def _kriging_dask_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
298+
vario_func, K_inv, return_variance, template_data):
299+
300+
def _chunk_pred(block, block_info=None):
301+
if block_info is None:
302+
return block
303+
loc = block_info[0]['array-location']
304+
y_sl = y_grid[loc[0][0]:loc[0][1]]
305+
x_sl = x_grid[loc[1][0]:loc[1][1]]
306+
pred, _ = _kriging_predict_cupy(x_pts, y_pts, z_pts, x_sl, y_sl,
307+
vario_func, K_inv, False)
308+
return pred
309+
310+
prediction = da.map_blocks(
311+
_chunk_pred, template_data, dtype=np.float64,
312+
meta=cupy.array((), dtype=np.float64),
240313
)
241314

315+
variance = None
316+
if return_variance:
317+
def _chunk_var(block, block_info=None):
318+
if block_info is None:
319+
return block
320+
loc = block_info[0]['array-location']
321+
y_sl = y_grid[loc[0][0]:loc[0][1]]
322+
x_sl = x_grid[loc[1][0]:loc[1][1]]
323+
_, var = _kriging_predict_cupy(x_pts, y_pts, z_pts, x_sl, y_sl,
324+
vario_func, K_inv, True)
325+
return var
326+
327+
variance = da.map_blocks(
328+
_chunk_var, template_data, dtype=np.float64,
329+
meta=cupy.array((), dtype=np.float64),
330+
)
331+
332+
return prediction, variance
333+
242334

243335
# ---------------------------------------------------------------------------
244336
# Public API
@@ -320,9 +412,9 @@ def vario_func(h):
320412

321413
mapper = ArrayTypeFunctionMapping(
322414
numpy_func=_kriging_numpy,
323-
cupy_func=_kriging_gpu_not_impl,
415+
cupy_func=_kriging_cupy,
324416
dask_func=_kriging_dask_numpy,
325-
dask_cupy_func=_kriging_gpu_not_impl,
417+
dask_cupy_func=_kriging_dask_cupy,
326418
)
327419

328420
pred_arr, var_arr = mapper(template)(

0 commit comments

Comments
 (0)