@@ -77,6 +77,16 @@ def _mean_dask_numpy(data, excludes, boundary='nan'):
7777 return out
7878
7979
80+ def _mean_dask_cupy (data , excludes , boundary = 'nan' ):
81+ data = data .astype (cupy .float32 )
82+ _func = partial (_mean_cupy , excludes = excludes )
83+ out = data .map_overlap (_func ,
84+ depth = (1 , 1 ),
85+ boundary = _boundary_to_dask (boundary , is_cupy = True ),
86+ meta = cupy .array (()))
87+ return out
88+
89+
8090@cuda .jit
8191def _mean_gpu (data , excludes , out ):
8292 # 1. Get coordinates: x is Column, y is Row
@@ -161,8 +171,7 @@ def _mean(data, excludes, boundary='nan'):
161171 numpy_func = partial (_mean_numpy_boundary , boundary = boundary ),
162172 cupy_func = _mean_cupy ,
163173 dask_func = partial (_mean_dask_numpy , boundary = boundary ),
164- dask_cupy_func = lambda * args : not_implemented_func (
165- * args , messages = 'mean() does not support dask with cupy backed DataArray.' ), # noqa
174+ dask_cupy_func = partial (_mean_dask_cupy , boundary = boundary ),
166175 )
167176 out = mapper (agg )(agg .data , excludes )
168177 return out
@@ -370,6 +379,22 @@ def _apply_dask_numpy(data, kernel, func, boundary='nan'):
370379 return out
371380
372381
382+ def _apply_cupy (data , kernel , func ):
383+ return _focal_stats_func_cupy (data .astype (cupy .float32 ), kernel , func )
384+
385+
386+ def _apply_dask_cupy (data , kernel , func , boundary = 'nan' ):
387+ data = data .astype (cupy .float32 )
388+ pad_h = kernel .shape [0 ] // 2
389+ pad_w = kernel .shape [1 ] // 2
390+ _func = partial (_focal_stats_func_cupy , kernel = kernel , func = func )
391+ out = data .map_overlap (_func ,
392+ depth = (pad_h , pad_w ),
393+ boundary = _boundary_to_dask (boundary , is_cupy = True ),
394+ meta = cupy .array (()))
395+ return out
396+
397+
373398def apply (raster , kernel , func = _calc_mean , name = 'focal_apply' , boundary = 'nan' ):
374399 """
375400 Returns custom function applied array using a user-created window.
@@ -378,11 +403,14 @@ def apply(raster, kernel, func=_calc_mean, name='focal_apply', boundary='nan'):
378403 ----------
379404 raster : xarray.DataArray
380405 2D array of input values to be filtered. Can be a NumPy backed,
381- or Dask with NumPy backed DataArray.
406+ CuPy backed, Dask with NumPy backed, or Dask with CuPy backed
407+ DataArray.
382408 kernel : numpy.ndarray
383409 2D array where values of 1 indicate the kernel.
384410 func : callable, default=xrspatial.focal._calc_mean
385411 Function which takes an input array and returns an array.
412+ For cupy and dask+cupy backends the function must be a
413+ ``@cuda.jit`` global kernel with signature ``(data, kernel, out)``.
386414 boundary : str, default='nan'
387415 How to handle edges where the kernel extends beyond the raster.
388416 ``'nan'`` -- fill missing neighbours with NaN (default).
@@ -496,11 +524,9 @@ def apply(raster, kernel, func=_calc_mean, name='focal_apply', boundary='nan'):
496524 # the function func must be a @ngjit
497525 mapper = ArrayTypeFunctionMapping (
498526 numpy_func = partial (_apply_numpy_boundary , boundary = boundary ),
499- cupy_func = lambda * args : not_implemented_func (
500- * args , messages = 'apply() does not support cupy backed DataArray.' ),
527+ cupy_func = _apply_cupy ,
501528 dask_func = partial (_apply_dask_numpy , boundary = boundary ),
502- dask_cupy_func = lambda * args : not_implemented_func (
503- * args , messages = 'apply() does not support dask with cupy backed DataArray.' ),
529+ dask_cupy_func = partial (_apply_dask_cupy , boundary = boundary ),
504530 )
505531 out = mapper (raster )(raster .data , kernel , func )
506532 result = DataArray (out ,
@@ -818,6 +844,32 @@ def _focal_stats_cupy(agg, kernel, stats_funcs):
818844 return stats
819845
820846
847+ def _focal_stats_dask_cupy (agg , kernel , stats_funcs , boundary = 'nan' ):
848+ _stats_cuda_mapper = dict (
849+ mean = _focal_mean_cuda , sum = _focal_sum_cuda ,
850+ range = _focal_range_cuda , max = _focal_max_cuda ,
851+ min = _focal_min_cuda , std = _focal_std_cuda , var = _focal_var_cuda ,
852+ )
853+ pad_h = kernel .shape [0 ] // 2
854+ pad_w = kernel .shape [1 ] // 2
855+ dask_bnd = _boundary_to_dask (boundary , is_cupy = True )
856+
857+ stats_aggs = []
858+ for stat_name in stats_funcs :
859+ cuda_kernel = _stats_cuda_mapper [stat_name ]
860+ _func = partial (_focal_stats_func_cupy , kernel = kernel , func = cuda_kernel )
861+ data = agg .data .astype (cupy .float32 )
862+ stats_data = data .map_overlap (
863+ _func , depth = (pad_h , pad_w ),
864+ boundary = dask_bnd , meta = cupy .array (()))
865+ stats_agg = xr .DataArray (
866+ stats_data , dims = agg .dims , coords = agg .coords , attrs = agg .attrs )
867+ stats_aggs .append (stats_agg )
868+ stats = xr .concat (stats_aggs ,
869+ pd .Index (stats_funcs , name = 'stats' , dtype = object ))
870+ return stats
871+
872+
821873def _focal_stats_cpu (agg , kernel , stats_funcs , boundary = 'nan' ):
822874 _function_mapping = {
823875 'mean' : _calc_mean ,
@@ -852,7 +904,8 @@ def focal_stats(agg,
852904 ----------
853905 agg : xarray.DataArray
854906 2D array of input values to be analysed. Can be a NumPy backed,
855- Cupy backed, or Dask with NumPy backed DataArray.
907+ CuPy backed, Dask with NumPy backed, or Dask with CuPy backed
908+ DataArray.
856909 kernel : numpy.array
857910 2D array where values of 1 indicate the kernel.
858911 stats_funcs: list of string
@@ -920,8 +973,7 @@ def focal_stats(agg,
920973 numpy_func = partial (_focal_stats_cpu , boundary = boundary ),
921974 cupy_func = _focal_stats_cupy ,
922975 dask_func = partial (_focal_stats_cpu , boundary = boundary ),
923- dask_cupy_func = lambda * args : not_implemented_func (
924- * args , messages = 'focal_stats() does not support dask with cupy backed DataArray.' ),
976+ dask_cupy_func = partial (_focal_stats_dask_cupy , boundary = boundary ),
925977 )
926978 result = mapper (agg )(agg , kernel , stats_funcs )
927979 return result
0 commit comments