Skip to content

Commit 4a860bc

Browse files
committed
Add tests for least_cost_corridor (#965)
Covers symmetry, optimal path minimum, absolute and relative thresholding, precomputed surfaces, multi-source pairwise, barriers, unreachable sources, single-cell input, and input validation. All parametrized across numpy and dask+numpy backends.
1 parent d21ef13 commit 4a860bc

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed

xrspatial/tests/test_corridor.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
"""Tests for xrspatial.corridor.least_cost_corridor."""
2+
3+
try:
4+
import dask.array as da
5+
except ImportError:
6+
da = None
7+
8+
import numpy as np
9+
import pytest
10+
import xarray as xr
11+
12+
from xrspatial.corridor import least_cost_corridor
13+
from xrspatial.cost_distance import cost_distance
14+
from xrspatial.utils import has_cuda_and_cupy
15+
16+
17+
def _make_raster(data, backend="numpy", chunks=(3, 3)):
18+
"""Build a DataArray with y/x coords, optionally dask/cupy-backed."""
19+
h, w = data.shape
20+
raster = xr.DataArray(
21+
data.astype(np.float64),
22+
dims=["y", "x"],
23+
attrs={"res": (1.0, 1.0)},
24+
)
25+
raster["y"] = np.arange(h, dtype=np.float64)
26+
raster["x"] = np.arange(w, dtype=np.float64)
27+
if "dask" in backend and da is not None:
28+
raster.data = da.from_array(raster.data, chunks=chunks)
29+
if "cupy" in backend and has_cuda_and_cupy():
30+
import cupy
31+
32+
if isinstance(raster.data, da.Array):
33+
raster.data = raster.data.map_blocks(cupy.asarray)
34+
else:
35+
raster.data = cupy.asarray(raster.data)
36+
return raster
37+
38+
39+
def _compute(arr):
40+
"""Extract numpy data from DataArray (works for numpy, dask, or cupy)."""
41+
if da is not None and isinstance(arr.data, da.Array):
42+
val = arr.data.compute()
43+
if hasattr(val, "get"):
44+
return val.get()
45+
return val
46+
if hasattr(arr.data, "get"):
47+
return arr.data.get()
48+
return arr.data
49+
50+
51+
# -----------------------------------------------------------------------
52+
# Basic corridor correctness
53+
# -----------------------------------------------------------------------
54+
55+
56+
@pytest.mark.parametrize(
57+
"backend", ["numpy", "dask+numpy", "cupy", "dask+cupy"]
58+
)
59+
def test_basic_corridor_symmetry(backend):
60+
"""Corridor between two sources on uniform friction is symmetric."""
61+
n = 7
62+
friction_data = np.ones((n, n))
63+
64+
src_a = np.zeros((n, n))
65+
src_a[3, 0] = 1.0 # left edge
66+
67+
src_b = np.zeros((n, n))
68+
src_b[3, 6] = 1.0 # right edge
69+
70+
friction = _make_raster(friction_data, backend=backend, chunks=(7, 7))
71+
sa = _make_raster(src_a, backend=backend, chunks=(7, 7))
72+
sb = _make_raster(src_b, backend=backend, chunks=(7, 7))
73+
74+
result = least_cost_corridor(friction, sa, sb)
75+
out = _compute(result)
76+
77+
# Minimum corridor cost should be 0 (after normalization)
78+
assert np.nanmin(out) == pytest.approx(0.0, abs=1e-5)
79+
80+
# Corridor should be symmetric about the vertical midline
81+
np.testing.assert_allclose(out[:, :3], out[:, -1:-4:-1], atol=1e-5)
82+
83+
84+
@pytest.mark.parametrize(
85+
"backend", ["numpy", "dask+numpy", "cupy", "dask+cupy"]
86+
)
87+
def test_corridor_minimum_on_optimal_path(backend):
88+
"""Cells on the optimal path between sources have corridor value 0."""
89+
n = 5
90+
friction_data = np.ones((n, n))
91+
92+
src_a = np.zeros((n, n))
93+
src_a[2, 0] = 1.0
94+
95+
src_b = np.zeros((n, n))
96+
src_b[2, 4] = 1.0
97+
98+
friction = _make_raster(friction_data, backend=backend, chunks=(5, 5))
99+
sa = _make_raster(src_a, backend=backend, chunks=(5, 5))
100+
sb = _make_raster(src_b, backend=backend, chunks=(5, 5))
101+
102+
result = least_cost_corridor(friction, sa, sb)
103+
out = _compute(result)
104+
105+
# The middle row (row 2) should be the optimal path on uniform friction.
106+
# All cells on row 2 should have the minimum corridor value (0).
107+
for col in range(n):
108+
assert out[2, col] == pytest.approx(0.0, abs=1e-5)
109+
110+
111+
# -----------------------------------------------------------------------
112+
# Threshold tests
113+
# -----------------------------------------------------------------------
114+
115+
116+
@pytest.mark.parametrize(
117+
"backend", ["numpy", "dask+numpy", "cupy", "dask+cupy"]
118+
)
119+
def test_absolute_threshold(backend):
120+
"""Absolute threshold masks cells with normalized cost > threshold."""
121+
n = 7
122+
friction_data = np.ones((n, n))
123+
124+
src_a = np.zeros((n, n))
125+
src_a[3, 0] = 1.0
126+
127+
src_b = np.zeros((n, n))
128+
src_b[3, 6] = 1.0
129+
130+
friction = _make_raster(friction_data, backend=backend, chunks=(7, 7))
131+
sa = _make_raster(src_a, backend=backend, chunks=(7, 7))
132+
sb = _make_raster(src_b, backend=backend, chunks=(7, 7))
133+
134+
result = least_cost_corridor(friction, sa, sb, threshold=0.5)
135+
out = _compute(result)
136+
137+
# Cells with normalized cost > 0.5 should be NaN
138+
assert np.all(np.isnan(out) | (out <= 0.5 + 1e-5))
139+
140+
# The optimal path (row 3) should not be masked
141+
for col in range(n):
142+
assert np.isfinite(out[3, col])
143+
144+
145+
@pytest.mark.parametrize(
146+
"backend", ["numpy", "dask+numpy", "cupy", "dask+cupy"]
147+
)
148+
def test_relative_threshold(backend):
149+
"""Relative threshold uses fraction of minimum corridor cost."""
150+
n = 7
151+
friction_data = np.ones((n, n))
152+
153+
src_a = np.zeros((n, n))
154+
src_a[3, 0] = 1.0
155+
156+
src_b = np.zeros((n, n))
157+
src_b[3, 6] = 1.0
158+
159+
friction = _make_raster(friction_data, backend=backend, chunks=(7, 7))
160+
sa = _make_raster(src_a, backend=backend, chunks=(7, 7))
161+
sb = _make_raster(src_b, backend=backend, chunks=(7, 7))
162+
163+
# No threshold -- get full corridor
164+
full = least_cost_corridor(friction, sa, sb)
165+
full_out = _compute(full)
166+
167+
# Relative threshold of 50%
168+
result = least_cost_corridor(
169+
friction, sa, sb, threshold=0.5, relative=True
170+
)
171+
out = _compute(result)
172+
173+
# Count finite cells -- threshold version should have fewer
174+
assert np.sum(np.isfinite(out)) < np.sum(np.isfinite(full_out))
175+
176+
# Optimal path cells should survive
177+
for col in range(n):
178+
assert np.isfinite(out[3, col])
179+
180+
181+
# -----------------------------------------------------------------------
182+
# Precomputed cost-distance surfaces
183+
# -----------------------------------------------------------------------
184+
185+
186+
def test_precomputed_matches_regular():
187+
"""Precomputed=True with manual cost_distance matches default path."""
188+
n = 7
189+
friction_data = np.ones((n, n))
190+
191+
src_a = np.zeros((n, n))
192+
src_a[3, 0] = 1.0
193+
194+
src_b = np.zeros((n, n))
195+
src_b[3, 6] = 1.0
196+
197+
friction = _make_raster(friction_data)
198+
sa = _make_raster(src_a)
199+
sb = _make_raster(src_b)
200+
201+
# Regular path
202+
result_regular = least_cost_corridor(friction, sa, sb)
203+
204+
# Precomputed path
205+
cd_a = cost_distance(sa, friction)
206+
cd_b = cost_distance(sb, friction)
207+
result_precomputed = least_cost_corridor(
208+
friction, cd_a, cd_b, precomputed=True
209+
)
210+
211+
np.testing.assert_allclose(
212+
_compute(result_regular),
213+
_compute(result_precomputed),
214+
atol=1e-5,
215+
)
216+
217+
218+
# -----------------------------------------------------------------------
219+
# Multi-source pairwise
220+
# -----------------------------------------------------------------------
221+
222+
223+
def test_pairwise_corridor():
224+
"""Pairwise mode with 3 sources returns Dataset with 3 corridors."""
225+
n = 7
226+
friction_data = np.ones((n, n))
227+
228+
sources = []
229+
for r, c in [(0, 0), (0, 6), (6, 3)]:
230+
s = np.zeros((n, n))
231+
s[r, c] = 1.0
232+
sources.append(_make_raster(s))
233+
234+
friction = _make_raster(friction_data)
235+
236+
result = least_cost_corridor(
237+
friction, sources=sources, pairwise=True
238+
)
239+
240+
assert isinstance(result, xr.Dataset)
241+
assert set(result.data_vars) == {
242+
"corridor_0_1",
243+
"corridor_0_2",
244+
"corridor_1_2",
245+
}
246+
247+
# Each corridor should have minimum 0
248+
for name in result.data_vars:
249+
out = _compute(result[name])
250+
assert np.nanmin(out) == pytest.approx(0.0, abs=1e-5)
251+
252+
253+
def test_pairwise_two_sources_returns_dataset():
254+
"""Pairwise=True with exactly 2 sources still returns a Dataset."""
255+
n = 5
256+
friction_data = np.ones((n, n))
257+
258+
s0 = np.zeros((n, n))
259+
s0[0, 0] = 1.0
260+
s1 = np.zeros((n, n))
261+
s1[4, 4] = 1.0
262+
263+
friction = _make_raster(friction_data)
264+
result = least_cost_corridor(
265+
friction,
266+
sources=[_make_raster(s0), _make_raster(s1)],
267+
pairwise=True,
268+
)
269+
270+
assert isinstance(result, xr.Dataset)
271+
assert "corridor_0_1" in result.data_vars
272+
273+
274+
# -----------------------------------------------------------------------
275+
# NaN / barrier handling
276+
# -----------------------------------------------------------------------
277+
278+
279+
@pytest.mark.parametrize(
280+
"backend", ["numpy", "dask+numpy", "cupy", "dask+cupy"]
281+
)
282+
def test_barrier_blocks_corridor(backend):
283+
"""NaN barrier between sources makes certain cells unreachable."""
284+
n = 7
285+
friction_data = np.ones((n, n))
286+
# Wall of NaN except a gap at row 3
287+
friction_data[:3, 3] = np.nan
288+
friction_data[4:, 3] = np.nan
289+
290+
src_a = np.zeros((n, n))
291+
src_a[3, 0] = 1.0
292+
293+
src_b = np.zeros((n, n))
294+
src_b[3, 6] = 1.0
295+
296+
friction = _make_raster(friction_data, backend=backend, chunks=(7, 7))
297+
sa = _make_raster(src_a, backend=backend, chunks=(7, 7))
298+
sb = _make_raster(src_b, backend=backend, chunks=(7, 7))
299+
300+
result = least_cost_corridor(friction, sa, sb)
301+
out = _compute(result)
302+
303+
# The gap row should still be reachable
304+
assert np.isfinite(out[3, 3])
305+
306+
307+
@pytest.mark.parametrize(
308+
"backend", ["numpy", "dask+numpy", "cupy", "dask+cupy"]
309+
)
310+
def test_unreachable_sources(backend):
311+
"""Full barrier between sources produces all-NaN corridor."""
312+
n = 5
313+
friction_data = np.ones((n, n))
314+
friction_data[:, 2] = np.nan # impenetrable wall
315+
316+
src_a = np.zeros((n, n))
317+
src_a[2, 0] = 1.0
318+
319+
src_b = np.zeros((n, n))
320+
src_b[2, 4] = 1.0
321+
322+
friction = _make_raster(friction_data, backend=backend, chunks=(5, 5))
323+
sa = _make_raster(src_a, backend=backend, chunks=(5, 5))
324+
sb = _make_raster(src_b, backend=backend, chunks=(5, 5))
325+
326+
result = least_cost_corridor(friction, sa, sb)
327+
out = _compute(result)
328+
329+
assert np.all(np.isnan(out))
330+
331+
332+
# -----------------------------------------------------------------------
333+
# Edge cases and validation
334+
# -----------------------------------------------------------------------
335+
336+
337+
def test_single_cell_raster():
338+
"""1x1 raster where both sources are the same cell."""
339+
friction = _make_raster(np.ones((1, 1)))
340+
src = _make_raster(np.ones((1, 1)))
341+
342+
result = least_cost_corridor(friction, src, src)
343+
out = _compute(result)
344+
345+
assert out[0, 0] == pytest.approx(0.0, abs=1e-5)
346+
347+
348+
def test_missing_sources_raises():
349+
"""Omitting both source_a/source_b and sources raises ValueError."""
350+
friction = _make_raster(np.ones((3, 3)))
351+
with pytest.raises(ValueError, match="source_a and source_b are required"):
352+
least_cost_corridor(friction)
353+
354+
355+
def test_both_source_modes_raises():
356+
"""Providing source_a/source_b AND sources raises ValueError."""
357+
friction = _make_raster(np.ones((3, 3)))
358+
src = _make_raster(np.ones((3, 3)))
359+
with pytest.raises(ValueError, match="not both"):
360+
least_cost_corridor(friction, src, src, sources=[src, src])
361+
362+
363+
def test_negative_threshold_raises():
364+
"""Negative threshold raises ValueError."""
365+
friction = _make_raster(np.ones((3, 3)))
366+
src = _make_raster(np.ones((3, 3)))
367+
with pytest.raises(ValueError, match="non-negative"):
368+
least_cost_corridor(friction, src, src, threshold=-1.0)
369+
370+
371+
def test_single_source_in_list_raises():
372+
"""sources with fewer than 2 entries raises ValueError."""
373+
friction = _make_raster(np.ones((3, 3)))
374+
src = _make_raster(np.ones((3, 3)))
375+
with pytest.raises(ValueError, match="at least 2"):
376+
least_cost_corridor(friction, sources=[src])

0 commit comments

Comments
 (0)