Skip to content

Commit cffcc4e

Browse files
authored
Replace O(n⁴) regions() with scipy union-find, add dask/cupy backends (#898)
The old _area_connectivity algorithm did full-array scans inside the main pixel loop for conflict resolution, giving O(n⁴) worst case. Replace it with scipy.ndimage.label (union-find, ~O(n)) per unique value, and add backend dispatch for dask, cupy, and dask+cupy arrays.
1 parent 05fc5f0 commit cffcc4e

File tree

4 files changed

+181
-146
lines changed

4 files changed

+181
-146
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
238238
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
239239
| [Apply](xrspatial/zonal.py) | Applies a custom function to each zone in a classified raster | ✅️ | ✅️ | | |
240240
| [Crop](xrspatial/zonal.py) | Extracts the bounding rectangle of a specific zone | ✅️ | | | |
241-
| [Regions](xrspatial/zonal.py) | Identifies connected regions of non-zero cells | | | | |
241+
| [Regions](xrspatial/zonal.py) | Identifies connected regions of non-zero cells | ✅️ | ✅️ | ✅️ | ✅️ |
242242
| [Trim](xrspatial/zonal.py) | Removes nodata border rows and columns from a raster | ✅️ | | | |
243243
| [Zonal Statistics](xrspatial/zonal.py) | Computes summary statistics for a value raster within each zone | ✅️ | ✅️| | |
244244
| [Zonal Cross Tabulate](xrspatial/zonal.py) | Cross-tabulates agreement between two categorical rasters | ✅️ | ✅️| | |

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include_package_data = True
2121
install_requires =
2222
datashader >= 0.15.0
2323
numba
24+
scipy
2425
xarray
2526
numpy
2627
packages = find:

xrspatial/tests/test_zonal.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -996,54 +996,123 @@ def create_test_arr(arr):
996996
return raster
997997

998998

999-
def test_regions_four_pixel_connectivity_int():
999+
def _make_regions_raster(arr, backend):
1000+
"""Create a test raster from *arr* for the given backend."""
1001+
raster = create_test_raster(arr, backend)
1002+
return raster
1003+
1004+
1005+
def _count_unique(raster_regions):
1006+
"""Count unique values in a regions result, computing dask if needed."""
1007+
data = raster_regions.data
1008+
if da is not None and isinstance(data, da.Array):
1009+
data = data.compute()
1010+
return len(np.unique(data))
1011+
1012+
1013+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
1014+
def test_regions_four_pixel_connectivity_int(backend):
10001015
arr = np.array([[0, 0, 0, 0],
10011016
[0, 4, 0, 0],
10021017
[1, 4, 4, 0],
10031018
[1, 1, 1, 0],
10041019
[0, 0, 0, 0]], dtype=np.int64)
1005-
raster = create_test_arr(arr)
1020+
raster = _make_regions_raster(arr, backend)
10061021
raster_regions = regions(raster, neighborhood=4)
1007-
assert len(np.unique(raster_regions.data)) == 3
1022+
assert _count_unique(raster_regions) == 3
10081023
assert raster.shape == raster_regions.shape
10091024

10101025

1011-
def test_regions_four_pixel_connectivity_float():
1026+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
1027+
def test_regions_four_pixel_connectivity_float(backend):
10121028
arr = np.array([[0, 0, 0, np.nan],
10131029
[0, 4, 0, 0],
10141030
[1, 4, 4, 0],
10151031
[1, 1, 1, 0],
10161032
[0, 0, 0, 0]], dtype=np.float64)
1017-
raster = create_test_arr(arr)
1033+
raster = _make_regions_raster(arr, backend)
10181034
raster_regions = regions(raster, neighborhood=4)
1019-
assert len(np.unique(raster_regions.data)) == 4
1035+
assert _count_unique(raster_regions) == 4
10201036
assert raster.shape == raster_regions.shape
10211037

10221038

1023-
def test_regions_eight_pixel_connectivity_int():
1039+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
1040+
def test_regions_eight_pixel_connectivity_int(backend):
10241041
arr = np.array([[1, 0, 0, 0],
10251042
[0, 1, 0, 0],
10261043
[0, 0, 1, 0],
10271044
[0, 0, 0, 1],
10281045
[0, 0, 0, 1]], dtype=np.int64)
1029-
raster = create_test_arr(arr)
1046+
raster = _make_regions_raster(arr, backend)
10301047
raster_regions = regions(raster, neighborhood=8)
1031-
assert len(np.unique(raster_regions.data)) == 2
1048+
assert _count_unique(raster_regions) == 2
10321049
assert raster.shape == raster_regions.shape
10331050

10341051

1035-
def test_regions_eight_pixel_connectivity_float():
1052+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
1053+
def test_regions_eight_pixel_connectivity_float(backend):
10361054
arr = np.array([[1, 0, 0, np.nan],
10371055
[0, 1, 0, 0],
10381056
[0, 0, 1, 0],
10391057
[0, 0, 0, 1],
10401058
[0, 0, 0, 1]], dtype=np.float64)
1041-
raster = create_test_arr(arr)
1059+
raster = _make_regions_raster(arr, backend)
10421060
raster_regions = regions(raster, neighborhood=8)
1043-
assert len(np.unique(raster_regions.data)) == 3
1061+
assert _count_unique(raster_regions) == 3
1062+
assert raster.shape == raster_regions.shape
1063+
1064+
1065+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
1066+
def test_regions_single_pixel(backend):
1067+
arr = np.array([[np.nan, np.nan],
1068+
[np.nan, 5.0]], dtype=np.float64)
1069+
raster = _make_regions_raster(arr, backend)
1070+
raster_regions = regions(raster, neighborhood=4)
1071+
data = raster_regions.data
1072+
if da is not None and isinstance(data, da.Array):
1073+
data = data.compute()
1074+
assert np.nansum(data > 0) == 1
1075+
assert raster.shape == raster_regions.shape
1076+
1077+
1078+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
1079+
def test_regions_all_same_value(backend):
1080+
arr = np.full((4, 4), 7.0, dtype=np.float64)
1081+
raster = _make_regions_raster(arr, backend)
1082+
raster_regions = regions(raster, neighborhood=4)
1083+
assert _count_unique(raster_regions) == 1
1084+
assert raster.shape == raster_regions.shape
1085+
1086+
1087+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
1088+
def test_regions_all_nan(backend):
1089+
arr = np.full((3, 3), np.nan, dtype=np.float64)
1090+
raster = _make_regions_raster(arr, backend)
1091+
raster_regions = regions(raster, neighborhood=4)
1092+
data = raster_regions.data
1093+
if da is not None and isinstance(data, da.Array):
1094+
data = data.compute()
1095+
assert np.all(np.isnan(data))
10441096
assert raster.shape == raster_regions.shape
10451097

10461098

1099+
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
1100+
def test_regions_numpy_dask_match(backend):
1101+
"""Verify numpy and dask backends produce identical results."""
1102+
arr = np.array([[1, 1, 0, 2],
1103+
[1, 1, 0, 2],
1104+
[0, 0, 0, 0],
1105+
[3, 3, 0, 3]], dtype=np.float64)
1106+
raster = _make_regions_raster(arr, backend)
1107+
result = regions(raster, neighborhood=4)
1108+
data = result.data
1109+
if da is not None and isinstance(data, da.Array):
1110+
data = data.compute()
1111+
# 0-region is connected, 1-region, 2-region, and two separate 3-regions
1112+
assert _count_unique(result) == 5
1113+
assert result.shape == arr.shape
1114+
1115+
10471116
def test_trim():
10481117
arr = np.array([[0, 0, 0, 0],
10491118
[0, 4, 0, 0],

0 commit comments

Comments
 (0)