1111
1212from __future__ import annotations
1313
14+ import warnings
1415from collections import defaultdict
1516from typing import Sequence
1617
@@ -36,40 +37,161 @@ class cupy:
3637 has_cuda_and_cupy ,
3738 is_cupy_array ,
3839 is_dask_cupy ,
40+ ngjit ,
3941)
4042
43+ _MAX_ITERATIONS = 50
44+
45+
46+ # ---------------------------------------------------------------------------
47+ # Numba union-find labeling
48+ # ---------------------------------------------------------------------------
49+
50+
51+ @ngjit
52+ def _uf_find (parent , x ):
53+ """Find root of *x* with path halving."""
54+ while parent [x ] != x :
55+ parent [x ] = parent [parent [x ]]
56+ x = parent [x ]
57+ return x
58+
59+
60+ @ngjit
61+ def _uf_union (parent , rank , a , b ):
62+ """Union by rank."""
63+ ra = _uf_find (parent , a )
64+ rb = _uf_find (parent , b )
65+ if ra == rb :
66+ return
67+ if rank [ra ] < rank [rb ]:
68+ parent [ra ] = rb
69+ elif rank [ra ] > rank [rb ]:
70+ parent [rb ] = ra
71+ else :
72+ parent [rb ] = ra
73+ rank [ra ] += 1
74+
75+
76+ @ngjit
77+ def _label_connected (data , valid , neighborhood ):
78+ """Single-pass connected-component labeling via union-find.
79+
80+ Labels connected regions of same-value pixels in one O(n) pass,
81+ replacing the previous approach of calling ``scipy.ndimage.label``
82+ once per unique raster value.
83+
84+ Returns
85+ -------
86+ region_map : ndarray of int32 (2D)
87+ Each pixel mapped to its region id (0 = nodata).
88+ region_val : ndarray of float64 (1D)
89+ Original raster value for each region id.
90+ n_regions : int
91+ Total number of regions + 1 (length of *region_val*).
92+ """
93+ rows = data .shape [0 ]
94+ cols = data .shape [1 ]
95+ n = rows * cols
96+ parent = np .arange (n , dtype = np .int32 )
97+ rank = np .zeros (n , dtype = np .int32 )
98+
99+ for r in range (rows ):
100+ for c in range (cols ):
101+ if not valid [r , c ]:
102+ continue
103+ idx = r * cols + c
104+ val = data [r , c ]
105+
106+ # Check left (already visited)
107+ if c > 0 and valid [r , c - 1 ] and data [r , c - 1 ] == val :
108+ _uf_union (parent , rank , idx , idx - 1 )
109+ # Check up (already visited)
110+ if r > 0 and valid [r - 1 , c ] and data [r - 1 , c ] == val :
111+ _uf_union (parent , rank , idx , (r - 1 ) * cols + c )
112+
113+ if neighborhood == 8 :
114+ if (
115+ r > 0
116+ and c > 0
117+ and valid [r - 1 , c - 1 ]
118+ and data [r - 1 , c - 1 ] == val
119+ ):
120+ _uf_union (parent , rank , idx , (r - 1 ) * cols + (c - 1 ))
121+ if (
122+ r > 0
123+ and c + 1 < cols
124+ and valid [r - 1 , c + 1 ]
125+ and data [r - 1 , c + 1 ] == val
126+ ):
127+ _uf_union (parent , rank , idx , (r - 1 ) * cols + (c + 1 ))
128+
129+ # Assign contiguous region IDs
130+ region_map_flat = np .zeros (n , dtype = np .int32 )
131+ root_to_id = np .zeros (n , dtype = np .int32 )
132+ region_val_buf = np .full (n + 1 , np .nan , dtype = np .float64 )
133+ next_id = 1
134+
135+ for i in range (n ):
136+ r = i // cols
137+ c = i % cols
138+ if not valid [r , c ]:
139+ continue
140+ root = _uf_find (parent , i )
141+ if root_to_id [root ] == 0 :
142+ root_to_id [root ] = next_id
143+ region_val_buf [next_id ] = data [r , c ]
144+ next_id += 1
145+ region_map_flat [i ] = root_to_id [root ]
146+
147+ region_map = region_map_flat .reshape (rows , cols )
148+ return region_map , region_val_buf [:next_id ], next_id
149+
41150
42151# ---------------------------------------------------------------------------
43152# Adjacency helpers
44153# ---------------------------------------------------------------------------
45154
46155
47156def _build_adjacency (region_map , neighborhood ):
48- """Build a region adjacency dict from a labeled map using vectorized shifts.
157+ """Build a region adjacency dict from a labeled map.
158+
159+ Encodes each (lo, hi) region pair as a single int64 so
160+ deduplication uses fast 1-D ``np.unique`` instead of the slower
161+ ``np.unique(axis=0)`` on 2-D pair arrays.
49162
50163 Returns ``{region_id: set_of_neighbor_ids}``.
51164 """
52- adjacency : dict [int , set [int ]] = defaultdict (set )
165+ max_id = np .int64 (region_map .max ()) + 1
166+ encoded_parts : list [np .ndarray ] = []
53167
54- def _add_pairs (a , b ):
168+ def _collect (a , b ):
55169 mask = (a > 0 ) & (b > 0 ) & (a != b )
56170 if not mask .any ():
57171 return
58- pairs = np .unique (
59- np .column_stack ([a [mask ].ravel (), b [mask ].ravel ()]), axis = 0
60- )
61- for x , y in pairs :
62- adjacency [int (x )].add (int (y ))
63- adjacency [int (y )].add (int (x ))
172+ am = a [mask ].ravel ().astype (np .int64 )
173+ bm = b [mask ].ravel ().astype (np .int64 )
174+ lo = np .minimum (am , bm )
175+ hi = np .maximum (am , bm )
176+ encoded_parts .append (lo * max_id + hi )
177+
178+ _collect (region_map [:- 1 , :], region_map [1 :, :])
179+ _collect (region_map [:, :- 1 ], region_map [:, 1 :])
180+ if neighborhood == 8 :
181+ _collect (region_map [:- 1 , :- 1 ], region_map [1 :, 1 :])
182+ _collect (region_map [:- 1 , 1 :], region_map [1 :, :- 1 ])
183+
184+ adjacency : dict [int , set [int ]] = defaultdict (set )
185+ if not encoded_parts :
186+ return adjacency
64187
65- # 4-connected directions (rook )
66- _add_pairs ( region_map [: - 1 , :], region_map [ 1 :, :]) # vertical
67- _add_pairs ( region_map [:, : - 1 ], region_map [:, 1 :]) # horizontal
188+ encoded = np . unique ( np . concatenate ( encoded_parts ) )
189+ lo_arr = encoded // max_id
190+ hi_arr = encoded % max_id
68191
69- # 8-connected adds diagonals (queen)
70- if neighborhood == 8 :
71- _add_pairs (region_map [:- 1 , :- 1 ], region_map [1 :, 1 :]) # SE
72- _add_pairs (region_map [:- 1 , 1 :], region_map [1 :, :- 1 ]) # SW
192+ for a , b in zip (lo_arr .tolist (), hi_arr .tolist ()):
193+ adjacency [a ].add (b )
194+ adjacency [b ].add (a )
73195
74196 return adjacency
75197
@@ -79,54 +201,17 @@ def _add_pairs(a, b):
79201# ---------------------------------------------------------------------------
80202
81203
82- def _label_all_regions (result , valid , structure ):
83- """Label connected components per unique value.
84-
85- Returns
86- -------
87- region_map : ndarray of int32
88- Each pixel mapped to its region id (0 = nodata).
89- region_val : ndarray of float64
90- Original raster value for each region id.
91- n_total : int
92- Total number of regions + 1 (length of *region_val*).
93- """
94- from scipy .ndimage import label
95-
96- unique_vals = np .unique (result [valid ])
97- region_map = np .zeros (result .shape , dtype = np .int32 )
98- region_val_list : list [float ] = [np .nan ] # id 0 = nodata
99- uid = 1
100-
101- for v in unique_vals :
102- mask = (result == v ) & valid
103- labeled , n_features = label (mask , structure = structure )
104- if n_features > 0 :
105- nonzero = labeled > 0
106- region_map [nonzero ] = labeled [nonzero ] + (uid - 1 )
107- region_val_list .extend ([float (v )] * n_features )
108- uid += n_features
109-
110- region_val = np .array (region_val_list , dtype = np .float64 )
111- return region_map , region_val , uid
112-
113-
114204def _sieve_numpy (data , threshold , neighborhood , skip_values ):
115205 """Replace connected regions smaller than *threshold* pixels."""
116- structure = (
117- np .array ([[0 , 1 , 0 ], [1 , 1 , 1 ], [0 , 1 , 0 ]])
118- if neighborhood == 4
119- else np .ones ((3 , 3 ), dtype = int )
120- )
121-
122206 result = data .astype (np .float64 , copy = True )
123207 is_float = np .issubdtype (data .dtype , np .floating )
124208 valid = ~ np .isnan (result ) if is_float else np .ones (result .shape , dtype = bool )
125209 skip_set = set (skip_values ) if skip_values is not None else set ()
126210
127- for _ in range (50 ): # convergence limit
128- region_map , region_val , uid = _label_all_regions (
129- result , valid , structure
211+ converged = False
212+ for _ in range (_MAX_ITERATIONS ):
213+ region_map , region_val , uid = _label_connected (
214+ result , valid , neighborhood
130215 )
131216 region_size = np .bincount (
132217 region_map .ravel (), minlength = uid
@@ -140,6 +225,7 @@ def _sieve_numpy(data, threshold, neighborhood, skip_values):
140225 and region_val [rid ] not in skip_set
141226 ]
142227 if not small_ids :
228+ converged = True
143229 break
144230
145231 adjacency = _build_adjacency (region_map , neighborhood )
@@ -176,8 +262,17 @@ def _sieve_numpy(data, threshold, neighborhood, skip_values):
176262 merged_any = True
177263
178264 if not merged_any :
265+ converged = True
179266 break
180267
268+ if not converged :
269+ warnings .warn (
270+ f"sieve() did not converge after { _MAX_ITERATIONS } iterations. "
271+ f"The result may still contain regions smaller than "
272+ f"threshold={ threshold } ." ,
273+ stacklevel = 3 ,
274+ )
275+
181276 return result
182277
183278
0 commit comments