88import xarray
99
1010from .gpu_rtx import has_rtx
11- from .utils import has_cuda_and_cupy , has_dask_array , is_cupy_array , ngjit
11+ from .utils import (has_cuda_and_cupy , has_dask_array , is_cupy_array ,
12+ is_cupy_backed , is_dask_cupy , ngjit )
1213
1314E_ROW_ID = 0
1415E_COL_ID = 1
@@ -1688,6 +1689,11 @@ def viewshed(raster: xarray.DataArray,
16881689
16891690 """
16901691
1692+ # --- max_distance: extract spatial window for any backend ---
1693+ if max_distance is not None :
1694+ return _viewshed_windowed (raster , x , y , observer_elev , target_elev ,
1695+ max_distance )
1696+
16911697 if isinstance (raster .data , np .ndarray ):
16921698 return _viewshed_cpu (raster , x , y , observer_elev , target_elev )
16931699
@@ -1705,8 +1711,7 @@ def viewshed(raster: xarray.DataArray,
17051711 elif has_dask_array ():
17061712 import dask .array as da
17071713 if isinstance (raster .data , da .Array ):
1708- return _viewshed_dask (raster , x , y , observer_elev , target_elev ,
1709- max_distance )
1714+ return _viewshed_dask (raster , x , y , observer_elev , target_elev )
17101715
17111716 raise TypeError (f"Unsupported raster array type: { type (raster .data )} " )
17121717
@@ -1964,10 +1969,15 @@ def _viewshed_distance_sweep(dask_data, H, W, obs_r, obs_c,
19641969 return visibility
19651970
19661971
1967- def _viewshed_dask (raster , x , y , observer_elev , target_elev , max_distance ):
1968- """Dask-backed viewshed with three-tier strategy."""
1969- import dask . array as da
1972+ def _viewshed_windowed (raster , x , y , observer_elev , target_elev ,
1973+ max_distance ):
1974+ """Run viewshed on a spatial window around the observer.
19701975
1976+ Works for any backend: numpy, cupy, dask+numpy, dask+cupy. The window
1977+ is extracted via xarray slicing, computed to an in-memory array, then
1978+ dispatched to the appropriate single-array backend. The result is
1979+ embedded in a full-size INVISIBLE output.
1980+ """
19711981 height , width = raster .shape
19721982 y_coords = raster .indexes .get ('y' ).values
19731983 x_coords = raster .indexes .get ('x' ).values
@@ -1988,47 +1998,114 @@ def _viewshed_dask(raster, x, y, observer_elev, target_elev, max_distance):
19881998 x_range = (x_coords [0 ], x_coords [- 1 ])
19891999 ew_res = (x_range [1 ] - x_range [0 ]) / (width - 1 ) if width > 1 else 1.0
19902000 ns_res = (y_range [1 ] - y_range [0 ]) / (height - 1 ) if height > 1 else 1.0
2001+ cell_size = max (abs (ew_res ), abs (ns_res ))
2002+ radius_cells = int (np .ceil (max_distance / cell_size ))
19912003
1992- # --- Tier A: max_distance → extract window, run CPU R2 ---
1993- if max_distance is not None :
1994- cell_size = max (abs (ew_res ), abs (ns_res ))
1995- radius_cells = int (np .ceil (max_distance / cell_size ))
1996-
1997- r_lo = max (0 , obs_r - radius_cells )
1998- r_hi = min (height , obs_r + radius_cells + 1 )
1999- c_lo = max (0 , obs_c - radius_cells )
2000- c_hi = min (width , obs_c + radius_cells + 1 )
2004+ r_lo = max (0 , obs_r - radius_cells )
2005+ r_hi = min (height , obs_r + radius_cells + 1 )
2006+ c_lo = max (0 , obs_c - radius_cells )
2007+ c_hi = min (width , obs_c + radius_cells + 1 )
20012008
2002- window = raster .isel (y = slice (r_lo , r_hi ), x = slice (c_lo , c_hi ))
2003- window_np = window .copy ()
2004- window_np .data = window .data .compute ()
2009+ window = raster .isel (y = slice (r_lo , r_hi ), x = slice (c_lo , c_hi ))
20052010
2011+ # Materialise to in-memory array (numpy or cupy)
2012+ is_cupy = has_cuda_and_cupy () and (
2013+ is_cupy_array (raster .data ) or is_cupy_backed (raster ))
2014+ if has_dask_array ():
2015+ import dask .array as da
2016+ if isinstance (window .data , da .Array ):
2017+ window = window .copy ()
2018+ window .data = window .data .compute ()
2019+
2020+ if is_cupy and has_rtx ():
2021+ import cupy as cp
2022+ if not is_cupy_array (window .data ):
2023+ window .data = cp .asarray (window .data )
2024+ from .gpu_rtx .viewshed import viewshed_gpu
2025+ local_result = viewshed_gpu (
2026+ window , x , y , observer_elev , target_elev )
2027+ else :
2028+ if is_cupy :
2029+ import cupy as cp
2030+ window .data = cp .asnumpy (window .data )
2031+ elif not isinstance (window .data , np .ndarray ):
2032+ window .data = np .asarray (window .data )
20062033 local_result = _viewshed_cpu (
2007- window_np ,
2008- x = x , y = y ,
2009- observer_elev = observer_elev ,
2010- target_elev = target_elev ,
2011- )
2034+ window , x , y , observer_elev , target_elev )
20122035
2013- # Embed in full-size INVISIBLE array
2036+ # Embed in full-size INVISIBLE output, preserving array type
2037+ if is_cupy and has_rtx ():
2038+ import cupy as cp
2039+ full_vis = cp .full ((height , width ), INVISIBLE , dtype = np .float64 )
2040+ full_vis [r_lo :r_hi , c_lo :c_hi ] = local_result .data
2041+ else :
2042+ local_vals = local_result .values
20142043 full_vis = np .full ((height , width ), INVISIBLE , dtype = np .float64 )
2015- full_vis [r_lo :r_hi , c_lo :c_hi ] = local_result .values
2016- vis_da = da .from_array (full_vis , chunks = raster .data .chunks )
2017- return xarray .DataArray (vis_da , coords = raster .coords ,
2018- dims = raster .dims , attrs = raster .attrs )
2044+ full_vis [r_lo :r_hi , c_lo :c_hi ] = local_vals
2045+
2046+ # Wrap in the same array type as the input
2047+ if has_dask_array () and isinstance (raster .data , da .Array ):
2048+ full_vis = da .from_array (full_vis , chunks = raster .data .chunks )
2049+
2050+ return xarray .DataArray (full_vis , coords = raster .coords ,
2051+ dims = raster .dims , attrs = raster .attrs )
20192052
2020- # --- Tier B: full grid fits in memory → compute and run CPU R2 ---
2053+
2054+ def _viewshed_dask (raster , x , y , observer_elev , target_elev ):
2055+ """Dask-backed viewshed (no max_distance — handled by caller).
2056+
2057+ Two-tier strategy:
2058+ - Tier B: grid fits in memory → compute and run exact R2 (CPU or GPU).
2059+ - Tier C: out-of-core horizon-profile distance sweep.
2060+ """
2061+ import dask .array as da
2062+
2063+ height , width = raster .shape
2064+ y_coords = raster .indexes .get ('y' ).values
2065+ x_coords = raster .indexes .get ('x' ).values
2066+
2067+ if not (x_coords .min () <= x <= x_coords .max ()):
2068+ raise ValueError ("x argument outside of raster x_range" )
2069+ if not (y_coords .min () <= y <= y_coords .max ()):
2070+ raise ValueError ("y argument outside of raster y_range" )
2071+
2072+ selection = raster .sel (x = [x ], y = [y ], method = 'nearest' )
2073+ x = selection .x .values [0 ]
2074+ y = selection .y .values [0 ]
2075+
2076+ obs_r = int (np .where (y_coords == y )[0 ][0 ])
2077+ obs_c = int (np .where (x_coords == x )[0 ][0 ])
2078+
2079+ y_range = (y_coords [0 ], y_coords [- 1 ])
2080+ x_range = (x_coords [0 ], x_coords [- 1 ])
2081+ ew_res = (x_range [1 ] - x_range [0 ]) / (width - 1 ) if width > 1 else 1.0
2082+ ns_res = (y_range [1 ] - y_range [0 ]) / (height - 1 ) if height > 1 else 1.0
2083+
2084+ cupy_backed = is_dask_cupy (raster )
2085+
2086+ # --- Tier B: full grid fits in memory → compute and run exact algo ---
20212087 r2_bytes = 280 * height * width
20222088 avail = _available_memory_bytes ()
20232089 if r2_bytes < 0.5 * avail :
2024- raster_np = raster .copy ()
2025- raster_np .data = raster .data .compute ()
2026- result = _viewshed_cpu (raster_np , x , y , observer_elev , target_elev )
2027- vis_da = da .from_array (result .values , chunks = raster .data .chunks )
2090+ raster_mem = raster .copy ()
2091+ raster_mem .data = raster .data .compute ()
2092+ if cupy_backed and has_rtx ():
2093+ from .gpu_rtx .viewshed import viewshed_gpu
2094+ result = viewshed_gpu (raster_mem , x , y ,
2095+ observer_elev , target_elev )
2096+ else :
2097+ if cupy_backed :
2098+ import cupy as cp
2099+ raster_mem .data = cp .asnumpy (raster_mem .data )
2100+ result = _viewshed_cpu (raster_mem , x , y ,
2101+ observer_elev , target_elev )
2102+ result_np = result .values if isinstance (result .data , np .ndarray ) \
2103+ else result .data .get ()
2104+ vis_da = da .from_array (result_np , chunks = raster .data .chunks )
20282105 return xarray .DataArray (vis_da , coords = raster .coords ,
20292106 dims = raster .dims , attrs = raster .attrs )
20302107
2031- # --- Tier C: out-of-core distance sweep ---
2108+ # --- Tier C: out-of-core distance sweep (CPU only) ---
20322109 output_bytes = height * width * 8
20332110 if output_bytes > 0.8 * avail :
20342111 raise MemoryError (
@@ -2037,22 +2114,29 @@ def _viewshed_dask(raster, x, y, observer_elev, target_elev, max_distance):
20372114 f"Use max_distance to limit the analysis area."
20382115 )
20392116
2040- obs_elev_val = raster .data .blocks [
2041- _chunk_index_for (_chunk_offsets (raster .data .chunks [0 ]), obs_r ),
2042- _chunk_index_for (_chunk_offsets (raster .data .chunks [1 ]), obs_c ),
2117+ # For dask+cupy, chunks compute to cupy arrays — cache needs numpy
2118+ dask_data = raster .data
2119+ if cupy_backed :
2120+ dask_data = dask_data .map_blocks (
2121+ lambda block : block .get (), dtype = np .float64 ,
2122+ meta = np .array (()))
2123+
2124+ obs_elev_val = dask_data .blocks [
2125+ _chunk_index_for (_chunk_offsets (dask_data .chunks [0 ]), obs_r ),
2126+ _chunk_index_for (_chunk_offsets (dask_data .chunks [1 ]), obs_c ),
20432127 ].compute ()
2044- local_r = obs_r - int (_chunk_offsets (raster . data .chunks [0 ])[
2045- _chunk_index_for (_chunk_offsets (raster . data .chunks [0 ]), obs_r )])
2046- local_c = obs_c - int (_chunk_offsets (raster . data .chunks [1 ])[
2047- _chunk_index_for (_chunk_offsets (raster . data .chunks [1 ]), obs_c )])
2128+ local_r = obs_r - int (_chunk_offsets (dask_data .chunks [0 ])[
2129+ _chunk_index_for (_chunk_offsets (dask_data .chunks [0 ]), obs_r )])
2130+ local_c = obs_c - int (_chunk_offsets (dask_data .chunks [1 ])[
2131+ _chunk_index_for (_chunk_offsets (dask_data .chunks [1 ]), obs_c )])
20482132 terrain_elev = float (obs_elev_val [local_r , local_c ])
20492133 vp_elev = terrain_elev + observer_elev
20502134
20512135 visibility = _viewshed_distance_sweep (
2052- raster . data , height , width , obs_r , obs_c ,
2136+ dask_data , height , width , obs_r , obs_c ,
20532137 vp_elev , target_elev , ew_res , ns_res ,
2054- raster . data . chunks [0 ], raster . data .chunks [1 ],
2055- max_distance ,
2138+ dask_data . chunks [0 ], dask_data .chunks [1 ],
2139+ None ,
20562140 )
20572141
20582142 vis_da = da .from_array (visibility , chunks = raster .data .chunks )
0 commit comments