Skip to content

Commit 521e9f8

Browse files
authored
fixed cupy failures in balanced_allocation and corridor (#985)
1 parent 001db21 commit 521e9f8

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

xrspatial/balanced_allocation.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ def _to_numpy(arr):
4444
return np.asarray(arr)
4545

4646

47+
def _as_numpy(arr):
48+
"""Convert a computed array (numpy or cupy) to numpy."""
49+
if hasattr(arr, 'get'):
50+
return arr.get()
51+
return np.asarray(arr)
52+
53+
4754
def _extract_sources(raster, target_values):
4855
"""Return sorted array of unique source IDs from the raster."""
4956
data = _to_numpy(raster.data)
@@ -116,7 +123,7 @@ def _allocate_from_costs(cost_stack, source_ids, fill_value=np.nan):
116123
# Replace NaN with inf for argmin
117124
stacked_clean = da.where(da.isnan(stacked), np.inf, stacked)
118125
best_idx = da.argmin(stacked_clean, axis=0).compute()
119-
best_idx = np.asarray(best_idx)
126+
best_idx = _as_numpy(best_idx)
120127
elif hasattr(first, 'get'): # cupy
121128
import cupy as cp
122129
stacked = cp.stack(cost_stack, axis=0)
@@ -133,7 +140,7 @@ def _allocate_from_costs(cost_stack, source_ids, fill_value=np.nan):
133140
# Mark cells that are unreachable from all sources
134141
if da is not None and isinstance(first, da.Array):
135142
all_nan = da.all(da.isnan(da.stack(cost_stack, axis=0)), axis=0)
136-
all_nan = np.asarray(all_nan.compute())
143+
all_nan = _as_numpy(all_nan.compute())
137144
elif hasattr(first, 'get'):
138145
import cupy as cp
139146
all_nan = cp.asnumpy(
@@ -164,7 +171,7 @@ def _allocate_biased(cost_stack, biases, source_ids, fill_value=np.nan):
164171
layers.append(layer)
165172
stacked = da.stack(layers, axis=0)
166173
best_idx = da.argmin(stacked, axis=0).compute()
167-
best_idx = np.asarray(best_idx)
174+
best_idx = _as_numpy(best_idx)
168175
elif hasattr(first, 'get'):
169176
import cupy as cp
170177
layers = []
@@ -188,7 +195,7 @@ def _allocate_biased(cost_stack, biases, source_ids, fill_value=np.nan):
188195
# Mark unreachable cells
189196
if da is not None and isinstance(first, da.Array):
190197
all_nan = da.all(da.isnan(da.stack(cost_stack, axis=0)), axis=0)
191-
all_nan = np.asarray(all_nan.compute())
198+
all_nan = _as_numpy(all_nan.compute())
192199
elif hasattr(first, 'get'):
193200
import cupy as cp
194201
all_nan = cp.asnumpy(

xrspatial/corridor.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@
2828
from xrspatial.utils import _validate_raster
2929

3030

31+
def _scalar_to_float(da_scalar):
32+
"""Extract a Python float from a scalar DataArray (numpy/cupy/dask)."""
33+
data = da_scalar.data
34+
if hasattr(data, 'compute'):
35+
data = data.compute()
36+
if hasattr(data, 'get'):
37+
data = data.get()
38+
return float(data)
39+
40+
3141
def _compute_corridor(
3242
cd_a: xr.DataArray,
3343
cd_b: xr.DataArray,
@@ -36,7 +46,7 @@ def _compute_corridor(
3646
) -> xr.DataArray:
3747
"""Sum two cost-distance surfaces, normalize, and optionally threshold."""
3848
corridor = cd_a + cd_b
39-
corridor_min = float(corridor.min())
49+
corridor_min = _scalar_to_float(corridor.min())
4050

4151
if not np.isfinite(corridor_min):
4252
# Sources are mutually unreachable -- return all-NaN.
@@ -49,7 +59,20 @@ def _compute_corridor(
4959
cutoff = threshold * corridor_min
5060
else:
5161
cutoff = threshold
52-
normalized = normalized.where(normalized <= cutoff)
62+
data = normalized.data
63+
try:
64+
import dask.array as _da
65+
if isinstance(data, _da.Array):
66+
data = _da.where(data <= cutoff, data, np.nan)
67+
else:
68+
raise ImportError
69+
except ImportError:
70+
if hasattr(data, 'get'): # cupy
71+
import cupy as cp
72+
data = cp.where(data <= cutoff, data, cp.nan)
73+
else:
74+
data = np.where(data <= cutoff, data, np.nan)
75+
normalized = normalized.copy(data=data)
5376

5477
return normalized
5578

0 commit comments

Comments
 (0)