Skip to content

Commit a2c9019

Browse files
committed
Speed up sieve labeling and add convergence warning (#1162)
Replace per-value scipy.ndimage.label calls with a single-pass numba union-find for connected-component labeling. Rework the adjacency builder to use 1-D np.unique on encoded int64 pairs instead of the slower np.unique(axis=0) on 2-D arrays. Add a UserWarning when the 50-iteration merge loop doesn't converge.
1 parent 2e310e8 commit a2c9019

File tree

2 files changed

+232
-57
lines changed

2 files changed

+232
-57
lines changed

xrspatial/sieve.py

Lines changed: 152 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import warnings
1415
from collections import defaultdict
1516
from typing import Sequence
1617

@@ -36,40 +37,161 @@ class cupy:
3637
has_cuda_and_cupy,
3738
is_cupy_array,
3839
is_dask_cupy,
40+
ngjit,
3941
)
4042

43+
_MAX_ITERATIONS = 50
44+
45+
46+
# ---------------------------------------------------------------------------
47+
# Numba union-find labeling
48+
# ---------------------------------------------------------------------------
49+
50+
51+
@ngjit
52+
def _uf_find(parent, x):
53+
"""Find root of *x* with path halving."""
54+
while parent[x] != x:
55+
parent[x] = parent[parent[x]]
56+
x = parent[x]
57+
return x
58+
59+
60+
@ngjit
61+
def _uf_union(parent, rank, a, b):
62+
"""Union by rank."""
63+
ra = _uf_find(parent, a)
64+
rb = _uf_find(parent, b)
65+
if ra == rb:
66+
return
67+
if rank[ra] < rank[rb]:
68+
parent[ra] = rb
69+
elif rank[ra] > rank[rb]:
70+
parent[rb] = ra
71+
else:
72+
parent[rb] = ra
73+
rank[ra] += 1
74+
75+
76+
@ngjit
77+
def _label_connected(data, valid, neighborhood):
78+
"""Single-pass connected-component labeling via union-find.
79+
80+
Labels connected regions of same-value pixels in one O(n) pass,
81+
replacing the previous approach of calling ``scipy.ndimage.label``
82+
once per unique raster value.
83+
84+
Returns
85+
-------
86+
region_map : ndarray of int32 (2D)
87+
Each pixel mapped to its region id (0 = nodata).
88+
region_val : ndarray of float64 (1D)
89+
Original raster value for each region id.
90+
n_regions : int
91+
Total number of regions + 1 (length of *region_val*).
92+
"""
93+
rows = data.shape[0]
94+
cols = data.shape[1]
95+
n = rows * cols
96+
parent = np.arange(n, dtype=np.int32)
97+
rank = np.zeros(n, dtype=np.int32)
98+
99+
for r in range(rows):
100+
for c in range(cols):
101+
if not valid[r, c]:
102+
continue
103+
idx = r * cols + c
104+
val = data[r, c]
105+
106+
# Check left (already visited)
107+
if c > 0 and valid[r, c - 1] and data[r, c - 1] == val:
108+
_uf_union(parent, rank, idx, idx - 1)
109+
# Check up (already visited)
110+
if r > 0 and valid[r - 1, c] and data[r - 1, c] == val:
111+
_uf_union(parent, rank, idx, (r - 1) * cols + c)
112+
113+
if neighborhood == 8:
114+
if (
115+
r > 0
116+
and c > 0
117+
and valid[r - 1, c - 1]
118+
and data[r - 1, c - 1] == val
119+
):
120+
_uf_union(parent, rank, idx, (r - 1) * cols + (c - 1))
121+
if (
122+
r > 0
123+
and c + 1 < cols
124+
and valid[r - 1, c + 1]
125+
and data[r - 1, c + 1] == val
126+
):
127+
_uf_union(parent, rank, idx, (r - 1) * cols + (c + 1))
128+
129+
# Assign contiguous region IDs
130+
region_map_flat = np.zeros(n, dtype=np.int32)
131+
root_to_id = np.zeros(n, dtype=np.int32)
132+
region_val_buf = np.full(n + 1, np.nan, dtype=np.float64)
133+
next_id = 1
134+
135+
for i in range(n):
136+
r = i // cols
137+
c = i % cols
138+
if not valid[r, c]:
139+
continue
140+
root = _uf_find(parent, i)
141+
if root_to_id[root] == 0:
142+
root_to_id[root] = next_id
143+
region_val_buf[next_id] = data[r, c]
144+
next_id += 1
145+
region_map_flat[i] = root_to_id[root]
146+
147+
region_map = region_map_flat.reshape(rows, cols)
148+
return region_map, region_val_buf[:next_id], next_id
149+
41150

42151
# ---------------------------------------------------------------------------
43152
# Adjacency helpers
44153
# ---------------------------------------------------------------------------
45154

46155

47156
def _build_adjacency(region_map, neighborhood):
48-
"""Build a region adjacency dict from a labeled map using vectorized shifts.
157+
"""Build a region adjacency dict from a labeled map.
158+
159+
Encodes each (lo, hi) region pair as a single int64 so
160+
deduplication uses fast 1-D ``np.unique`` instead of the slower
161+
``np.unique(axis=0)`` on 2-D pair arrays.
49162
50163
Returns ``{region_id: set_of_neighbor_ids}``.
51164
"""
52-
adjacency: dict[int, set[int]] = defaultdict(set)
165+
max_id = np.int64(region_map.max()) + 1
166+
encoded_parts: list[np.ndarray] = []
53167

54-
def _add_pairs(a, b):
168+
def _collect(a, b):
55169
mask = (a > 0) & (b > 0) & (a != b)
56170
if not mask.any():
57171
return
58-
pairs = np.unique(
59-
np.column_stack([a[mask].ravel(), b[mask].ravel()]), axis=0
60-
)
61-
for x, y in pairs:
62-
adjacency[int(x)].add(int(y))
63-
adjacency[int(y)].add(int(x))
172+
am = a[mask].ravel().astype(np.int64)
173+
bm = b[mask].ravel().astype(np.int64)
174+
lo = np.minimum(am, bm)
175+
hi = np.maximum(am, bm)
176+
encoded_parts.append(lo * max_id + hi)
177+
178+
_collect(region_map[:-1, :], region_map[1:, :])
179+
_collect(region_map[:, :-1], region_map[:, 1:])
180+
if neighborhood == 8:
181+
_collect(region_map[:-1, :-1], region_map[1:, 1:])
182+
_collect(region_map[:-1, 1:], region_map[1:, :-1])
183+
184+
adjacency: dict[int, set[int]] = defaultdict(set)
185+
if not encoded_parts:
186+
return adjacency
64187

65-
# 4-connected directions (rook)
66-
_add_pairs(region_map[:-1, :], region_map[1:, :]) # vertical
67-
_add_pairs(region_map[:, :-1], region_map[:, 1:]) # horizontal
188+
encoded = np.unique(np.concatenate(encoded_parts))
189+
lo_arr = encoded // max_id
190+
hi_arr = encoded % max_id
68191

69-
# 8-connected adds diagonals (queen)
70-
if neighborhood == 8:
71-
_add_pairs(region_map[:-1, :-1], region_map[1:, 1:]) # SE
72-
_add_pairs(region_map[:-1, 1:], region_map[1:, :-1]) # SW
192+
for a, b in zip(lo_arr.tolist(), hi_arr.tolist()):
193+
adjacency[a].add(b)
194+
adjacency[b].add(a)
73195

74196
return adjacency
75197

@@ -79,54 +201,17 @@ def _add_pairs(a, b):
79201
# ---------------------------------------------------------------------------
80202

81203

82-
def _label_all_regions(result, valid, structure):
83-
"""Label connected components per unique value.
84-
85-
Returns
86-
-------
87-
region_map : ndarray of int32
88-
Each pixel mapped to its region id (0 = nodata).
89-
region_val : ndarray of float64
90-
Original raster value for each region id.
91-
n_total : int
92-
Total number of regions + 1 (length of *region_val*).
93-
"""
94-
from scipy.ndimage import label
95-
96-
unique_vals = np.unique(result[valid])
97-
region_map = np.zeros(result.shape, dtype=np.int32)
98-
region_val_list: list[float] = [np.nan] # id 0 = nodata
99-
uid = 1
100-
101-
for v in unique_vals:
102-
mask = (result == v) & valid
103-
labeled, n_features = label(mask, structure=structure)
104-
if n_features > 0:
105-
nonzero = labeled > 0
106-
region_map[nonzero] = labeled[nonzero] + (uid - 1)
107-
region_val_list.extend([float(v)] * n_features)
108-
uid += n_features
109-
110-
region_val = np.array(region_val_list, dtype=np.float64)
111-
return region_map, region_val, uid
112-
113-
114204
def _sieve_numpy(data, threshold, neighborhood, skip_values):
115205
"""Replace connected regions smaller than *threshold* pixels."""
116-
structure = (
117-
np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
118-
if neighborhood == 4
119-
else np.ones((3, 3), dtype=int)
120-
)
121-
122206
result = data.astype(np.float64, copy=True)
123207
is_float = np.issubdtype(data.dtype, np.floating)
124208
valid = ~np.isnan(result) if is_float else np.ones(result.shape, dtype=bool)
125209
skip_set = set(skip_values) if skip_values is not None else set()
126210

127-
for _ in range(50): # convergence limit
128-
region_map, region_val, uid = _label_all_regions(
129-
result, valid, structure
211+
converged = False
212+
for _ in range(_MAX_ITERATIONS):
213+
region_map, region_val, uid = _label_connected(
214+
result, valid, neighborhood
130215
)
131216
region_size = np.bincount(
132217
region_map.ravel(), minlength=uid
@@ -140,6 +225,7 @@ def _sieve_numpy(data, threshold, neighborhood, skip_values):
140225
and region_val[rid] not in skip_set
141226
]
142227
if not small_ids:
228+
converged = True
143229
break
144230

145231
adjacency = _build_adjacency(region_map, neighborhood)
@@ -176,8 +262,17 @@ def _sieve_numpy(data, threshold, neighborhood, skip_values):
176262
merged_any = True
177263

178264
if not merged_any:
265+
converged = True
179266
break
180267

268+
if not converged:
269+
warnings.warn(
270+
f"sieve() did not converge after {_MAX_ITERATIONS} iterations. "
271+
f"The result may still contain regions smaller than "
272+
f"threshold={threshold}.",
273+
stacklevel=3,
274+
)
275+
181276
return result
182277

183278

xrspatial/tests/test_sieve.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,83 @@ def test_sieve_numpy_dask_match():
422422
dk_result = _to_numpy(sieve(dk_raster, threshold=3))
423423

424424
np.testing.assert_array_equal(np_result, dk_result)
425+
426+
427+
# ---------------------------------------------------------------------------
428+
# Convergence warning
429+
# ---------------------------------------------------------------------------
430+
431+
432+
def test_sieve_convergence_warning():
433+
"""Should warn when the iteration limit is reached."""
434+
from unittest.mock import patch
435+
436+
from xrspatial.sieve import _MAX_ITERATIONS
437+
438+
# Create a raster where merging is artificially stalled by
439+
# patching _MAX_ITERATIONS to 0 so the loop never runs.
440+
arr = np.array(
441+
[
442+
[1, 1, 1],
443+
[1, 2, 1],
444+
[1, 1, 1],
445+
],
446+
dtype=np.float64,
447+
)
448+
raster = _make_raster(arr, "numpy")
449+
450+
with patch("xrspatial.sieve._MAX_ITERATIONS", 0):
451+
with pytest.warns(UserWarning, match="did not converge"):
452+
sieve(raster, threshold=2)
453+
454+
455+
# ---------------------------------------------------------------------------
456+
# Larger synthetic rasters
457+
# ---------------------------------------------------------------------------
458+
459+
460+
@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"])
461+
def test_sieve_noisy_classification(backend):
462+
"""Sieve a noisy 100x100 classification with known outcome."""
463+
rng = np.random.RandomState(1162)
464+
# 4-class base raster in quadrants
465+
base = np.zeros((100, 100), dtype=np.float64)
466+
base[:50, :50] = 1
467+
base[:50, 50:] = 2
468+
base[50:, :50] = 3
469+
base[50:, 50:] = 4
470+
471+
# Sprinkle 5 % salt-and-pepper noise
472+
noise_mask = rng.random((100, 100)) < 0.05
473+
noise_vals = rng.choice([1.0, 2.0, 3.0, 4.0], size=(100, 100))
474+
noisy = base.copy()
475+
noisy[noise_mask] = noise_vals[noise_mask]
476+
477+
raster = _make_raster(noisy, backend)
478+
result = sieve(raster, threshold=10)
479+
data = _to_numpy(result)
480+
481+
# After sieving, isolated noise pixels should be gone.
482+
# Each quadrant interior (excluding boundary) should be uniform.
483+
assert np.all(data[5:45, 5:45] == 1.0)
484+
assert np.all(data[5:45, 55:95] == 2.0)
485+
assert np.all(data[55:95, 5:45] == 3.0)
486+
assert np.all(data[55:95, 55:95] == 4.0)
487+
488+
489+
@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"])
490+
def test_sieve_many_small_regions(backend):
491+
"""Checkerboard produces maximum region count; sieve should unify."""
492+
# 20x20 checkerboard: every pixel is its own 1-pixel region
493+
arr = np.zeros((20, 20), dtype=np.float64)
494+
arr[::2, ::2] = 1
495+
arr[1::2, 1::2] = 1
496+
arr[arr == 0] = 2
497+
498+
raster = _make_raster(arr, backend)
499+
result = sieve(raster, threshold=2, neighborhood=4)
500+
data = _to_numpy(result)
501+
502+
# With 4-connectivity every pixel is isolated (size 1).
503+
# threshold=2 forces all to merge. Result should be uniform.
504+
assert len(np.unique(data)) == 1

0 commit comments

Comments
 (0)