Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .claude/commands/rockout.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ through all seven steps below. The prompt is: $ARGUMENTS
2. Pick labels from the repo's existing set. Always include the type label
(`enhancement`, `bug`, or `proposal`). Add topical labels when they fit
(e.g. `gpu`, `performance`, `focal tools`, `hydrology`, etc.).
3. Draft the title and body. Use the repo's issue templates as structure guides:
3. Draft the title and body. Use the repo's issue templates as structure guides
(skip the "Author of Proposal" field -- GitHub already shows the author):
- Enhancement/proposal: follow `.github/ISSUE_TEMPLATE/feature-proposal.md`
- Bug: follow `.github/ISSUE_TEMPLATE/bug_report.md`
4. **Run the body text through the `/humanizer` skill** before creating the issue
Expand Down
1 change: 0 additions & 1 deletion .github/ISSUE_TEMPLATE/feature-proposal.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ assignees: ''

---

**Author of Proposal:**
## Reason or Problem
Describe what the need for this new feature is or what problem this new feature will address.
## Proposal
Expand Down
164 changes: 84 additions & 80 deletions xrspatial/sieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

Given a categorical raster and a pixel-count threshold, replaces
connected regions smaller than the threshold with the value of
their largest spatial neighbor. Pairs with classification functions
(``natural_breaks``, ``reclassify``, etc.) and ``polygonize`` for
cleaning results before vectorization.
their largest spatial neighbor that is already at or above the
threshold. Matches the single-pass semantics of GDAL's
``GDALSieveFilter`` / ``rasterio.features.sieve``.

Pairs with classification functions (``natural_breaks``,
``reclassify``, etc.) and ``polygonize`` for cleaning results
before vectorization.

Supports all four backends: numpy, cupy, dask+numpy, dask+cupy.
"""

from __future__ import annotations

import warnings
from collections import defaultdict
from typing import Sequence

Expand Down Expand Up @@ -40,7 +43,6 @@ class cupy:
ngjit,
)

_MAX_ITERATIONS = 50


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -205,67 +207,75 @@ def _collect(a, b):


def _sieve_numpy(data, threshold, neighborhood, skip_values):
"""Replace connected regions smaller than *threshold* pixels."""
"""Single-pass sieve matching GDAL's ``GDALSieveFilter`` semantics.

A small region is only merged into a neighbor whose size is
**>= threshold**. If no such neighbor exists the region stays.
Regions are processed smallest-first with in-place size updates
so that earlier merges can grow a neighbor above threshold for
later ones within the same pass.
"""
result = data.astype(np.float64, copy=True)
is_float = np.issubdtype(data.dtype, np.floating)
valid = ~np.isnan(result) if is_float else np.ones(result.shape, dtype=bool)
skip_set = set(skip_values) if skip_values is not None else set()

for _ in range(_MAX_ITERATIONS):
region_map, region_val, uid = _label_connected(
result, valid, neighborhood
)
region_size = np.bincount(
region_map.ravel(), minlength=uid
).astype(np.int64)

# Identify small regions eligible for merging
small_ids = [
rid
for rid in range(1, uid)
if region_size[rid] < threshold
and region_val[rid] not in skip_set
]
if not small_ids:
return result, True

adjacency = _build_adjacency(region_map, neighborhood)

# Process smallest regions first so they merge into larger neighbors
small_ids.sort(key=lambda r: region_size[r])

merged_any = False
for rid in small_ids:
if region_size[rid] == 0 or region_size[rid] >= threshold:
continue
region_map, region_val, uid = _label_connected(
result, valid, neighborhood
)
region_size = np.bincount(
region_map.ravel(), minlength=uid
).astype(np.int64)

small_ids = [
rid
for rid in range(1, uid)
if region_size[rid] < threshold
and region_val[rid] not in skip_set
]
if not small_ids:
return result

adjacency = _build_adjacency(region_map, neighborhood)

# Process smallest regions first so earlier merges can grow
# a neighbor above threshold for later candidates.
small_ids.sort(key=lambda r: region_size[r])

for rid in small_ids:
if region_size[rid] == 0 or region_size[rid] >= threshold:
continue

neighbors = adjacency.get(rid)
if not neighbors:
continue # surrounded by nodata only
neighbors = adjacency.get(rid)
if not neighbors:
continue # surrounded by nodata only

largest_nid = max(neighbors, key=lambda n: region_size[n])
mask = region_map == rid
result[mask] = region_val[largest_nid]
# Only merge into a neighbor that is already >= threshold.
valid_neighbors = [
n for n in neighbors if region_size[n] >= threshold
]
if not valid_neighbors:
continue

# Update tracking in place
region_map[mask] = largest_nid
region_size[largest_nid] += region_size[rid]
region_size[rid] = 0
largest_nid = max(valid_neighbors, key=lambda n: region_size[n])
mask = region_map == rid
result[mask] = region_val[largest_nid]

for n in neighbors:
if n != largest_nid:
adjacency[n].discard(rid)
adjacency[n].add(largest_nid)
adjacency.setdefault(largest_nid, set()).add(n)
if largest_nid in adjacency:
adjacency[largest_nid].discard(rid)
del adjacency[rid]
merged_any = True
# Update tracking in place
region_map[mask] = largest_nid
region_size[largest_nid] += region_size[rid]
region_size[rid] = 0

if not merged_any:
return result, True
for n in neighbors:
if n != largest_nid:
adjacency[n].discard(rid)
adjacency[n].add(largest_nid)
adjacency.setdefault(largest_nid, set()).add(n)
if largest_nid in adjacency:
adjacency[largest_nid].discard(rid)
del adjacency[rid]

return result, False
return result


# ---------------------------------------------------------------------------
Expand All @@ -277,10 +287,10 @@ def _sieve_cupy(data, threshold, neighborhood, skip_values):
"""CuPy backend: transfer to CPU, sieve, transfer back."""
import cupy as cp

np_result, converged = _sieve_numpy(
np_result = _sieve_numpy(
data.get(), threshold, neighborhood, skip_values
)
return cp.asarray(np_result), converged
return cp.asarray(np_result)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -320,10 +330,10 @@ def _sieve_dask(data, threshold, neighborhood, skip_values):
)

np_data = data.compute()
result, converged = _sieve_numpy(
result = _sieve_numpy(
np_data, threshold, neighborhood, skip_values
)
return da.from_array(result, chunks=data.chunks), converged
return da.from_array(result, chunks=data.chunks)


def _sieve_dask_cupy(data, threshold, neighborhood, skip_values):
Expand All @@ -345,10 +355,10 @@ def _sieve_dask_cupy(data, threshold, neighborhood, skip_values):
pass

cp_data = data.compute()
result, converged = _sieve_cupy(
result = _sieve_cupy(
cp_data, threshold, neighborhood, skip_values
)
return da.from_array(result, chunks=data.chunks), converged
return da.from_array(result, chunks=data.chunks)


# ---------------------------------------------------------------------------
Expand All @@ -367,7 +377,10 @@ def sieve(

Identifies connected components of same-value pixels and replaces
regions smaller than *threshold* pixels with the value of their
largest spatial neighbor. NaN pixels are always preserved.
largest spatial neighbor that is already at or above *threshold*.
Regions whose only neighbors are also below *threshold* are left
unchanged, matching GDAL's single-pass semantics. NaN pixels
are always preserved.

Parameters
----------
Expand Down Expand Up @@ -417,6 +430,11 @@ def sieve(

Notes
-----
Uses single-pass semantics matching GDAL's ``GDALSieveFilter``.
A small region is only merged into a neighbor whose current size
is >= *threshold*. If no such neighbor exists the region is left
unchanged.

This is a global operation: for dask-backed arrays the entire raster
is computed into memory before sieving. Connected-component labeling
cannot be performed on individual chunks because regions may span
Expand All @@ -442,35 +460,21 @@ def sieve(
data = raster.data

if isinstance(data, np.ndarray):
out, converged = _sieve_numpy(
data, threshold, neighborhood, skip_values
)
out = _sieve_numpy(data, threshold, neighborhood, skip_values)
elif has_cuda_and_cupy() and is_cupy_array(data):
out, converged = _sieve_cupy(
data, threshold, neighborhood, skip_values
)
out = _sieve_cupy(data, threshold, neighborhood, skip_values)
elif da is not None and isinstance(data, da.Array):
if is_dask_cupy(raster):
out, converged = _sieve_dask_cupy(
out = _sieve_dask_cupy(
data, threshold, neighborhood, skip_values
)
else:
out, converged = _sieve_dask(
data, threshold, neighborhood, skip_values
)
out = _sieve_dask(data, threshold, neighborhood, skip_values)
else:
raise TypeError(
f"Unsupported array type {type(data).__name__} for sieve()"
)

if not converged:
warnings.warn(
f"sieve() did not converge after {_MAX_ITERATIONS} iterations. "
f"The result may still contain regions smaller than "
f"threshold={threshold}.",
stacklevel=2,
)

return DataArray(
out,
name=name,
Expand Down
40 changes: 18 additions & 22 deletions xrspatial/tests/test_sieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,12 @@ def test_sieve_four_connectivity(backend):
dtype=np.float64,
)
raster = _make_raster(arr, backend)
# With 4-connectivity: each 1 and 2 forms its own 1-pixel region
# except center which is 1 pixel. All regions are size 1.
# threshold=2 should merge them all.
# With 4-connectivity each pixel is its own 1-pixel region.
# All regions are below threshold=2 and no neighbor is >= 2,
# so nothing merges (GDAL single-pass semantics).
result = sieve(raster, threshold=2, neighborhood=4)
data = _to_numpy(result)
# All pixels should end up the same value (merged into one)
assert len(np.unique(data)) == 1
np.testing.assert_array_equal(data, arr)


@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"])
Expand Down Expand Up @@ -425,29 +424,26 @@ def test_sieve_numpy_dask_match():


# ---------------------------------------------------------------------------
# Convergence warning
# Single-pass: small regions with no above-threshold neighbor stay
# ---------------------------------------------------------------------------


def test_sieve_convergence_warning():
"""Should warn when the iteration limit is reached."""
from unittest.mock import patch

# Create a raster where merging is artificially stalled by
# patching _MAX_ITERATIONS to 0 so the loop never runs.
def test_sieve_small_region_no_large_neighbor():
"""A small region whose only neighbors are also small stays unchanged."""
arr = np.array(
[
[1, 1, 1],
[1, 2, 1],
[1, 1, 1],
[1, 1, 2, 2],
[1, 1, 2, 2],
[3, 3, 4, 4],
[3, 3, 4, 4],
],
dtype=np.float64,
)
raster = _make_raster(arr, "numpy")

with patch("xrspatial.sieve._MAX_ITERATIONS", 0):
with pytest.warns(UserWarning, match="did not converge"):
sieve(raster, threshold=2)
# All regions are size 4, threshold=5: no neighbor is >= 5.
result = sieve(raster, threshold=5)
data = _to_numpy(result)
np.testing.assert_array_equal(data, arr)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -486,7 +482,7 @@ def test_sieve_noisy_classification(backend):

@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"])
def test_sieve_many_small_regions(backend):
"""Checkerboard produces maximum region count; sieve should unify."""
"""Checkerboard: all regions size 1, no neighbor >= threshold."""
# 20x20 checkerboard: every pixel is its own 1-pixel region
arr = np.zeros((20, 20), dtype=np.float64)
arr[::2, ::2] = 1
Expand All @@ -498,5 +494,5 @@ def test_sieve_many_small_regions(backend):
data = _to_numpy(result)

# With 4-connectivity every pixel is isolated (size 1).
# threshold=2 forces all to merge. Result should be uniform.
assert len(np.unique(data)) == 1
# No neighbor is >= threshold=2, so nothing merges (GDAL semantics).
np.testing.assert_array_equal(data, arr)
Loading
Loading