1010import xarray as xr
1111
1212from xrspatial .cost_distance import cost_distance
13+ from xrspatial .tests .general_checks import cuda_and_cupy_available
14+ from xrspatial .utils import has_cuda_and_cupy , has_dask_array
1315
1416
1517def _make_raster (data , backend = 'numpy' , chunks = (3 , 3 )):
16- """Build a DataArray with y/x coords, optionally dask-backed."""
18+ """Build a DataArray with y/x coords, optionally dask/cupy -backed."""
1719 h , w = data .shape
1820 raster = xr .DataArray (
1921 data .astype (np .float64 ),
@@ -24,13 +26,24 @@ def _make_raster(data, backend='numpy', chunks=(3, 3)):
2426 raster ['x' ] = np .arange (w , dtype = np .float64 )
2527 if 'dask' in backend and da is not None :
2628 raster .data = da .from_array (raster .data , chunks = chunks )
29+ if 'cupy' in backend and has_cuda_and_cupy ():
30+ import cupy
31+ if isinstance (raster .data , da .Array ):
32+ raster .data = raster .data .map_blocks (cupy .asarray )
33+ else :
34+ raster .data = cupy .asarray (raster .data )
2735 return raster
2836
2937
3038def _compute (arr ):
31- """Extract numpy data from DataArray (works for numpy or dask )."""
39+ """Extract numpy data from DataArray (works for numpy, dask, or cupy )."""
3240 if da is not None and isinstance (arr .data , da .Array ):
33- return arr .values
41+ val = arr .data .compute ()
42+ if hasattr (val , 'get' ):
43+ return val .get ()
44+ return val
45+ if hasattr (arr .data , 'get' ):
46+ return arr .data .get ()
3447 return arr .data
3548
3649
@@ -400,3 +413,91 @@ def test_source_on_impassable_cell(backend):
400413
401414 # Everything should be NaN — the only source is on impassable terrain
402415 assert np .all (np .isnan (out ))
416+
417+
418+ # -----------------------------------------------------------------------
419+ # CuPy GPU spill-to-CPU tests
420+ # -----------------------------------------------------------------------
421+
422+ @cuda_and_cupy_available
423+ def test_cupy_matches_numpy ():
424+ """CuPy (CPU fallback) path should produce identical results to numpy."""
425+ np .random .seed (42 )
426+ source = np .zeros ((7 , 7 ))
427+ source [3 , 3 ] = 1.0
428+
429+ friction_data = np .random .uniform (0.5 , 5.0 , (7 , 7 ))
430+
431+ result_np = _compute (cost_distance (
432+ _make_raster (source , backend = 'numpy' ),
433+ _make_raster (friction_data , backend = 'numpy' ),
434+ ))
435+ result_cupy = _compute (cost_distance (
436+ _make_raster (source , backend = 'cupy' ),
437+ _make_raster (friction_data , backend = 'cupy' ),
438+ ))
439+
440+ np .testing .assert_allclose (result_cupy , result_np , equal_nan = True , atol = 1e-5 )
441+
442+
443+ @cuda_and_cupy_available
444+ def test_cupy_max_cost ():
445+ """CuPy path respects max_cost truncation."""
446+ source = np .zeros ((1 , 10 ))
447+ source [0 , 0 ] = 1.0
448+ friction_data = np .ones ((1 , 10 ))
449+
450+ result = _compute (cost_distance (
451+ _make_raster (source , backend = 'cupy' ),
452+ _make_raster (friction_data , backend = 'cupy' ),
453+ max_cost = 3.5 ,
454+ ))
455+
456+ np .testing .assert_allclose (result [0 , 3 ], 3.0 , atol = 1e-5 )
457+ assert np .isnan (result [0 , 4 ])
458+
459+
460+ @cuda_and_cupy_available
461+ def test_cupy_returns_cupy_array ():
462+ """Result should be CuPy-backed when input is CuPy-backed."""
463+ import cupy
464+ from xrspatial .utils import is_cupy_array
465+
466+ source = np .zeros ((3 , 3 ))
467+ source [1 , 1 ] = 1.0
468+ friction_data = np .ones ((3 , 3 ))
469+
470+ result = cost_distance (
471+ _make_raster (source , backend = 'cupy' ),
472+ _make_raster (friction_data , backend = 'cupy' ),
473+ )
474+ assert is_cupy_array (result .data )
475+
476+
477+ # -----------------------------------------------------------------------
478+ # Dask + CuPy GPU spill-to-CPU tests
479+ # -----------------------------------------------------------------------
480+
481+ @cuda_and_cupy_available
482+ @pytest .mark .skipif (not has_dask_array (), reason = "Requires dask.Array" )
483+ def test_dask_cupy_matches_numpy ():
484+ """Dask+CuPy (CPU fallback) should produce identical results to numpy."""
485+ np .random .seed (42 )
486+ source = np .zeros ((10 , 12 ))
487+ source [2 , 3 ] = 1.0
488+ source [7 , 9 ] = 2.0
489+
490+ friction_data = np .random .uniform (0.5 , 5.0 , (10 , 12 ))
491+
492+ result_np = _compute (cost_distance (
493+ _make_raster (source , backend = 'numpy' ),
494+ _make_raster (friction_data , backend = 'numpy' ),
495+ max_cost = 20.0 ,
496+ ))
497+ result_dc = _compute (cost_distance (
498+ _make_raster (source , backend = 'dask+cupy' , chunks = (5 , 6 )),
499+ _make_raster (friction_data , backend = 'dask+cupy' , chunks = (5 , 6 )),
500+ max_cost = 20.0 ,
501+ ))
502+
503+ np .testing .assert_allclose (result_dc , result_np , equal_nan = True , atol = 1e-5 )
0 commit comments