1111
1212from __future__ import annotations
1313
14+ import warnings
1415from collections import defaultdict
1516from typing import Sequence
1617
@@ -36,40 +37,164 @@ class cupy:
3637 has_cuda_and_cupy ,
3738 is_cupy_array ,
3839 is_dask_cupy ,
40+ ngjit ,
3941)
4042
43+ _MAX_ITERATIONS = 50
44+
45+
46+ # ---------------------------------------------------------------------------
47+ # Numba union-find labeling
48+ # ---------------------------------------------------------------------------
49+
50+
51+ @ngjit
52+ def _uf_find (parent , x ):
53+ """Find root of *x* with path halving."""
54+ while parent [x ] != x :
55+ parent [x ] = parent [parent [x ]]
56+ x = parent [x ]
57+ return x
58+
59+
60+ @ngjit
61+ def _uf_union (parent , rank , a , b ):
62+ """Union by rank."""
63+ ra = _uf_find (parent , a )
64+ rb = _uf_find (parent , b )
65+ if ra == rb :
66+ return
67+ if rank [ra ] < rank [rb ]:
68+ parent [ra ] = rb
69+ elif rank [ra ] > rank [rb ]:
70+ parent [rb ] = ra
71+ else :
72+ parent [rb ] = ra
73+ rank [ra ] += 1
74+
75+
76+ @ngjit
77+ def _label_connected (data , valid , neighborhood ):
78+ """Single-pass connected-component labeling via union-find.
79+
80+ Labels connected regions of same-value pixels in one O(n) pass,
81+ replacing the previous approach of calling ``scipy.ndimage.label``
82+ once per unique raster value.
83+
84+ Uses int32 indices internally, so the raster must have fewer than
85+ ~2.1 billion pixels (roughly 46 000 x 46 000).
86+
87+ Returns
88+ -------
89+ region_map : ndarray of int32 (2D)
90+ Each pixel mapped to its region id (0 = nodata).
91+ region_val : ndarray of float64 (1D)
92+ Original raster value for each region id.
93+ n_regions : int
94+ Total number of regions + 1 (length of *region_val*).
95+ """
96+ rows = data .shape [0 ]
97+ cols = data .shape [1 ]
98+ n = rows * cols
99+ parent = np .arange (n , dtype = np .int32 )
100+ rank = np .zeros (n , dtype = np .int32 )
101+
102+ for r in range (rows ):
103+ for c in range (cols ):
104+ if not valid [r , c ]:
105+ continue
106+ idx = r * cols + c
107+ val = data [r , c ]
108+
109+ # Check left (already visited)
110+ if c > 0 and valid [r , c - 1 ] and data [r , c - 1 ] == val :
111+ _uf_union (parent , rank , idx , idx - 1 )
112+ # Check up (already visited)
113+ if r > 0 and valid [r - 1 , c ] and data [r - 1 , c ] == val :
114+ _uf_union (parent , rank , idx , (r - 1 ) * cols + c )
115+
116+ if neighborhood == 8 :
117+ if (
118+ r > 0
119+ and c > 0
120+ and valid [r - 1 , c - 1 ]
121+ and data [r - 1 , c - 1 ] == val
122+ ):
123+ _uf_union (parent , rank , idx , (r - 1 ) * cols + (c - 1 ))
124+ if (
125+ r > 0
126+ and c + 1 < cols
127+ and valid [r - 1 , c + 1 ]
128+ and data [r - 1 , c + 1 ] == val
129+ ):
130+ _uf_union (parent , rank , idx , (r - 1 ) * cols + (c + 1 ))
131+
132+ # Assign contiguous region IDs
133+ region_map_flat = np .zeros (n , dtype = np .int32 )
134+ root_to_id = np .zeros (n , dtype = np .int32 )
135+ region_val_buf = np .full (n + 1 , np .nan , dtype = np .float64 )
136+ next_id = 1
137+
138+ for i in range (n ):
139+ r = i // cols
140+ c = i % cols
141+ if not valid [r , c ]:
142+ continue
143+ root = _uf_find (parent , i )
144+ if root_to_id [root ] == 0 :
145+ root_to_id [root ] = next_id
146+ region_val_buf [next_id ] = data [r , c ]
147+ next_id += 1
148+ region_map_flat [i ] = root_to_id [root ]
149+
150+ region_map = region_map_flat .reshape (rows , cols )
151+ return region_map , region_val_buf [:next_id ], next_id
152+
41153
42154# ---------------------------------------------------------------------------
43155# Adjacency helpers
44156# ---------------------------------------------------------------------------
45157
46158
47159def _build_adjacency (region_map , neighborhood ):
48- """Build a region adjacency dict from a labeled map using vectorized shifts.
160+ """Build a region adjacency dict from a labeled map.
161+
162+ Encodes each (lo, hi) region pair as a single int64 so
163+ deduplication uses fast 1-D ``np.unique`` instead of the slower
164+ ``np.unique(axis=0)`` on 2-D pair arrays.
49165
50166 Returns ``{region_id: set_of_neighbor_ids}``.
51167 """
52- adjacency : dict [int , set [int ]] = defaultdict (set )
168+ max_id = np .int64 (region_map .max ()) + 1
169+ encoded_parts : list [np .ndarray ] = []
53170
54- def _add_pairs (a , b ):
171+ def _collect (a , b ):
55172 mask = (a > 0 ) & (b > 0 ) & (a != b )
56173 if not mask .any ():
57174 return
58- pairs = np .unique (
59- np .column_stack ([a [mask ].ravel (), b [mask ].ravel ()]), axis = 0
60- )
61- for x , y in pairs :
62- adjacency [int (x )].add (int (y ))
63- adjacency [int (y )].add (int (x ))
175+ am = a [mask ].ravel ().astype (np .int64 )
176+ bm = b [mask ].ravel ().astype (np .int64 )
177+ lo = np .minimum (am , bm )
178+ hi = np .maximum (am , bm )
179+ encoded_parts .append (lo * max_id + hi )
180+
181+ _collect (region_map [:- 1 , :], region_map [1 :, :])
182+ _collect (region_map [:, :- 1 ], region_map [:, 1 :])
183+ if neighborhood == 8 :
184+ _collect (region_map [:- 1 , :- 1 ], region_map [1 :, 1 :])
185+ _collect (region_map [:- 1 , 1 :], region_map [1 :, :- 1 ])
186+
187+ adjacency : dict [int , set [int ]] = defaultdict (set )
188+ if not encoded_parts :
189+ return adjacency
64190
65- # 4-connected directions (rook )
66- _add_pairs ( region_map [: - 1 , :], region_map [ 1 :, :]) # vertical
67- _add_pairs ( region_map [:, : - 1 ], region_map [:, 1 :]) # horizontal
191+ encoded = np . unique ( np . concatenate ( encoded_parts ) )
192+ lo_arr = encoded // max_id
193+ hi_arr = encoded % max_id
68194
69- # 8-connected adds diagonals (queen)
70- if neighborhood == 8 :
71- _add_pairs (region_map [:- 1 , :- 1 ], region_map [1 :, 1 :]) # SE
72- _add_pairs (region_map [:- 1 , 1 :], region_map [1 :, :- 1 ]) # SW
195+ for a , b in zip (lo_arr .tolist (), hi_arr .tolist ()):
196+ adjacency [a ].add (b )
197+ adjacency [b ].add (a )
73198
74199 return adjacency
75200
@@ -79,54 +204,16 @@ def _add_pairs(a, b):
79204# ---------------------------------------------------------------------------
80205
81206
82- def _label_all_regions (result , valid , structure ):
83- """Label connected components per unique value.
84-
85- Returns
86- -------
87- region_map : ndarray of int32
88- Each pixel mapped to its region id (0 = nodata).
89- region_val : ndarray of float64
90- Original raster value for each region id.
91- n_total : int
92- Total number of regions + 1 (length of *region_val*).
93- """
94- from scipy .ndimage import label
95-
96- unique_vals = np .unique (result [valid ])
97- region_map = np .zeros (result .shape , dtype = np .int32 )
98- region_val_list : list [float ] = [np .nan ] # id 0 = nodata
99- uid = 1
100-
101- for v in unique_vals :
102- mask = (result == v ) & valid
103- labeled , n_features = label (mask , structure = structure )
104- if n_features > 0 :
105- nonzero = labeled > 0
106- region_map [nonzero ] = labeled [nonzero ] + (uid - 1 )
107- region_val_list .extend ([float (v )] * n_features )
108- uid += n_features
109-
110- region_val = np .array (region_val_list , dtype = np .float64 )
111- return region_map , region_val , uid
112-
113-
114207def _sieve_numpy (data , threshold , neighborhood , skip_values ):
115208 """Replace connected regions smaller than *threshold* pixels."""
116- structure = (
117- np .array ([[0 , 1 , 0 ], [1 , 1 , 1 ], [0 , 1 , 0 ]])
118- if neighborhood == 4
119- else np .ones ((3 , 3 ), dtype = int )
120- )
121-
122209 result = data .astype (np .float64 , copy = True )
123210 is_float = np .issubdtype (data .dtype , np .floating )
124211 valid = ~ np .isnan (result ) if is_float else np .ones (result .shape , dtype = bool )
125212 skip_set = set (skip_values ) if skip_values is not None else set ()
126213
127- for _ in range (50 ): # convergence limit
128- region_map , region_val , uid = _label_all_regions (
129- result , valid , structure
214+ for _ in range (_MAX_ITERATIONS ):
215+ region_map , region_val , uid = _label_connected (
216+ result , valid , neighborhood
130217 )
131218 region_size = np .bincount (
132219 region_map .ravel (), minlength = uid
@@ -140,7 +227,7 @@ def _sieve_numpy(data, threshold, neighborhood, skip_values):
140227 and region_val [rid ] not in skip_set
141228 ]
142229 if not small_ids :
143- break
230+ return result , True
144231
145232 adjacency = _build_adjacency (region_map , neighborhood )
146233
@@ -176,9 +263,9 @@ def _sieve_numpy(data, threshold, neighborhood, skip_values):
176263 merged_any = True
177264
178265 if not merged_any :
179- break
266+ return result , True
180267
181- return result
268+ return result , False
182269
183270
184271# ---------------------------------------------------------------------------
@@ -190,8 +277,10 @@ def _sieve_cupy(data, threshold, neighborhood, skip_values):
190277 """CuPy backend: transfer to CPU, sieve, transfer back."""
191278 import cupy as cp
192279
193- np_result = _sieve_numpy (data .get (), threshold , neighborhood , skip_values )
194- return cp .asarray (np_result )
280+ np_result , converged = _sieve_numpy (
281+ data .get (), threshold , neighborhood , skip_values
282+ )
283+ return cp .asarray (np_result ), converged
195284
196285
197286# ---------------------------------------------------------------------------
@@ -231,8 +320,10 @@ def _sieve_dask(data, threshold, neighborhood, skip_values):
231320 )
232321
233322 np_data = data .compute ()
234- result = _sieve_numpy (np_data , threshold , neighborhood , skip_values )
235- return da .from_array (result , chunks = data .chunks )
323+ result , converged = _sieve_numpy (
324+ np_data , threshold , neighborhood , skip_values
325+ )
326+ return da .from_array (result , chunks = data .chunks ), converged
236327
237328
238329def _sieve_dask_cupy (data , threshold , neighborhood , skip_values ):
@@ -254,8 +345,10 @@ def _sieve_dask_cupy(data, threshold, neighborhood, skip_values):
254345 pass
255346
256347 cp_data = data .compute ()
257- result = _sieve_cupy (cp_data , threshold , neighborhood , skip_values )
258- return da .from_array (result , chunks = data .chunks )
348+ result , converged = _sieve_cupy (
349+ cp_data , threshold , neighborhood , skip_values
350+ )
351+ return da .from_array (result , chunks = data .chunks ), converged
259352
260353
261354# ---------------------------------------------------------------------------
@@ -349,21 +442,35 @@ def sieve(
349442 data = raster .data
350443
351444 if isinstance (data , np .ndarray ):
352- out = _sieve_numpy (data , threshold , neighborhood , skip_values )
445+ out , converged = _sieve_numpy (
446+ data , threshold , neighborhood , skip_values
447+ )
353448 elif has_cuda_and_cupy () and is_cupy_array (data ):
354- out = _sieve_cupy (data , threshold , neighborhood , skip_values )
449+ out , converged = _sieve_cupy (
450+ data , threshold , neighborhood , skip_values
451+ )
355452 elif da is not None and isinstance (data , da .Array ):
356453 if is_dask_cupy (raster ):
357- out = _sieve_dask_cupy (
454+ out , converged = _sieve_dask_cupy (
358455 data , threshold , neighborhood , skip_values
359456 )
360457 else :
361- out = _sieve_dask (data , threshold , neighborhood , skip_values )
458+ out , converged = _sieve_dask (
459+ data , threshold , neighborhood , skip_values
460+ )
362461 else :
363462 raise TypeError (
364463 f"Unsupported array type { type (data ).__name__ } for sieve()"
365464 )
366465
466+ if not converged :
467+ warnings .warn (
468+ f"sieve() did not converge after { _MAX_ITERATIONS } iterations. "
469+ f"The result may still contain regions smaller than "
470+ f"threshold={ threshold } ." ,
471+ stacklevel = 2 ,
472+ )
473+
367474 return DataArray (
368475 out ,
369476 name = name ,
0 commit comments