Skip to content

Commit 3b56af8

Browse files
authored
Speed up sieve labeling and add convergence warning (#1163)
* Speed up sieve labeling and add convergence warning (#1162) Replace per-value scipy.ndimage.label calls with a single-pass numba union-find for connected-component labeling. Rework the adjacency builder to use 1-D np.unique on encoded int64 pairs instead of the slower np.unique(axis=0) on 2-D arrays. Add a UserWarning when the 50-iteration merge loop doesn't converge. * Address code review: fix warning stacklevel, clean up test (#1162) Move the convergence warning from _sieve_numpy into sieve() so stacklevel=2 always points at the caller regardless of backend. Return (result, converged) tuple from all backend functions. Remove unused import in test_sieve_convergence_warning. Add int32 limit note to _label_connected docstring.
1 parent e05e7eb commit 3b56af8

File tree

2 files changed

+255
-70
lines changed

2 files changed

+255
-70
lines changed

xrspatial/sieve.py

Lines changed: 177 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import warnings
1415
from collections import defaultdict
1516
from typing import Sequence
1617

@@ -36,40 +37,164 @@ class cupy:
3637
has_cuda_and_cupy,
3738
is_cupy_array,
3839
is_dask_cupy,
40+
ngjit,
3941
)
4042

43+
_MAX_ITERATIONS = 50
44+
45+
46+
# ---------------------------------------------------------------------------
47+
# Numba union-find labeling
48+
# ---------------------------------------------------------------------------
49+
50+
51+
@ngjit
52+
def _uf_find(parent, x):
53+
"""Find root of *x* with path halving."""
54+
while parent[x] != x:
55+
parent[x] = parent[parent[x]]
56+
x = parent[x]
57+
return x
58+
59+
60+
@ngjit
61+
def _uf_union(parent, rank, a, b):
62+
"""Union by rank."""
63+
ra = _uf_find(parent, a)
64+
rb = _uf_find(parent, b)
65+
if ra == rb:
66+
return
67+
if rank[ra] < rank[rb]:
68+
parent[ra] = rb
69+
elif rank[ra] > rank[rb]:
70+
parent[rb] = ra
71+
else:
72+
parent[rb] = ra
73+
rank[ra] += 1
74+
75+
76+
@ngjit
77+
def _label_connected(data, valid, neighborhood):
78+
"""Single-pass connected-component labeling via union-find.
79+
80+
Labels connected regions of same-value pixels in one O(n) pass,
81+
replacing the previous approach of calling ``scipy.ndimage.label``
82+
once per unique raster value.
83+
84+
Uses int32 indices internally, so the raster must have fewer than
85+
~2.1 billion pixels (roughly 46 000 x 46 000).
86+
87+
Returns
88+
-------
89+
region_map : ndarray of int32 (2D)
90+
Each pixel mapped to its region id (0 = nodata).
91+
region_val : ndarray of float64 (1D)
92+
Original raster value for each region id.
93+
n_regions : int
94+
Total number of regions + 1 (length of *region_val*).
95+
"""
96+
rows = data.shape[0]
97+
cols = data.shape[1]
98+
n = rows * cols
99+
parent = np.arange(n, dtype=np.int32)
100+
rank = np.zeros(n, dtype=np.int32)
101+
102+
for r in range(rows):
103+
for c in range(cols):
104+
if not valid[r, c]:
105+
continue
106+
idx = r * cols + c
107+
val = data[r, c]
108+
109+
# Check left (already visited)
110+
if c > 0 and valid[r, c - 1] and data[r, c - 1] == val:
111+
_uf_union(parent, rank, idx, idx - 1)
112+
# Check up (already visited)
113+
if r > 0 and valid[r - 1, c] and data[r - 1, c] == val:
114+
_uf_union(parent, rank, idx, (r - 1) * cols + c)
115+
116+
if neighborhood == 8:
117+
if (
118+
r > 0
119+
and c > 0
120+
and valid[r - 1, c - 1]
121+
and data[r - 1, c - 1] == val
122+
):
123+
_uf_union(parent, rank, idx, (r - 1) * cols + (c - 1))
124+
if (
125+
r > 0
126+
and c + 1 < cols
127+
and valid[r - 1, c + 1]
128+
and data[r - 1, c + 1] == val
129+
):
130+
_uf_union(parent, rank, idx, (r - 1) * cols + (c + 1))
131+
132+
# Assign contiguous region IDs
133+
region_map_flat = np.zeros(n, dtype=np.int32)
134+
root_to_id = np.zeros(n, dtype=np.int32)
135+
region_val_buf = np.full(n + 1, np.nan, dtype=np.float64)
136+
next_id = 1
137+
138+
for i in range(n):
139+
r = i // cols
140+
c = i % cols
141+
if not valid[r, c]:
142+
continue
143+
root = _uf_find(parent, i)
144+
if root_to_id[root] == 0:
145+
root_to_id[root] = next_id
146+
region_val_buf[next_id] = data[r, c]
147+
next_id += 1
148+
region_map_flat[i] = root_to_id[root]
149+
150+
region_map = region_map_flat.reshape(rows, cols)
151+
return region_map, region_val_buf[:next_id], next_id
152+
41153

42154
# ---------------------------------------------------------------------------
43155
# Adjacency helpers
44156
# ---------------------------------------------------------------------------
45157

46158

47159
def _build_adjacency(region_map, neighborhood):
48-
"""Build a region adjacency dict from a labeled map using vectorized shifts.
160+
"""Build a region adjacency dict from a labeled map.
161+
162+
Encodes each (lo, hi) region pair as a single int64 so
163+
deduplication uses fast 1-D ``np.unique`` instead of the slower
164+
``np.unique(axis=0)`` on 2-D pair arrays.
49165
50166
Returns ``{region_id: set_of_neighbor_ids}``.
51167
"""
52-
adjacency: dict[int, set[int]] = defaultdict(set)
168+
max_id = np.int64(region_map.max()) + 1
169+
encoded_parts: list[np.ndarray] = []
53170

54-
def _add_pairs(a, b):
171+
def _collect(a, b):
55172
mask = (a > 0) & (b > 0) & (a != b)
56173
if not mask.any():
57174
return
58-
pairs = np.unique(
59-
np.column_stack([a[mask].ravel(), b[mask].ravel()]), axis=0
60-
)
61-
for x, y in pairs:
62-
adjacency[int(x)].add(int(y))
63-
adjacency[int(y)].add(int(x))
175+
am = a[mask].ravel().astype(np.int64)
176+
bm = b[mask].ravel().astype(np.int64)
177+
lo = np.minimum(am, bm)
178+
hi = np.maximum(am, bm)
179+
encoded_parts.append(lo * max_id + hi)
180+
181+
_collect(region_map[:-1, :], region_map[1:, :])
182+
_collect(region_map[:, :-1], region_map[:, 1:])
183+
if neighborhood == 8:
184+
_collect(region_map[:-1, :-1], region_map[1:, 1:])
185+
_collect(region_map[:-1, 1:], region_map[1:, :-1])
186+
187+
adjacency: dict[int, set[int]] = defaultdict(set)
188+
if not encoded_parts:
189+
return adjacency
64190

65-
# 4-connected directions (rook)
66-
_add_pairs(region_map[:-1, :], region_map[1:, :]) # vertical
67-
_add_pairs(region_map[:, :-1], region_map[:, 1:]) # horizontal
191+
encoded = np.unique(np.concatenate(encoded_parts))
192+
lo_arr = encoded // max_id
193+
hi_arr = encoded % max_id
68194

69-
# 8-connected adds diagonals (queen)
70-
if neighborhood == 8:
71-
_add_pairs(region_map[:-1, :-1], region_map[1:, 1:]) # SE
72-
_add_pairs(region_map[:-1, 1:], region_map[1:, :-1]) # SW
195+
for a, b in zip(lo_arr.tolist(), hi_arr.tolist()):
196+
adjacency[a].add(b)
197+
adjacency[b].add(a)
73198

74199
return adjacency
75200

@@ -79,54 +204,16 @@ def _add_pairs(a, b):
79204
# ---------------------------------------------------------------------------
80205

81206

82-
def _label_all_regions(result, valid, structure):
83-
"""Label connected components per unique value.
84-
85-
Returns
86-
-------
87-
region_map : ndarray of int32
88-
Each pixel mapped to its region id (0 = nodata).
89-
region_val : ndarray of float64
90-
Original raster value for each region id.
91-
n_total : int
92-
Total number of regions + 1 (length of *region_val*).
93-
"""
94-
from scipy.ndimage import label
95-
96-
unique_vals = np.unique(result[valid])
97-
region_map = np.zeros(result.shape, dtype=np.int32)
98-
region_val_list: list[float] = [np.nan] # id 0 = nodata
99-
uid = 1
100-
101-
for v in unique_vals:
102-
mask = (result == v) & valid
103-
labeled, n_features = label(mask, structure=structure)
104-
if n_features > 0:
105-
nonzero = labeled > 0
106-
region_map[nonzero] = labeled[nonzero] + (uid - 1)
107-
region_val_list.extend([float(v)] * n_features)
108-
uid += n_features
109-
110-
region_val = np.array(region_val_list, dtype=np.float64)
111-
return region_map, region_val, uid
112-
113-
114207
def _sieve_numpy(data, threshold, neighborhood, skip_values):
115208
"""Replace connected regions smaller than *threshold* pixels."""
116-
structure = (
117-
np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
118-
if neighborhood == 4
119-
else np.ones((3, 3), dtype=int)
120-
)
121-
122209
result = data.astype(np.float64, copy=True)
123210
is_float = np.issubdtype(data.dtype, np.floating)
124211
valid = ~np.isnan(result) if is_float else np.ones(result.shape, dtype=bool)
125212
skip_set = set(skip_values) if skip_values is not None else set()
126213

127-
for _ in range(50): # convergence limit
128-
region_map, region_val, uid = _label_all_regions(
129-
result, valid, structure
214+
for _ in range(_MAX_ITERATIONS):
215+
region_map, region_val, uid = _label_connected(
216+
result, valid, neighborhood
130217
)
131218
region_size = np.bincount(
132219
region_map.ravel(), minlength=uid
@@ -140,7 +227,7 @@ def _sieve_numpy(data, threshold, neighborhood, skip_values):
140227
and region_val[rid] not in skip_set
141228
]
142229
if not small_ids:
143-
break
230+
return result, True
144231

145232
adjacency = _build_adjacency(region_map, neighborhood)
146233

@@ -176,9 +263,9 @@ def _sieve_numpy(data, threshold, neighborhood, skip_values):
176263
merged_any = True
177264

178265
if not merged_any:
179-
break
266+
return result, True
180267

181-
return result
268+
return result, False
182269

183270

184271
# ---------------------------------------------------------------------------
@@ -190,8 +277,10 @@ def _sieve_cupy(data, threshold, neighborhood, skip_values):
190277
"""CuPy backend: transfer to CPU, sieve, transfer back."""
191278
import cupy as cp
192279

193-
np_result = _sieve_numpy(data.get(), threshold, neighborhood, skip_values)
194-
return cp.asarray(np_result)
280+
np_result, converged = _sieve_numpy(
281+
data.get(), threshold, neighborhood, skip_values
282+
)
283+
return cp.asarray(np_result), converged
195284

196285

197286
# ---------------------------------------------------------------------------
@@ -231,8 +320,10 @@ def _sieve_dask(data, threshold, neighborhood, skip_values):
231320
)
232321

233322
np_data = data.compute()
234-
result = _sieve_numpy(np_data, threshold, neighborhood, skip_values)
235-
return da.from_array(result, chunks=data.chunks)
323+
result, converged = _sieve_numpy(
324+
np_data, threshold, neighborhood, skip_values
325+
)
326+
return da.from_array(result, chunks=data.chunks), converged
236327

237328

238329
def _sieve_dask_cupy(data, threshold, neighborhood, skip_values):
@@ -254,8 +345,10 @@ def _sieve_dask_cupy(data, threshold, neighborhood, skip_values):
254345
pass
255346

256347
cp_data = data.compute()
257-
result = _sieve_cupy(cp_data, threshold, neighborhood, skip_values)
258-
return da.from_array(result, chunks=data.chunks)
348+
result, converged = _sieve_cupy(
349+
cp_data, threshold, neighborhood, skip_values
350+
)
351+
return da.from_array(result, chunks=data.chunks), converged
259352

260353

261354
# ---------------------------------------------------------------------------
@@ -349,21 +442,35 @@ def sieve(
349442
data = raster.data
350443

351444
if isinstance(data, np.ndarray):
352-
out = _sieve_numpy(data, threshold, neighborhood, skip_values)
445+
out, converged = _sieve_numpy(
446+
data, threshold, neighborhood, skip_values
447+
)
353448
elif has_cuda_and_cupy() and is_cupy_array(data):
354-
out = _sieve_cupy(data, threshold, neighborhood, skip_values)
449+
out, converged = _sieve_cupy(
450+
data, threshold, neighborhood, skip_values
451+
)
355452
elif da is not None and isinstance(data, da.Array):
356453
if is_dask_cupy(raster):
357-
out = _sieve_dask_cupy(
454+
out, converged = _sieve_dask_cupy(
358455
data, threshold, neighborhood, skip_values
359456
)
360457
else:
361-
out = _sieve_dask(data, threshold, neighborhood, skip_values)
458+
out, converged = _sieve_dask(
459+
data, threshold, neighborhood, skip_values
460+
)
362461
else:
363462
raise TypeError(
364463
f"Unsupported array type {type(data).__name__} for sieve()"
365464
)
366465

466+
if not converged:
467+
warnings.warn(
468+
f"sieve() did not converge after {_MAX_ITERATIONS} iterations. "
469+
f"The result may still contain regions smaller than "
470+
f"threshold={threshold}.",
471+
stacklevel=2,
472+
)
473+
367474
return DataArray(
368475
out,
369476
name=name,

0 commit comments

Comments
 (0)