Skip to content

Commit 8c2b75a

Browse files
authored
Fixes #885: add dask+cupy backends for focal tools (#896)
* Fixes #885: add dask+cupy backends for focal mean, apply, and focal_stats Wire existing cupy CUDA kernels through map_overlap for dask+cupy inputs so GPU-cluster users no longer need to .compute() first, avoiding OOM on large datasets. * Fixes #883: remove unnecessary .copy() from broadcast views in _extract_latlon_coords All downstream consumers are read-only, so broadcast views suffice. Avoids materializing two full (H, W) float64 arrays (~16 bytes/pixel).
1 parent fd78352 commit 8c2b75a

File tree

4 files changed

+151
-24
lines changed

4 files changed

+151
-24
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,10 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
155155

156156
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
157157
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
158-
| [Apply](xrspatial/focal.py) | Applies a custom function over a sliding neighborhood window | ✅️ | ✅️ | | |
158+
| [Apply](xrspatial/focal.py) | Applies a custom function over a sliding neighborhood window | ✅️ | ✅️ | ✅️ | ✅️ |
159159
| [Hotspots](xrspatial/focal.py) | Identifies statistically significant spatial clusters using Getis-Ord Gi* | ✅️ | ✅️ | ✅️ | |
160-
| [Mean](xrspatial/focal.py) | Computes the mean value within a sliding neighborhood window | ✅️ | ✅️ | ✅️ | |
161-
| [Focal Statistics](xrspatial/focal.py) | Computes summary statistics over a sliding neighborhood window | ✅️ | ✅️ | ✅️ | |
160+
| [Mean](xrspatial/focal.py) | Computes the mean value within a sliding neighborhood window | ✅️ | ✅️ | ✅️ | ✅️ |
161+
| [Focal Statistics](xrspatial/focal.py) | Computes summary statistics over a sliding neighborhood window | ✅️ | ✅️ | ✅️ | ✅️ |
162162

163163
-------
164164

xrspatial/focal.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ def _mean_dask_numpy(data, excludes, boundary='nan'):
7777
return out
7878

7979

80+
def _mean_dask_cupy(data, excludes, boundary='nan'):
81+
data = data.astype(cupy.float32)
82+
_func = partial(_mean_cupy, excludes=excludes)
83+
out = data.map_overlap(_func,
84+
depth=(1, 1),
85+
boundary=_boundary_to_dask(boundary, is_cupy=True),
86+
meta=cupy.array(()))
87+
return out
88+
89+
8090
@cuda.jit
8191
def _mean_gpu(data, excludes, out):
8292
# 1. Get coordinates: x is Column, y is Row
@@ -161,8 +171,7 @@ def _mean(data, excludes, boundary='nan'):
161171
numpy_func=partial(_mean_numpy_boundary, boundary=boundary),
162172
cupy_func=_mean_cupy,
163173
dask_func=partial(_mean_dask_numpy, boundary=boundary),
164-
dask_cupy_func=lambda *args: not_implemented_func(
165-
*args, messages='mean() does not support dask with cupy backed DataArray.'), # noqa
174+
dask_cupy_func=partial(_mean_dask_cupy, boundary=boundary),
166175
)
167176
out = mapper(agg)(agg.data, excludes)
168177
return out
@@ -370,6 +379,22 @@ def _apply_dask_numpy(data, kernel, func, boundary='nan'):
370379
return out
371380

372381

382+
def _apply_cupy(data, kernel, func):
383+
return _focal_stats_func_cupy(data.astype(cupy.float32), kernel, func)
384+
385+
386+
def _apply_dask_cupy(data, kernel, func, boundary='nan'):
387+
data = data.astype(cupy.float32)
388+
pad_h = kernel.shape[0] // 2
389+
pad_w = kernel.shape[1] // 2
390+
_func = partial(_focal_stats_func_cupy, kernel=kernel, func=func)
391+
out = data.map_overlap(_func,
392+
depth=(pad_h, pad_w),
393+
boundary=_boundary_to_dask(boundary, is_cupy=True),
394+
meta=cupy.array(()))
395+
return out
396+
397+
373398
def apply(raster, kernel, func=_calc_mean, name='focal_apply', boundary='nan'):
374399
"""
375400
Returns custom function applied array using a user-created window.
@@ -378,11 +403,14 @@ def apply(raster, kernel, func=_calc_mean, name='focal_apply', boundary='nan'):
378403
----------
379404
raster : xarray.DataArray
380405
2D array of input values to be filtered. Can be a NumPy backed,
381-
or Dask with NumPy backed DataArray.
406+
CuPy backed, Dask with NumPy backed, or Dask with CuPy backed
407+
DataArray.
382408
kernel : numpy.ndarray
383409
2D array where values of 1 indicate the kernel.
384410
func : callable, default=xrspatial.focal._calc_mean
385411
Function which takes an input array and returns an array.
412+
For cupy and dask+cupy backends the function must be a
413+
``@cuda.jit`` global kernel with signature ``(data, kernel, out)``.
386414
boundary : str, default='nan'
387415
How to handle edges where the kernel extends beyond the raster.
388416
``'nan'`` -- fill missing neighbours with NaN (default).
@@ -496,11 +524,9 @@ def apply(raster, kernel, func=_calc_mean, name='focal_apply', boundary='nan'):
496524
# the function func must be a @ngjit
497525
mapper = ArrayTypeFunctionMapping(
498526
numpy_func=partial(_apply_numpy_boundary, boundary=boundary),
499-
cupy_func=lambda *args: not_implemented_func(
500-
*args, messages='apply() does not support cupy backed DataArray.'),
527+
cupy_func=_apply_cupy,
501528
dask_func=partial(_apply_dask_numpy, boundary=boundary),
502-
dask_cupy_func=lambda *args: not_implemented_func(
503-
*args, messages='apply() does not support dask with cupy backed DataArray.'),
529+
dask_cupy_func=partial(_apply_dask_cupy, boundary=boundary),
504530
)
505531
out = mapper(raster)(raster.data, kernel, func)
506532
result = DataArray(out,
@@ -818,6 +844,32 @@ def _focal_stats_cupy(agg, kernel, stats_funcs):
818844
return stats
819845

820846

847+
def _focal_stats_dask_cupy(agg, kernel, stats_funcs, boundary='nan'):
848+
_stats_cuda_mapper = dict(
849+
mean=_focal_mean_cuda, sum=_focal_sum_cuda,
850+
range=_focal_range_cuda, max=_focal_max_cuda,
851+
min=_focal_min_cuda, std=_focal_std_cuda, var=_focal_var_cuda,
852+
)
853+
pad_h = kernel.shape[0] // 2
854+
pad_w = kernel.shape[1] // 2
855+
dask_bnd = _boundary_to_dask(boundary, is_cupy=True)
856+
857+
stats_aggs = []
858+
for stat_name in stats_funcs:
859+
cuda_kernel = _stats_cuda_mapper[stat_name]
860+
_func = partial(_focal_stats_func_cupy, kernel=kernel, func=cuda_kernel)
861+
data = agg.data.astype(cupy.float32)
862+
stats_data = data.map_overlap(
863+
_func, depth=(pad_h, pad_w),
864+
boundary=dask_bnd, meta=cupy.array(()))
865+
stats_agg = xr.DataArray(
866+
stats_data, dims=agg.dims, coords=agg.coords, attrs=agg.attrs)
867+
stats_aggs.append(stats_agg)
868+
stats = xr.concat(stats_aggs,
869+
pd.Index(stats_funcs, name='stats', dtype=object))
870+
return stats
871+
872+
821873
def _focal_stats_cpu(agg, kernel, stats_funcs, boundary='nan'):
822874
_function_mapping = {
823875
'mean': _calc_mean,
@@ -852,7 +904,8 @@ def focal_stats(agg,
852904
----------
853905
agg : xarray.DataArray
854906
2D array of input values to be analysed. Can be a NumPy backed,
855-
Cupy backed, or Dask with NumPy backed DataArray.
907+
CuPy backed, Dask with NumPy backed, or Dask with CuPy backed
908+
DataArray.
856909
kernel : numpy.array
857910
2D array where values of 1 indicate the kernel.
858911
stats_funcs: list of string
@@ -920,8 +973,7 @@ def focal_stats(agg,
920973
numpy_func=partial(_focal_stats_cpu, boundary=boundary),
921974
cupy_func=_focal_stats_cupy,
922975
dask_func=partial(_focal_stats_cpu, boundary=boundary),
923-
dask_cupy_func=lambda *args: not_implemented_func(
924-
*args, messages='focal_stats() does not support dask with cupy backed DataArray.'),
976+
dask_cupy_func=partial(_focal_stats_dask_cupy, boundary=boundary),
925977
)
926978
result = mapper(agg)(agg, kernel, stats_funcs)
927979
return result

xrspatial/tests/test_focal.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,24 @@ def test_mean_transfer_function_gpu_equals_cpu():
9393

9494
@dask_array_available
9595
@cuda_and_cupy_available
96-
def test_mean_transfer_dask_gpu_raise_not_implemented():
96+
def test_mean_transfer_function_dask_gpu():
9797

9898
import cupy
9999

100-
# cupy case
101-
cupy_agg = xr.DataArray(cupy.asarray(data_random))
102-
cupy_mean = mean(cupy_agg)
103-
general_output_checks(cupy_agg, cupy_mean)
100+
# numpy reference
101+
numpy_agg = xr.DataArray(data_random)
102+
numpy_mean = mean(numpy_agg)
104103

105-
# dask + cupy case not implemented
104+
# dask + cupy case
106105
dask_cupy_agg = xr.DataArray(
107106
da.from_array(cupy.asarray(data_random), chunks=(3, 3))
108107
)
109-
with pytest.raises(NotImplementedError) as e_info:
110-
mean(dask_cupy_agg)
111-
assert e_info
108+
dask_cupy_mean = mean(dask_cupy_agg)
109+
general_output_checks(dask_cupy_agg, dask_cupy_mean)
110+
111+
np.testing.assert_allclose(
112+
numpy_mean.data, dask_cupy_mean.data.compute().get(),
113+
equal_nan=True, rtol=1e-4)
112114

113115

114116
@pytest.fixture
@@ -351,6 +353,53 @@ def test_apply_dask_numpy(data_apply):
351353
general_output_checks(dask_numpy_agg, dask_numpy_apply, expected_result)
352354

353355

356+
@cuda_and_cupy_available
357+
def test_apply_cupy(data_apply):
358+
from xrspatial.focal import _focal_mean_cuda
359+
360+
data, kernel, expected_result_zero = data_apply
361+
# numpy reference using _calc_mean
362+
numpy_agg = create_test_raster(data)
363+
numpy_apply = apply(numpy_agg, kernel)
364+
365+
# cupy case with equivalent CUDA kernel
366+
cupy_agg = create_test_raster(data, backend='cupy')
367+
cupy_apply = apply(cupy_agg, kernel, _focal_mean_cuda)
368+
general_output_checks(cupy_agg, cupy_apply)
369+
370+
np.testing.assert_allclose(
371+
numpy_apply.data, cupy_apply.data.get(),
372+
equal_nan=True, rtol=1e-4)
373+
374+
375+
@dask_array_available
376+
@cuda_and_cupy_available
377+
def test_apply_dask_cupy():
378+
from xrspatial.focal import _focal_mean_cuda
379+
380+
# Use a larger array so chunk interiors are meaningful
381+
rng = np.random.default_rng(42)
382+
data = rng.random((20, 24)).astype(np.float64)
383+
kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
384+
385+
# cupy reference (same CUDA kernel)
386+
cupy_agg = create_test_raster(data, backend='cupy')
387+
cupy_apply = apply(cupy_agg, kernel, _focal_mean_cuda)
388+
389+
# dask + cupy case
390+
dask_cupy_agg = create_test_raster(data, backend='dask+cupy', chunks=(10, 12))
391+
dask_cupy_apply = apply(dask_cupy_agg, kernel, _focal_mean_cuda)
392+
general_output_checks(dask_cupy_agg, dask_cupy_apply, verify_attrs=False)
393+
394+
# Compare interior (boundary='nan' causes edge differences between
395+
# cupy single-GPU bounds-clamping and dask map_overlap NaN-padding)
396+
pad = kernel.shape[0] // 2
397+
np.testing.assert_allclose(
398+
cupy_apply.data[pad:-pad, pad:-pad].get(),
399+
dask_cupy_apply.data[pad:-pad, pad:-pad].compute().get(),
400+
equal_nan=True, rtol=1e-4)
401+
402+
354403
@pytest.fixture
355404
def data_focal_stats():
356405
data = np.arange(16).reshape(4, 4)
@@ -424,6 +473,32 @@ def test_focal_stats_gpu(data_focal_stats):
424473
)
425474

426475

476+
@dask_array_available
477+
@cuda_and_cupy_available
478+
def test_focal_stats_dask_cupy():
479+
# Use larger data so chunk interiors are meaningful
480+
rng = np.random.default_rng(42)
481+
data = rng.random((20, 24)).astype(np.float64)
482+
kernel = custom_kernel(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]))
483+
484+
# cupy reference
485+
cupy_agg = create_test_raster(data, backend='cupy')
486+
cupy_focalstats = focal_stats(cupy_agg, kernel)
487+
488+
# dask + cupy case
489+
dask_cupy_agg = create_test_raster(data, backend='dask+cupy', chunks=(10, 12))
490+
dask_cupy_focalstats = focal_stats(dask_cupy_agg, kernel)
491+
assert dask_cupy_focalstats.ndim == 3
492+
493+
# Compare interior (boundary='nan' causes edge differences between
494+
# cupy single-GPU bounds-clamping and dask map_overlap NaN-padding)
495+
pad = kernel.shape[0] // 2
496+
np.testing.assert_allclose(
497+
cupy_focalstats.data[:, pad:-pad, pad:-pad].get(),
498+
dask_cupy_focalstats.data[:, pad:-pad, pad:-pad].compute().get(),
499+
equal_nan=True, rtol=1e-4)
500+
501+
427502
@pytest.fixture
428503
def data_hotspots():
429504
data = np.asarray([

xrspatial/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,9 @@ def _extract_latlon_coords(agg: xr.DataArray):
704704
if lat_vals.ndim == 1 and lon_vals.ndim == 1:
705705
# Regular grid: broadcast to 2-D
706706
lat_2d = np.broadcast_to(lat_vals[:, np.newaxis],
707-
(agg.sizes[dim_y], agg.sizes[dim_x])).copy()
707+
(agg.sizes[dim_y], agg.sizes[dim_x]))
708708
lon_2d = np.broadcast_to(lon_vals[np.newaxis, :],
709-
(agg.sizes[dim_y], agg.sizes[dim_x])).copy()
709+
(agg.sizes[dim_y], agg.sizes[dim_x]))
710710
elif lat_vals.ndim == 2 and lon_vals.ndim == 2:
711711
lat_2d = lat_vals
712712
lon_2d = lon_vals

0 commit comments

Comments
 (0)