@@ -44,6 +44,13 @@ def _to_numpy(arr):
4444 return np .asarray (arr )
4545
4646
47+ def _as_numpy (arr ):
48+ """Convert a computed array (numpy or cupy) to numpy."""
49+ if hasattr (arr , 'get' ):
50+ return arr .get ()
51+ return np .asarray (arr )
52+
53+
4754def _extract_sources (raster , target_values ):
4855 """Return sorted array of unique source IDs from the raster."""
4956 data = _to_numpy (raster .data )
@@ -116,7 +123,7 @@ def _allocate_from_costs(cost_stack, source_ids, fill_value=np.nan):
116123 # Replace NaN with inf for argmin
117124 stacked_clean = da .where (da .isnan (stacked ), np .inf , stacked )
118125 best_idx = da .argmin (stacked_clean , axis = 0 ).compute ()
119- best_idx = np . asarray (best_idx )
126+ best_idx = _as_numpy (best_idx )
120127 elif hasattr (first , 'get' ): # cupy
121128 import cupy as cp
122129 stacked = cp .stack (cost_stack , axis = 0 )
@@ -133,7 +140,7 @@ def _allocate_from_costs(cost_stack, source_ids, fill_value=np.nan):
133140 # Mark cells that are unreachable from all sources
134141 if da is not None and isinstance (first , da .Array ):
135142 all_nan = da .all (da .isnan (da .stack (cost_stack , axis = 0 )), axis = 0 )
136- all_nan = np . asarray (all_nan .compute ())
143+ all_nan = _as_numpy (all_nan .compute ())
137144 elif hasattr (first , 'get' ):
138145 import cupy as cp
139146 all_nan = cp .asnumpy (
@@ -164,7 +171,7 @@ def _allocate_biased(cost_stack, biases, source_ids, fill_value=np.nan):
164171 layers .append (layer )
165172 stacked = da .stack (layers , axis = 0 )
166173 best_idx = da .argmin (stacked , axis = 0 ).compute ()
167- best_idx = np . asarray (best_idx )
174+ best_idx = _as_numpy (best_idx )
168175 elif hasattr (first , 'get' ):
169176 import cupy as cp
170177 layers = []
@@ -188,7 +195,7 @@ def _allocate_biased(cost_stack, biases, source_ids, fill_value=np.nan):
188195 # Mark unreachable cells
189196 if da is not None and isinstance (first , da .Array ):
190197 all_nan = da .all (da .isnan (da .stack (cost_stack , axis = 0 )), axis = 0 )
191- all_nan = np . asarray (all_nan .compute ())
198+ all_nan = _as_numpy (all_nan .compute ())
192199 elif hasattr (first , 'get' ):
193200 import cupy as cp
194201 all_nan = cp .asnumpy (
0 commit comments