Skip to content

Commit 1455f45

Browse files
committed
Add native GPU kernel for zonal.apply() with automatic CPU fallback
JIT-compile the user's scalar function as a CUDA device function and run it inside a CUDA kernel, avoiding the GPU→CPU→GPU round-trip. Non-compilable functions (e.g. using dict, str) automatically fall back to the existing CPU path via exception handling.
1 parent 6c6fa8a commit 1455f45

3 files changed

Lines changed: 259 additions & 30 deletions

File tree

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
236236

237237
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
238238
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
239-
| [Apply](xrspatial/zonal.py) | Applies a custom function to each zone in a classified raster | ✅️ | ✅️ | 🔄 | 🔄 |
240-
| [Crop](xrspatial/zonal.py) | Extracts the bounding rectangle of a specific zone | ✅️ | 🔄 | 🔄 | 🔄 |
239+
| [Apply](xrspatial/zonal.py) | Applies a custom function to each zone in a classified raster | ✅️ | ✅️ | ✅️ | ✅️ |
240+
| [Crop](xrspatial/zonal.py) | Extracts the bounding rectangle of a specific zone | ✅️ | ✅️ | ✅️ | ✅️ |
241241
| [Regions](xrspatial/zonal.py) | Identifies connected regions of non-zero cells | ✅️ | ✅️ | ✅️ | ✅️ |
242-
| [Trim](xrspatial/zonal.py) | Removes nodata border rows and columns from a raster | ✅️ | 🔄 | 🔄 | 🔄 |
242+
| [Trim](xrspatial/zonal.py) | Removes nodata border rows and columns from a raster | ✅️ | ✅️ | ✅️ | ✅️ |
243243
| [Zonal Statistics](xrspatial/zonal.py) | Computes summary statistics for a value raster within each zone | ✅️ | ✅️| ✅️ | 🔄 |
244244
| [Zonal Cross Tabulate](xrspatial/zonal.py) | Cross-tabulates agreement between two categorical rasters | ✅️ | ✅️| 🔄 | 🔄 |
245245

xrspatial/tests/test_dask_cupy_gaps.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
"""Tests for dask+cupy backends: perlin, terrain, crosstab, trim, crop."""
1+
"""Tests for dask+cupy backends: perlin, terrain, crosstab, trim, crop, apply."""
22

33
import numpy as np
44
import xarray as xr
55

66
from xrspatial import generate_terrain, perlin
77
from xrspatial.tests.general_checks import cuda_and_cupy_available, dask_array_available
88
from xrspatial.utils import has_cuda_and_cupy
9-
from xrspatial.zonal import crop, trim
9+
from xrspatial.zonal import apply, crop, trim
1010

1111

1212
def _make_raster(shape=(50, 50), backend='numpy', chunks=(10, 10)):
@@ -184,6 +184,18 @@ def test_trim_dask():
184184
np.testing.assert_array_equal(result.data.compute(), _TRIM_EXPECTED)
185185

186186

187+
@dask_array_available
188+
def test_trim_dask_lazy():
189+
"""trim() on a dask DataArray returns a dask-backed result (not computed)."""
190+
import dask.array as da
191+
192+
raster = xr.DataArray(
193+
da.from_array(_TRIM_ARR, chunks=(3, 2)), dims=['y', 'x'],
194+
)
195+
result = trim(raster, values=(0,))
196+
assert isinstance(result.data, da.Array)
197+
198+
187199
@cuda_and_cupy_available
188200
def test_trim_cupy():
189201
import cupy
@@ -239,6 +251,18 @@ def test_crop_dask():
239251
np.testing.assert_array_equal(result.data.compute(), _CROP_EXPECTED)
240252

241253

254+
@dask_array_available
255+
def test_crop_dask_lazy():
256+
"""crop() on a dask DataArray returns a dask-backed result (not computed)."""
257+
import dask.array as da
258+
259+
raster = xr.DataArray(
260+
da.from_array(_CROP_ARR, chunks=(3, 2)), dims=['y', 'x'],
261+
)
262+
result = crop(raster, raster, zones_ids=(1, 3))
263+
assert isinstance(result.data, da.Array)
264+
265+
242266
@cuda_and_cupy_available
243267
def test_crop_cupy():
244268
import cupy
@@ -262,3 +286,86 @@ def test_crop_dask_cupy():
262286
computed = result.data.compute()
263287
assert isinstance(computed, cupy.ndarray)
264288
np.testing.assert_array_equal(computed.get(), _CROP_EXPECTED)
289+
290+
291+
# ---- apply: cupy, dask+cupy, fallback ----
292+
293+
_APPLY_ZONES = np.array([
294+
[1, 1, 0, 2],
295+
[1, 1, 0, 2],
296+
[3, 3, 3, 2],
297+
], dtype=np.int64)
298+
299+
_APPLY_VALUES = np.array([
300+
[10.0, 20.0, 30.0, 40.0],
301+
[50.0, 60.0, 70.0, 80.0],
302+
[90.0, 100.0, 110.0, 120.0],
303+
], dtype=np.float64)
304+
305+
306+
def _double(x):
307+
return x * 2
308+
309+
310+
@cuda_and_cupy_available
311+
def test_apply_cupy():
312+
import cupy
313+
314+
zones_np = xr.DataArray(_APPLY_ZONES, dims=['y', 'x'])
315+
values_np = xr.DataArray(_APPLY_VALUES, dims=['y', 'x'])
316+
result_np = apply(zones_np, values_np, _double)
317+
318+
zones_cupy = xr.DataArray(cupy.asarray(_APPLY_ZONES), dims=['y', 'x'])
319+
values_cupy = xr.DataArray(cupy.asarray(_APPLY_VALUES), dims=['y', 'x'])
320+
result_cupy = apply(zones_cupy, values_cupy, _double)
321+
322+
assert isinstance(result_cupy.data, cupy.ndarray)
323+
np.testing.assert_allclose(result_cupy.data.get(), result_np.values)
324+
325+
326+
@cuda_and_cupy_available
327+
@dask_array_available
328+
def test_apply_dask_cupy():
329+
import cupy
330+
import dask.array as da
331+
332+
zones_np = xr.DataArray(_APPLY_ZONES, dims=['y', 'x'])
333+
values_np = xr.DataArray(_APPLY_VALUES, dims=['y', 'x'])
334+
result_np = apply(zones_np, values_np, _double)
335+
336+
zones_gpu = cupy.asarray(_APPLY_ZONES)
337+
values_gpu = cupy.asarray(_APPLY_VALUES)
338+
zones_dask = xr.DataArray(
339+
da.from_array(zones_gpu, chunks=(2, 2)), dims=['y', 'x'],
340+
)
341+
values_dask = xr.DataArray(
342+
da.from_array(values_gpu, chunks=(2, 2)), dims=['y', 'x'],
343+
)
344+
result = apply(zones_dask, values_dask, _double)
345+
346+
assert isinstance(result.data, da.Array)
347+
computed = result.data.compute()
348+
assert isinstance(computed, cupy.ndarray)
349+
np.testing.assert_allclose(computed.get(), result_np.values)
350+
351+
352+
@cuda_and_cupy_available
353+
def test_apply_cupy_fallback():
354+
"""A func that CUDA can't compile still works via CPU fallback."""
355+
import cupy
356+
357+
lookup = {10.0: 100.0, 50.0: 500.0}
358+
359+
def _dict_func(x):
360+
return lookup.get(x, x)
361+
362+
zones_np = xr.DataArray(_APPLY_ZONES, dims=['y', 'x'])
363+
values_np = xr.DataArray(_APPLY_VALUES, dims=['y', 'x'])
364+
result_np = apply(zones_np, values_np, _dict_func)
365+
366+
zones_cupy = xr.DataArray(cupy.asarray(_APPLY_ZONES), dims=['y', 'x'])
367+
values_cupy = xr.DataArray(cupy.asarray(_APPLY_VALUES), dims=['y', 'x'])
368+
result_cupy = apply(zones_cupy, values_cupy, _dict_func)
369+
370+
assert isinstance(result_cupy.data, cupy.ndarray)
371+
np.testing.assert_allclose(result_cupy.data.get(), result_np.values)

xrspatial/zonal.py

Lines changed: 147 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
# 3rd-party
1010
try:
11+
import dask
1112
import dask.array as da
1213
except ImportError:
14+
dask = None
1315
da = None
1416

1517
try:
@@ -35,7 +37,7 @@ class cupy(object):
3537

3638
# local modules
3739
from xrspatial.utils import (
38-
ArrayTypeFunctionMapping, _validate_raster, has_cuda_and_cupy,
40+
ArrayTypeFunctionMapping, _validate_raster, cuda_args, has_cuda_and_cupy,
3941
is_cupy_array, is_dask_cupy,
4042
ngjit, not_implemented_func, validate_arrays,
4143
)
@@ -1232,9 +1234,51 @@ def _apply_numpy(zones_data, values_data, func, nodata):
12321234
return out
12331235

12341236

1237+
def _make_apply_kernel(func):
1238+
"""Build a CUDA kernel that applies *func* element-wise."""
1239+
from numba import cuda as nb_cuda
1240+
1241+
device_func = nb_cuda.jit(device=True)(func)
1242+
1243+
@nb_cuda.jit
1244+
def _kernel(zones, values, out, nodata_val, has_nodata):
1245+
y, x = nb_cuda.grid(2)
1246+
if y < zones.shape[0] and x < zones.shape[1]:
1247+
if has_nodata and zones[y, x] == nodata_val:
1248+
return
1249+
out[y, x] = device_func(values[y, x])
1250+
1251+
return _kernel
1252+
1253+
1254+
def _apply_cupy_gpu(zones_data, values_data, kernel, nodata):
1255+
"""Run the CUDA apply kernel on cupy arrays."""
1256+
out = values_data.copy()
1257+
has_nodata = nodata is not None
1258+
nodata_val = nodata if has_nodata else 0
1259+
1260+
griddim, blockdim = cuda_args(values_data.shape[:2])
1261+
1262+
if values_data.ndim == 2:
1263+
kernel[griddim, blockdim](
1264+
zones_data, values_data, out, nodata_val, has_nodata,
1265+
)
1266+
else:
1267+
for k in range(values_data.shape[2]):
1268+
kernel[griddim, blockdim](
1269+
zones_data, values_data[:, :, k], out[:, :, k],
1270+
nodata_val, has_nodata,
1271+
)
1272+
return out
1273+
1274+
12351275
def _apply_cupy(zones_data, values_data, func, nodata):
1236-
result_np = _apply_numpy(zones_data.get(), values_data.get(), func, nodata)
1237-
return cupy.asarray(result_np)
1276+
try:
1277+
kernel = _make_apply_kernel(func)
1278+
return _apply_cupy_gpu(zones_data, values_data, kernel, nodata)
1279+
except Exception:
1280+
result_np = _apply_numpy(zones_data.get(), values_data.get(), func, nodata)
1281+
return cupy.asarray(result_np)
12381282

12391283

12401284
def _apply_dask_numpy(zones_data, values_data, func, nodata):
@@ -1258,16 +1302,43 @@ def _chunk_fn(zones_chunk, values_chunk):
12581302

12591303

12601304
def _apply_dask_cupy(zones_data, values_data, func, nodata):
1261-
zones_cpu = zones_data.map_blocks(
1262-
lambda x: x.get(), dtype=zones_data.dtype, meta=np.array(()),
1263-
)
1264-
values_cpu = values_data.map_blocks(
1265-
lambda x: x.get(), dtype=values_data.dtype, meta=np.array(()),
1266-
)
1267-
result = _apply_dask_numpy(zones_cpu, values_cpu, func, nodata)
1268-
return result.map_blocks(
1269-
cupy.asarray, dtype=result.dtype, meta=cupy.array(()),
1270-
)
1305+
# Try GPU: build kernel once, reuse across all chunks
1306+
try:
1307+
kernel = _make_apply_kernel(func)
1308+
gpu_ok = True
1309+
except Exception:
1310+
gpu_ok = False
1311+
1312+
if gpu_ok:
1313+
def _chunk_fn(zones_chunk, values_chunk):
1314+
try:
1315+
return _apply_cupy_gpu(zones_chunk, values_chunk, kernel, nodata)
1316+
except Exception:
1317+
result_np = _apply_numpy(
1318+
zones_chunk.get(), values_chunk.get(), func, nodata,
1319+
)
1320+
return cupy.asarray(result_np)
1321+
else:
1322+
def _chunk_fn(zones_chunk, values_chunk):
1323+
result_np = _apply_numpy(
1324+
zones_chunk.get(), values_chunk.get(), func, nodata,
1325+
)
1326+
return cupy.asarray(result_np)
1327+
1328+
if values_data.ndim == 2:
1329+
return da.map_blocks(
1330+
_chunk_fn, zones_data, values_data,
1331+
dtype=values_data.dtype, meta=cupy.array(()),
1332+
)
1333+
else:
1334+
layers = []
1335+
for k in range(values_data.shape[2]):
1336+
layer = values_data[:, :, k].rechunk(zones_data.chunks)
1337+
layers.append(da.map_blocks(
1338+
_chunk_fn, zones_data, layer,
1339+
dtype=values_data.dtype, meta=cupy.array(()),
1340+
))
1341+
return da.stack(layers, axis=2)
12711342

12721343

12731344
def apply(
@@ -1783,6 +1854,35 @@ def _trim(data, excludes):
17831854
return top, bottom, left, right
17841855

17851856

1857+
def _trim_bounds_dask(data, excludes):
1858+
"""Find trim bounds using lazy dask reductions (O(rows+cols) memory)."""
1859+
excluded = da.zeros_like(data, dtype=bool)
1860+
for v in excludes:
1861+
if isinstance(v, float) and np.isnan(v):
1862+
excluded = excluded | da.isnan(data)
1863+
else:
1864+
excluded = excluded | (data == v)
1865+
1866+
all_excl_rows = excluded.all(axis=1)
1867+
all_excl_cols = excluded.all(axis=0)
1868+
row_mask, col_mask = dask.compute(all_excl_rows, all_excl_cols)
1869+
1870+
# dask+cupy computes to cupy arrays; move to numpy for np.where
1871+
if is_cupy_array(row_mask):
1872+
row_mask = row_mask.get()
1873+
if is_cupy_array(col_mask):
1874+
col_mask = col_mask.get()
1875+
1876+
data_rows = np.where(~np.asarray(row_mask))[0]
1877+
data_cols = np.where(~np.asarray(col_mask))[0]
1878+
1879+
if len(data_rows) == 0 or len(data_cols) == 0:
1880+
return 0, -1, 0, -1 # empty slice
1881+
1882+
return (int(data_rows[0]), int(data_rows[-1]),
1883+
int(data_cols[0]), int(data_cols[-1]))
1884+
1885+
17861886
def trim(
17871887
raster: xr.DataArray,
17881888
values: Union[list, tuple] = (np.nan,),
@@ -1891,15 +1991,13 @@ def trim(
18911991
_validate_raster(raster, func_name='trim', name='raster', ndim=2)
18921992

18931993
data = raster.data
1894-
# _trim needs element access; materialise to numpy for non-numpy backends
1895-
if is_cupy_array(data):
1896-
data = data.get()
1897-
elif has_dask_array() and isinstance(data, da.Array):
1898-
data = data.compute()
1994+
if has_dask_array() and isinstance(data, da.Array):
1995+
top, bottom, left, right = _trim_bounds_dask(data, values)
1996+
else:
18991997
if is_cupy_array(data):
19001998
data = data.get()
1999+
top, bottom, left, right = _trim(data, values)
19012000

1902-
top, bottom, left, right = _trim(data, values)
19032001
arr = raster[top: bottom + 1, left: right + 1]
19042002
arr.name = name
19052003
return arr
@@ -2003,6 +2101,32 @@ def _crop(data, values):
20032101
return top, bottom, left, right
20042102

20052103

2104+
def _crop_bounds_dask(data, target_values):
2105+
"""Find crop bounds using lazy dask reductions (O(rows+cols) memory)."""
2106+
matched = da.zeros_like(data, dtype=bool)
2107+
for v in target_values:
2108+
matched = matched | (data == v)
2109+
2110+
any_match_rows = matched.any(axis=1)
2111+
any_match_cols = matched.any(axis=0)
2112+
row_mask, col_mask = dask.compute(any_match_rows, any_match_cols)
2113+
2114+
# dask+cupy computes to cupy arrays; move to numpy for np.where
2115+
if is_cupy_array(row_mask):
2116+
row_mask = row_mask.get()
2117+
if is_cupy_array(col_mask):
2118+
col_mask = col_mask.get()
2119+
2120+
match_rows = np.where(np.asarray(row_mask))[0]
2121+
match_cols = np.where(np.asarray(col_mask))[0]
2122+
2123+
if len(match_rows) == 0 or len(match_cols) == 0:
2124+
return 0, data.shape[0] - 1, 0, data.shape[1] - 1
2125+
2126+
return (int(match_rows[0]), int(match_rows[-1]),
2127+
int(match_cols[0]), int(match_cols[-1]))
2128+
2129+
20062130
def crop(
20072131
zones: xr.DataArray,
20082132
values: xr.DataArray,
@@ -2123,15 +2247,13 @@ def crop(
21232247
_validate_raster(values, func_name='crop', name='values', ndim=2)
21242248

21252249
data = zones.data
2126-
# _crop is @ngjit; materialise to numpy for non-numpy backends
2127-
if is_cupy_array(data):
2128-
data = data.get()
2129-
elif has_dask_array() and isinstance(data, da.Array):
2130-
data = data.compute()
2250+
if has_dask_array() and isinstance(data, da.Array):
2251+
top, bottom, left, right = _crop_bounds_dask(data, zones_ids)
2252+
else:
21312253
if is_cupy_array(data):
21322254
data = data.get()
2255+
top, bottom, left, right = _crop(data, zones_ids)
21332256

2134-
top, bottom, left, right = _crop(data, zones_ids)
21352257
arr = values[top: bottom + 1, left: right + 1]
21362258
arr.name = name
21372259
return arr

0 commit comments

Comments
 (0)