Skip to content

Commit f1c39ee

Browse files
committed
Add sieve filter to remove small raster clumps (#1149)
Implements a sieve() function that identifies connected components of same-value pixels and replaces regions smaller than a threshold with the value of their largest spatial neighbor. Supports 4- and 8-connectivity, selective sieving via skip_values, and all four backends (numpy, cupy via CPU fallback, dask+numpy, dask+cupy).
1 parent cb5bc02 commit f1c39ee

File tree

3 files changed

+798
-0
lines changed

3 files changed

+798
-0
lines changed

xrspatial/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
from xrspatial.hydro import stream_link_d8, stream_link_dinf, stream_link_mfd # noqa
105105
from xrspatial.hydro import stream_order # noqa: unified wrapper
106106
from xrspatial.hydro import stream_order_d8, stream_order_dinf, stream_order_mfd # noqa
107+
from xrspatial.sieve import sieve # noqa
107108
from xrspatial.sky_view_factor import sky_view_factor # noqa
108109
from xrspatial.slope import slope # noqa
109110
from xrspatial.surface_distance import surface_allocation # noqa

xrspatial/sieve.py

Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
"""Sieve filter for removing small raster clumps.
2+
3+
Given a categorical raster and a pixel-count threshold, replaces
4+
connected regions smaller than the threshold with the value of
5+
their largest spatial neighbor. Pairs with classification functions
6+
(``natural_breaks``, ``reclassify``, etc.) and ``polygonize`` for
7+
cleaning results before vectorization.
8+
9+
Supports all four backends: numpy, cupy, dask+numpy, dask+cupy.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from collections import defaultdict
15+
from typing import Sequence
16+
17+
import numpy as np
18+
import xarray as xr
19+
from xarray import DataArray
20+
21+
try:
22+
import cupy
23+
except ImportError:
24+
25+
class cupy:
26+
ndarray = False
27+
28+
29+
try:
30+
import dask.array as da
31+
except ImportError:
32+
da = None
33+
34+
from xrspatial.utils import (
35+
_validate_raster,
36+
has_cuda_and_cupy,
37+
is_cupy_array,
38+
is_dask_cupy,
39+
)
40+
41+
42+
# ---------------------------------------------------------------------------
43+
# Adjacency helpers
44+
# ---------------------------------------------------------------------------
45+
46+
47+
def _build_adjacency(region_map, neighborhood):
48+
"""Build a region adjacency dict from a labeled map using vectorized shifts.
49+
50+
Returns ``{region_id: set_of_neighbor_ids}``.
51+
"""
52+
adjacency: dict[int, set[int]] = defaultdict(set)
53+
54+
def _add_pairs(a, b):
55+
mask = (a > 0) & (b > 0) & (a != b)
56+
if not mask.any():
57+
return
58+
pairs = np.unique(
59+
np.column_stack([a[mask].ravel(), b[mask].ravel()]), axis=0
60+
)
61+
for x, y in pairs:
62+
adjacency[int(x)].add(int(y))
63+
adjacency[int(y)].add(int(x))
64+
65+
# 4-connected directions (rook)
66+
_add_pairs(region_map[:-1, :], region_map[1:, :]) # vertical
67+
_add_pairs(region_map[:, :-1], region_map[:, 1:]) # horizontal
68+
69+
# 8-connected adds diagonals (queen)
70+
if neighborhood == 8:
71+
_add_pairs(region_map[:-1, :-1], region_map[1:, 1:]) # SE
72+
_add_pairs(region_map[:-1, 1:], region_map[1:, :-1]) # SW
73+
74+
return adjacency
75+
76+
77+
# ---------------------------------------------------------------------------
78+
# numpy backend
79+
# ---------------------------------------------------------------------------
80+
81+
82+
def _label_all_regions(result, valid, structure):
83+
"""Label connected components per unique value.
84+
85+
Returns
86+
-------
87+
region_map : ndarray of int32
88+
Each pixel mapped to its region id (0 = nodata).
89+
region_val : ndarray of float64
90+
Original raster value for each region id.
91+
n_total : int
92+
Total number of regions + 1 (length of *region_val*).
93+
"""
94+
from scipy.ndimage import label
95+
96+
unique_vals = np.unique(result[valid])
97+
region_map = np.zeros(result.shape, dtype=np.int32)
98+
region_val_list: list[float] = [np.nan] # id 0 = nodata
99+
uid = 1
100+
101+
for v in unique_vals:
102+
mask = (result == v) & valid
103+
labeled, n_features = label(mask, structure=structure)
104+
if n_features > 0:
105+
nonzero = labeled > 0
106+
region_map[nonzero] = labeled[nonzero] + (uid - 1)
107+
region_val_list.extend([float(v)] * n_features)
108+
uid += n_features
109+
110+
region_val = np.array(region_val_list, dtype=np.float64)
111+
return region_map, region_val, uid
112+
113+
114+
def _sieve_numpy(data, threshold, neighborhood, skip_values):
115+
"""Replace connected regions smaller than *threshold* pixels."""
116+
structure = (
117+
np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
118+
if neighborhood == 4
119+
else np.ones((3, 3), dtype=int)
120+
)
121+
122+
result = data.astype(np.float64, copy=True)
123+
is_float = np.issubdtype(data.dtype, np.floating)
124+
valid = ~np.isnan(result) if is_float else np.ones(result.shape, dtype=bool)
125+
skip_set = set(skip_values) if skip_values is not None else set()
126+
127+
for _ in range(50): # convergence limit
128+
region_map, region_val, uid = _label_all_regions(
129+
result, valid, structure
130+
)
131+
region_size = np.bincount(
132+
region_map.ravel(), minlength=uid
133+
).astype(np.int64)
134+
135+
# Identify small regions eligible for merging
136+
small_ids = [
137+
rid
138+
for rid in range(1, uid)
139+
if region_size[rid] < threshold
140+
and region_val[rid] not in skip_set
141+
]
142+
if not small_ids:
143+
break
144+
145+
adjacency = _build_adjacency(region_map, neighborhood)
146+
147+
# Process smallest regions first so they merge into larger neighbors
148+
small_ids.sort(key=lambda r: region_size[r])
149+
150+
merged_any = False
151+
for rid in small_ids:
152+
if region_size[rid] == 0 or region_size[rid] >= threshold:
153+
continue
154+
155+
neighbors = adjacency.get(rid)
156+
if not neighbors:
157+
continue # surrounded by nodata only
158+
159+
largest_nid = max(neighbors, key=lambda n: region_size[n])
160+
mask = region_map == rid
161+
result[mask] = region_val[largest_nid]
162+
163+
# Update tracking in place
164+
region_map[mask] = largest_nid
165+
region_size[largest_nid] += region_size[rid]
166+
region_size[rid] = 0
167+
168+
for n in neighbors:
169+
if n != largest_nid:
170+
adjacency[n].discard(rid)
171+
adjacency[n].add(largest_nid)
172+
adjacency.setdefault(largest_nid, set()).add(n)
173+
if largest_nid in adjacency:
174+
adjacency[largest_nid].discard(rid)
175+
del adjacency[rid]
176+
merged_any = True
177+
178+
if not merged_any:
179+
break
180+
181+
return result
182+
183+
184+
# ---------------------------------------------------------------------------
185+
# cupy backend (CPU fallback – merge logic is serial)
186+
# ---------------------------------------------------------------------------
187+
188+
189+
def _sieve_cupy(data, threshold, neighborhood, skip_values):
190+
"""CuPy backend: transfer to CPU, sieve, transfer back."""
191+
import cupy as cp
192+
193+
np_result = _sieve_numpy(data.get(), threshold, neighborhood, skip_values)
194+
return cp.asarray(np_result)
195+
196+
197+
# ---------------------------------------------------------------------------
198+
# dask backends
199+
# ---------------------------------------------------------------------------
200+
201+
202+
def _available_memory_bytes():
203+
"""Best-effort estimate of available memory in bytes."""
204+
try:
205+
with open("/proc/meminfo", "r") as f:
206+
for line in f:
207+
if line.startswith("MemAvailable:"):
208+
return int(line.split()[1]) * 1024
209+
except (OSError, ValueError, IndexError):
210+
pass
211+
try:
212+
import psutil
213+
214+
return psutil.virtual_memory().available
215+
except (ImportError, AttributeError):
216+
pass
217+
return 2 * 1024**3
218+
219+
220+
def _sieve_dask(data, threshold, neighborhood, skip_values):
221+
"""Dask+numpy backend: compute to numpy, sieve, wrap back."""
222+
avail = _available_memory_bytes()
223+
estimated_bytes = np.prod(data.shape) * data.dtype.itemsize
224+
if estimated_bytes * 5 > 0.5 * avail:
225+
raise MemoryError(
226+
f"sieve() needs the full array in memory "
227+
f"(~{estimated_bytes * 5 / 1e9:.1f} GB) but only "
228+
f"~{avail / 1e9:.1f} GB is available. Connected-component "
229+
f"labeling is a global operation that cannot be chunked. "
230+
f"Consider downsampling or tiling the input manually."
231+
)
232+
233+
np_data = data.compute()
234+
result = _sieve_numpy(np_data, threshold, neighborhood, skip_values)
235+
return da.from_array(result, chunks=data.chunks)
236+
237+
238+
def _sieve_dask_cupy(data, threshold, neighborhood, skip_values):
239+
"""Dask+CuPy backend: compute to cupy, sieve via CPU fallback, wrap back."""
240+
estimated_bytes = np.prod(data.shape) * data.dtype.itemsize
241+
try:
242+
import cupy as cp
243+
244+
free_gpu, _total = cp.cuda.Device().mem_info
245+
if estimated_bytes * 5 > 0.5 * free_gpu:
246+
raise MemoryError(
247+
f"sieve() needs the full array on GPU "
248+
f"(~{estimated_bytes * 5 / 1e9:.1f} GB) but only "
249+
f"~{free_gpu / 1e9:.1f} GB free. Connected-component "
250+
f"labeling is a global operation that cannot be chunked. "
251+
f"Consider downsampling or tiling the input manually."
252+
)
253+
except (ImportError, AttributeError):
254+
pass
255+
256+
cp_data = data.compute()
257+
result = _sieve_cupy(cp_data, threshold, neighborhood, skip_values)
258+
return da.from_array(result, chunks=data.chunks)
259+
260+
261+
# ---------------------------------------------------------------------------
262+
# Public API
263+
# ---------------------------------------------------------------------------
264+
265+
266+
def sieve(
267+
raster: xr.DataArray,
268+
threshold: int = 10,
269+
neighborhood: int = 4,
270+
skip_values: Sequence[float] | None = None,
271+
name: str = "sieve",
272+
) -> xr.DataArray:
273+
"""Remove small connected regions from a classified raster.
274+
275+
Identifies connected components of same-value pixels and replaces
276+
regions smaller than *threshold* pixels with the value of their
277+
largest spatial neighbor. NaN pixels are always preserved.
278+
279+
Parameters
280+
----------
281+
raster : xr.DataArray
282+
2D classified or categorical raster.
283+
threshold : int, default=10
284+
Minimum region size in pixels. Regions with fewer pixels
285+
are replaced by their largest neighbor's value.
286+
neighborhood : int, default=4
287+
Pixel connectivity: 4 (rook) or 8 (queen).
288+
skip_values : sequence of float, optional
289+
Category values whose regions are never replaced, regardless
290+
of size. These regions can still serve as merge targets for
291+
neighboring small regions.
292+
name : str, default='sieve'
293+
Output DataArray name.
294+
295+
Returns
296+
-------
297+
xr.DataArray
298+
Sieved raster with the same shape, dims, coords, and attrs.
299+
300+
Examples
301+
--------
302+
.. sourcecode:: python
303+
304+
>>> import numpy as np
305+
>>> import xarray as xr
306+
>>> from xrspatial.sieve import sieve
307+
308+
>>> # Classified raster with salt-and-pepper noise
309+
>>> arr = np.array([[1, 1, 1, 2, 2],
310+
... [1, 3, 1, 2, 2],
311+
... [1, 1, 1, 2, 2],
312+
... [2, 2, 2, 2, 2],
313+
... [2, 2, 2, 2, 2]], dtype=np.float64)
314+
>>> raster = xr.DataArray(arr, dims=['y', 'x'])
315+
316+
>>> # Remove regions smaller than 2 pixels
317+
>>> result = sieve(raster, threshold=2)
318+
>>> print(result.values)
319+
[[1. 1. 1. 2. 2.]
320+
[1. 1. 1. 2. 2.]
321+
[1. 1. 1. 2. 2.]
322+
[2. 2. 2. 2. 2.]
323+
[2. 2. 2. 2. 2.]]
324+
325+
Notes
326+
-----
327+
This is a global operation: for dask-backed arrays the entire raster
328+
is computed into memory before sieving. Connected-component labeling
329+
cannot be performed on individual chunks because regions may span
330+
chunk boundaries.
331+
332+
The CuPy backends use a CPU fallback for the merge step, which is
333+
inherently serial.
334+
335+
See Also
336+
--------
337+
xrspatial.zonal.regions : Connected-component labeling.
338+
xrspatial.classify.natural_breaks : Classification that may produce
339+
noisy output suitable for sieving.
340+
"""
341+
_validate_raster(raster, func_name="sieve", name="raster", ndim=2)
342+
343+
if neighborhood not in (4, 8):
344+
raise ValueError("`neighborhood` must be 4 or 8")
345+
346+
if not isinstance(threshold, (int, np.integer)) or threshold < 1:
347+
raise ValueError("`threshold` must be a positive integer")
348+
349+
data = raster.data
350+
351+
if isinstance(data, np.ndarray):
352+
out = _sieve_numpy(data, threshold, neighborhood, skip_values)
353+
elif has_cuda_and_cupy() and is_cupy_array(data):
354+
out = _sieve_cupy(data, threshold, neighborhood, skip_values)
355+
elif da is not None and isinstance(data, da.Array):
356+
if is_dask_cupy(raster):
357+
out = _sieve_dask_cupy(
358+
data, threshold, neighborhood, skip_values
359+
)
360+
else:
361+
out = _sieve_dask(data, threshold, neighborhood, skip_values)
362+
else:
363+
raise TypeError(
364+
f"Unsupported array type {type(data).__name__} for sieve()"
365+
)
366+
367+
return DataArray(
368+
out,
369+
name=name,
370+
dims=raster.dims,
371+
coords=raster.coords,
372+
attrs=raster.attrs,
373+
)

0 commit comments

Comments
 (0)