88
99# 3rd-party
1010try :
11+ import dask
1112 import dask .array as da
1213except ImportError :
14+ dask = None
1315 da = None
1416
1517try :
@@ -35,7 +37,7 @@ class cupy(object):
3537
3638# local modules
3739from xrspatial .utils import (
38- ArrayTypeFunctionMapping , _validate_raster , has_cuda_and_cupy ,
40+ ArrayTypeFunctionMapping , _validate_raster , cuda_args , has_cuda_and_cupy ,
3941 is_cupy_array , is_dask_cupy ,
4042 ngjit , not_implemented_func , validate_arrays ,
4143)
@@ -1232,9 +1234,51 @@ def _apply_numpy(zones_data, values_data, func, nodata):
12321234 return out
12331235
12341236
1237+ def _make_apply_kernel (func ):
1238+ """Build a CUDA kernel that applies *func* element-wise."""
1239+ from numba import cuda as nb_cuda
1240+
1241+ device_func = nb_cuda .jit (device = True )(func )
1242+
1243+ @nb_cuda .jit
1244+ def _kernel (zones , values , out , nodata_val , has_nodata ):
1245+ y , x = nb_cuda .grid (2 )
1246+ if y < zones .shape [0 ] and x < zones .shape [1 ]:
1247+ if has_nodata and zones [y , x ] == nodata_val :
1248+ return
1249+ out [y , x ] = device_func (values [y , x ])
1250+
1251+ return _kernel
1252+
1253+
1254+ def _apply_cupy_gpu (zones_data , values_data , kernel , nodata ):
1255+ """Run the CUDA apply kernel on cupy arrays."""
1256+ out = values_data .copy ()
1257+ has_nodata = nodata is not None
1258+ nodata_val = nodata if has_nodata else 0
1259+
1260+ griddim , blockdim = cuda_args (values_data .shape [:2 ])
1261+
1262+ if values_data .ndim == 2 :
1263+ kernel [griddim , blockdim ](
1264+ zones_data , values_data , out , nodata_val , has_nodata ,
1265+ )
1266+ else :
1267+ for k in range (values_data .shape [2 ]):
1268+ kernel [griddim , blockdim ](
1269+ zones_data , values_data [:, :, k ], out [:, :, k ],
1270+ nodata_val , has_nodata ,
1271+ )
1272+ return out
1273+
1274+
12351275def _apply_cupy (zones_data , values_data , func , nodata ):
1236- result_np = _apply_numpy (zones_data .get (), values_data .get (), func , nodata )
1237- return cupy .asarray (result_np )
1276+ try :
1277+ kernel = _make_apply_kernel (func )
1278+ return _apply_cupy_gpu (zones_data , values_data , kernel , nodata )
1279+ except Exception :
1280+ result_np = _apply_numpy (zones_data .get (), values_data .get (), func , nodata )
1281+ return cupy .asarray (result_np )
12381282
12391283
12401284def _apply_dask_numpy (zones_data , values_data , func , nodata ):
@@ -1258,16 +1302,43 @@ def _chunk_fn(zones_chunk, values_chunk):
12581302
12591303
12601304def _apply_dask_cupy (zones_data , values_data , func , nodata ):
1261- zones_cpu = zones_data .map_blocks (
1262- lambda x : x .get (), dtype = zones_data .dtype , meta = np .array (()),
1263- )
1264- values_cpu = values_data .map_blocks (
1265- lambda x : x .get (), dtype = values_data .dtype , meta = np .array (()),
1266- )
1267- result = _apply_dask_numpy (zones_cpu , values_cpu , func , nodata )
1268- return result .map_blocks (
1269- cupy .asarray , dtype = result .dtype , meta = cupy .array (()),
1270- )
1305+ # Try GPU: build kernel once, reuse across all chunks
1306+ try :
1307+ kernel = _make_apply_kernel (func )
1308+ gpu_ok = True
1309+ except Exception :
1310+ gpu_ok = False
1311+
1312+ if gpu_ok :
1313+ def _chunk_fn (zones_chunk , values_chunk ):
1314+ try :
1315+ return _apply_cupy_gpu (zones_chunk , values_chunk , kernel , nodata )
1316+ except Exception :
1317+ result_np = _apply_numpy (
1318+ zones_chunk .get (), values_chunk .get (), func , nodata ,
1319+ )
1320+ return cupy .asarray (result_np )
1321+ else :
1322+ def _chunk_fn (zones_chunk , values_chunk ):
1323+ result_np = _apply_numpy (
1324+ zones_chunk .get (), values_chunk .get (), func , nodata ,
1325+ )
1326+ return cupy .asarray (result_np )
1327+
1328+ if values_data .ndim == 2 :
1329+ return da .map_blocks (
1330+ _chunk_fn , zones_data , values_data ,
1331+ dtype = values_data .dtype , meta = cupy .array (()),
1332+ )
1333+ else :
1334+ layers = []
1335+ for k in range (values_data .shape [2 ]):
1336+ layer = values_data [:, :, k ].rechunk (zones_data .chunks )
1337+ layers .append (da .map_blocks (
1338+ _chunk_fn , zones_data , layer ,
1339+ dtype = values_data .dtype , meta = cupy .array (()),
1340+ ))
1341+ return da .stack (layers , axis = 2 )
12711342
12721343
12731344def apply (
@@ -1783,6 +1854,35 @@ def _trim(data, excludes):
17831854 return top , bottom , left , right
17841855
17851856
1857+ def _trim_bounds_dask (data , excludes ):
1858+ """Find trim bounds using lazy dask reductions (O(rows+cols) memory)."""
1859+ excluded = da .zeros_like (data , dtype = bool )
1860+ for v in excludes :
1861+ if isinstance (v , float ) and np .isnan (v ):
1862+ excluded = excluded | da .isnan (data )
1863+ else :
1864+ excluded = excluded | (data == v )
1865+
1866+ all_excl_rows = excluded .all (axis = 1 )
1867+ all_excl_cols = excluded .all (axis = 0 )
1868+ row_mask , col_mask = dask .compute (all_excl_rows , all_excl_cols )
1869+
1870+ # dask+cupy computes to cupy arrays; move to numpy for np.where
1871+ if is_cupy_array (row_mask ):
1872+ row_mask = row_mask .get ()
1873+ if is_cupy_array (col_mask ):
1874+ col_mask = col_mask .get ()
1875+
1876+ data_rows = np .where (~ np .asarray (row_mask ))[0 ]
1877+ data_cols = np .where (~ np .asarray (col_mask ))[0 ]
1878+
1879+ if len (data_rows ) == 0 or len (data_cols ) == 0 :
1880+ return 0 , - 1 , 0 , - 1 # empty slice
1881+
1882+ return (int (data_rows [0 ]), int (data_rows [- 1 ]),
1883+ int (data_cols [0 ]), int (data_cols [- 1 ]))
1884+
1885+
17861886def trim (
17871887 raster : xr .DataArray ,
17881888 values : Union [list , tuple ] = (np .nan ,),
@@ -1891,15 +1991,13 @@ def trim(
18911991 _validate_raster (raster , func_name = 'trim' , name = 'raster' , ndim = 2 )
18921992
18931993 data = raster .data
1894- # _trim needs element access; materialise to numpy for non-numpy backends
1895- if is_cupy_array (data ):
1896- data = data .get ()
1897- elif has_dask_array () and isinstance (data , da .Array ):
1898- data = data .compute ()
1994+ if has_dask_array () and isinstance (data , da .Array ):
1995+ top , bottom , left , right = _trim_bounds_dask (data , values )
1996+ else :
18991997 if is_cupy_array (data ):
19001998 data = data .get ()
1999+ top , bottom , left , right = _trim (data , values )
19012000
1902- top , bottom , left , right = _trim (data , values )
19032001 arr = raster [top : bottom + 1 , left : right + 1 ]
19042002 arr .name = name
19052003 return arr
@@ -2003,6 +2101,32 @@ def _crop(data, values):
20032101 return top , bottom , left , right
20042102
20052103
2104+ def _crop_bounds_dask (data , target_values ):
2105+ """Find crop bounds using lazy dask reductions (O(rows+cols) memory)."""
2106+ matched = da .zeros_like (data , dtype = bool )
2107+ for v in target_values :
2108+ matched = matched | (data == v )
2109+
2110+ any_match_rows = matched .any (axis = 1 )
2111+ any_match_cols = matched .any (axis = 0 )
2112+ row_mask , col_mask = dask .compute (any_match_rows , any_match_cols )
2113+
2114+ # dask+cupy computes to cupy arrays; move to numpy for np.where
2115+ if is_cupy_array (row_mask ):
2116+ row_mask = row_mask .get ()
2117+ if is_cupy_array (col_mask ):
2118+ col_mask = col_mask .get ()
2119+
2120+ match_rows = np .where (np .asarray (row_mask ))[0 ]
2121+ match_cols = np .where (np .asarray (col_mask ))[0 ]
2122+
2123+ if len (match_rows ) == 0 or len (match_cols ) == 0 :
2124+ return 0 , data .shape [0 ] - 1 , 0 , data .shape [1 ] - 1
2125+
2126+ return (int (match_rows [0 ]), int (match_rows [- 1 ]),
2127+ int (match_cols [0 ]), int (match_cols [- 1 ]))
2128+
2129+
20062130def crop (
20072131 zones : xr .DataArray ,
20082132 values : xr .DataArray ,
@@ -2123,15 +2247,13 @@ def crop(
21232247 _validate_raster (values , func_name = 'crop' , name = 'values' , ndim = 2 )
21242248
21252249 data = zones .data
2126- # _crop is @ngjit; materialise to numpy for non-numpy backends
2127- if is_cupy_array (data ):
2128- data = data .get ()
2129- elif has_dask_array () and isinstance (data , da .Array ):
2130- data = data .compute ()
2250+ if has_dask_array () and isinstance (data , da .Array ):
2251+ top , bottom , left , right = _crop_bounds_dask (data , zones_ids )
2252+ else :
21312253 if is_cupy_array (data ):
21322254 data = data .get ()
2255+ top , bottom , left , right = _crop (data , zones_ids )
21332256
2134- top , bottom , left , right = _crop (data , zones_ids )
21352257 arr = values [top : bottom + 1 , left : right + 1 ]
21362258 arr .name = name
21372259 return arr
0 commit comments