Skip to content

Commit 25309d8

Browse files
authored
Add interpolation tools: IDW, Kriging, and Spline (#932) (#934)
Three new functions for converting scattered point observations (x, y, z arrays) into gridded xarray DataArrays on a user-specified template grid. - IDW: all-points numba JIT kernel and k-nearest via scipy cKDTree, with CUDA kernel for GPU; all 4 backends - Kriging: ordinary kriging with automatic variogram fitting (spherical, exponential, gaussian models); numpy and dask backends - Spline: thin plate spline with CPU system solve and parallelised grid evaluation; all 4 backends Also adds .xrs accessor methods and 25 tests covering correctness, cross-backend consistency, validation, and edge cases.
1 parent 19f9539 commit 25309d8

File tree

9 files changed

+1315
-0
lines changed

9 files changed

+1315
-0
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,16 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
282282

283283
-----------
284284

285+
### **Interpolation**
286+
287+
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
288+
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
289+
| [IDW](xrspatial/interpolate/_idw.py) | Inverse Distance Weighting from scattered points to a raster grid | ✅️ | ✅️ | ✅️ | ✅️ |
290+
| [Kriging](xrspatial/interpolate/_kriging.py) | Ordinary Kriging with automatic variogram fitting (spherical, exponential, gaussian) | ✅️ | ✅️ | | |
291+
| [Spline](xrspatial/interpolate/_spline.py) | Thin Plate Spline interpolation with optional smoothing | ✅️ | ✅️ | ✅️ | ✅️ |
292+
293+
-----------
294+
285295
### **Zonal**
286296

287297
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |

xrspatial/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from xrspatial.emerging_hotspots import emerging_hotspots # noqa
1717
from xrspatial.erosion import erode # noqa
1818
from xrspatial.fill import fill # noqa
19+
from xrspatial.interpolate import idw # noqa
20+
from xrspatial.interpolate import kriging # noqa
21+
from xrspatial.interpolate import spline # noqa
1922
from xrspatial.fire import burn_severity_class # noqa
2023
from xrspatial.fire import dnbr # noqa
2124
from xrspatial.fire import fireline_intensity # noqa

xrspatial/accessor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,20 @@ def mahalanobis(self, other_bands, **kwargs):
251251
from .mahalanobis import mahalanobis
252252
return mahalanobis([self._obj] + list(other_bands), **kwargs)
253253

254+
# ---- Interpolation ----
255+
256+
def idw(self, x, y, z, **kwargs):
257+
from .interpolate import idw
258+
return idw(x, y, z, self._obj, **kwargs)
259+
260+
def kriging(self, x, y, z, **kwargs):
261+
from .interpolate import kriging
262+
return kriging(x, y, z, self._obj, **kwargs)
263+
264+
def spline(self, x, y, z, **kwargs):
265+
from .interpolate import spline
266+
return spline(x, y, z, self._obj, **kwargs)
267+
254268
# ---- Raster to vector ----
255269

256270
def polygonize(self, **kwargs):

xrspatial/interpolate/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Interpolation tools for scattered-point-to-raster conversion."""
2+
3+
from xrspatial.interpolate._idw import idw # noqa: F401
4+
from xrspatial.interpolate._kriging import kriging # noqa: F401
5+
from xrspatial.interpolate._spline import spline # noqa: F401

xrspatial/interpolate/_idw.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
"""Inverse Distance Weighting (IDW) interpolation."""
2+
3+
from __future__ import annotations
4+
5+
import math
6+
7+
import numpy as np
8+
import xarray as xr
9+
from numba import cuda
10+
11+
from xrspatial.utils import (
12+
ArrayTypeFunctionMapping,
13+
_validate_raster,
14+
_validate_scalar,
15+
cuda_args,
16+
ngjit,
17+
)
18+
19+
from ._validation import extract_grid_coords, validate_points
20+
21+
try:
22+
import cupy
23+
except ImportError:
24+
cupy = None
25+
26+
try:
27+
import dask.array as da
28+
except ImportError:
29+
da = None
30+
31+
32+
# ---------------------------------------------------------------------------
33+
# CPU all-points kernel (numba JIT)
34+
# ---------------------------------------------------------------------------
35+
36+
@ngjit
37+
def _idw_cpu_allpoints(x_pts, y_pts, z_pts, n_pts,
38+
x_grid, y_grid, power, fill_value):
39+
ny = y_grid.shape[0]
40+
nx = x_grid.shape[0]
41+
out = np.empty((ny, nx), dtype=np.float64)
42+
43+
for i in range(ny):
44+
for j in range(nx):
45+
gx = x_grid[j]
46+
gy = y_grid[i]
47+
w_sum = 0.0
48+
wz_sum = 0.0
49+
exact = False
50+
exact_val = 0.0
51+
52+
for p in range(n_pts):
53+
dx = gx - x_pts[p]
54+
dy = gy - y_pts[p]
55+
d2 = dx * dx + dy * dy
56+
if d2 == 0.0:
57+
exact = True
58+
exact_val = z_pts[p]
59+
break
60+
w = 1.0 / (d2 ** (power * 0.5))
61+
w_sum += w
62+
wz_sum += w * z_pts[p]
63+
64+
if exact:
65+
out[i, j] = exact_val
66+
elif w_sum > 0.0:
67+
out[i, j] = wz_sum / w_sum
68+
else:
69+
out[i, j] = fill_value
70+
71+
return out
72+
73+
74+
# ---------------------------------------------------------------------------
75+
# CPU k-nearest (scipy cKDTree)
76+
# ---------------------------------------------------------------------------
77+
78+
def _idw_knearest_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
79+
power, k, fill_value):
80+
from scipy.spatial import cKDTree
81+
82+
pts = np.column_stack([x_pts, y_pts])
83+
tree = cKDTree(pts)
84+
85+
gx, gy = np.meshgrid(x_grid, y_grid)
86+
query_pts = np.column_stack([gx.ravel(), gy.ravel()])
87+
dists, indices = tree.query(query_pts, k=k)
88+
89+
if k == 1:
90+
dists = dists[:, np.newaxis]
91+
indices = indices[:, np.newaxis]
92+
93+
exact = dists == 0.0
94+
dists_safe = np.where(exact, 1.0, dists)
95+
weights = np.where(exact, 1.0, 1.0 / (dists_safe ** power))
96+
97+
has_exact = np.any(exact, axis=1)
98+
weights[has_exact] = np.where(exact[has_exact], 1.0, 0.0)
99+
100+
z_vals = z_pts[indices]
101+
wz = np.sum(weights * z_vals, axis=1)
102+
w_total = np.sum(weights, axis=1)
103+
result = np.where(w_total > 0, wz / w_total, fill_value)
104+
return result.reshape(len(y_grid), len(x_grid))
105+
106+
107+
# ---------------------------------------------------------------------------
108+
# Numpy backend
109+
# ---------------------------------------------------------------------------
110+
111+
def _idw_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
112+
power, k, fill_value, template_data):
113+
if k is not None:
114+
return _idw_knearest_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
115+
power, k, fill_value)
116+
return _idw_cpu_allpoints(x_pts, y_pts, z_pts, len(x_pts),
117+
x_grid, y_grid, power, fill_value)
118+
119+
120+
# ---------------------------------------------------------------------------
121+
# CUDA kernel (all-points only)
122+
# ---------------------------------------------------------------------------
123+
124+
@cuda.jit
125+
def _idw_cuda_kernel(x_pts, y_pts, z_pts, n_pts,
126+
x_grid, y_grid, power, fill_value, out):
127+
i, j = cuda.grid(2)
128+
if i < out.shape[0] and j < out.shape[1]:
129+
gx = x_grid[j]
130+
gy = y_grid[i]
131+
w_sum = 0.0
132+
wz_sum = 0.0
133+
exact = False
134+
exact_val = 0.0
135+
136+
for p in range(n_pts):
137+
dx = gx - x_pts[p]
138+
dy = gy - y_pts[p]
139+
d2 = dx * dx + dy * dy
140+
if d2 == 0.0:
141+
exact = True
142+
exact_val = z_pts[p]
143+
break
144+
w = 1.0 / (d2 ** (power * 0.5))
145+
w_sum += w
146+
wz_sum += w * z_pts[p]
147+
148+
if exact:
149+
out[i, j] = exact_val
150+
elif w_sum > 0.0:
151+
out[i, j] = wz_sum / w_sum
152+
else:
153+
out[i, j] = fill_value
154+
155+
156+
# ---------------------------------------------------------------------------
157+
# CuPy backend
158+
# ---------------------------------------------------------------------------
159+
160+
def _idw_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
161+
power, k, fill_value, template_data):
162+
if k is not None:
163+
raise NotImplementedError(
164+
"idw(): k-nearest mode is not supported on GPU. "
165+
"Use k=None for all-points IDW on GPU, or use a "
166+
"numpy/dask+numpy backend."
167+
)
168+
169+
x_gpu = cupy.asarray(x_pts)
170+
y_gpu = cupy.asarray(y_pts)
171+
z_gpu = cupy.asarray(z_pts)
172+
xg_gpu = cupy.asarray(x_grid)
173+
yg_gpu = cupy.asarray(y_grid)
174+
175+
ny, nx = len(y_grid), len(x_grid)
176+
out = cupy.full((ny, nx), fill_value, dtype=np.float64)
177+
178+
griddim, blockdim = cuda_args((ny, nx))
179+
_idw_cuda_kernel[griddim, blockdim](
180+
x_gpu, y_gpu, z_gpu, len(x_pts),
181+
xg_gpu, yg_gpu, power, fill_value, out,
182+
)
183+
return out
184+
185+
186+
# ---------------------------------------------------------------------------
187+
# Dask + numpy backend
188+
# ---------------------------------------------------------------------------
189+
190+
def _idw_dask_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
191+
power, k, fill_value, template_data):
192+
193+
def _chunk(block, block_info=None):
194+
if block_info is None:
195+
return block
196+
loc = block_info[0]['array-location']
197+
y_sl = y_grid[loc[0][0]:loc[0][1]]
198+
x_sl = x_grid[loc[1][0]:loc[1][1]]
199+
return _idw_numpy(x_pts, y_pts, z_pts, x_sl, y_sl,
200+
power, k, fill_value, None)
201+
202+
return da.map_blocks(_chunk, template_data, dtype=np.float64)
203+
204+
205+
# ---------------------------------------------------------------------------
206+
# Dask + cupy backend
207+
# ---------------------------------------------------------------------------
208+
209+
def _idw_dask_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
210+
power, k, fill_value, template_data):
211+
if k is not None:
212+
raise NotImplementedError(
213+
"idw(): k-nearest mode is not supported on GPU."
214+
)
215+
216+
def _chunk(block, block_info=None):
217+
if block_info is None:
218+
return block
219+
loc = block_info[0]['array-location']
220+
y_sl = y_grid[loc[0][0]:loc[0][1]]
221+
x_sl = x_grid[loc[1][0]:loc[1][1]]
222+
return _idw_cupy(x_pts, y_pts, z_pts, x_sl, y_sl,
223+
power, None, fill_value, None)
224+
225+
return da.map_blocks(
226+
_chunk, template_data, dtype=np.float64,
227+
meta=cupy.array((), dtype=np.float64),
228+
)
229+
230+
231+
# ---------------------------------------------------------------------------
232+
# Public API
233+
# ---------------------------------------------------------------------------
234+
235+
def idw(x, y, z, template, power=2.0, k=None,
236+
fill_value=np.nan, name='idw'):
237+
"""Inverse Distance Weighting interpolation.
238+
239+
Parameters
240+
----------
241+
x, y, z : array-like
242+
Coordinates and values of scattered input points.
243+
template : xr.DataArray
244+
2-D DataArray whose grid defines the output raster.
245+
power : float, default 2.0
246+
Distance weighting exponent.
247+
k : int or None, default None
248+
Number of nearest neighbours. ``None`` uses all points
249+
(numba JIT); an integer uses ``scipy.spatial.cKDTree``
250+
(CPU only).
251+
fill_value : float, default np.nan
252+
Value for pixels with zero total weight.
253+
name : str, default 'idw'
254+
Name of the output DataArray.
255+
256+
Returns
257+
-------
258+
xr.DataArray
259+
"""
260+
_validate_raster(template, func_name='idw', name='template')
261+
x_arr, y_arr, z_arr = validate_points(x, y, z, func_name='idw')
262+
_validate_scalar(power, func_name='idw', name='power',
263+
min_val=0.0, min_exclusive=True)
264+
265+
if k is not None:
266+
_validate_scalar(k, func_name='idw', name='k',
267+
dtype=int, min_val=1)
268+
k = min(k, len(x_arr))
269+
270+
x_grid, y_grid = extract_grid_coords(template, func_name='idw')
271+
272+
mapper = ArrayTypeFunctionMapping(
273+
numpy_func=_idw_numpy,
274+
cupy_func=_idw_cupy,
275+
dask_func=_idw_dask_numpy,
276+
dask_cupy_func=_idw_dask_cupy,
277+
)
278+
279+
out = mapper(template)(
280+
x_arr, y_arr, z_arr, x_grid, y_grid,
281+
power, k, fill_value, template.data,
282+
)
283+
284+
return xr.DataArray(
285+
out, name=name,
286+
coords=template.coords,
287+
dims=template.dims,
288+
attrs=template.attrs,
289+
)

0 commit comments

Comments
 (0)