Skip to content

Commit a370d56

Browse files
authored
Add 3D multi-band support to focal functions (#924)
* Add 3D multi-band support to focal functions (mean, apply, focal_stats, hotspots) Slices along the first dimension, applies the 2D focal op per band, and concatenates results. Fixes _apply_per_band parameter name collision where `func` was passed both positionally and as a keyword argument. Removes stale validation tests that expected 3D rejection. * Mark GPU CPU-fallback backends in feature matrix Flag 17 cells across classify, polygonize, and hydrology sections where CuPy or Dask+CuPy backends accept GPU input but convert to numpy internally rather than running native GPU kernels.
1 parent 4b599fc commit a370d56

File tree

4 files changed

+166
-26
lines changed

4 files changed

+166
-26
lines changed

README.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,13 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
139139

140140
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
141141
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
142-
| [Box Plot](xrspatial/classify.py) | Classifies values into bins based on box plot quartile boundaries | ✅️ ||| |
142+
| [Box Plot](xrspatial/classify.py) | Classifies values into bins based on box plot quartile boundaries | ✅️ ||| 🔄 |
143143
| [Equal Interval](xrspatial/classify.py) | Divides the value range into equal-width bins | ✅️ ||||
144-
| [Head/Tail Breaks](xrspatial/classify.py) | Classifies heavy-tailed distributions using recursive mean splitting | ✅️ || | |
145-
| [Maximum Breaks](xrspatial/classify.py) | Finds natural groupings by maximizing differences between sorted values | ✅️ || | |
146-
| [Natural Breaks](xrspatial/classify.py) | Optimizes class boundaries to minimize within-class variance (Jenks) | ✅️ || | |
147-
| [Percentiles](xrspatial/classify.py) | Assigns classes based on user-defined percentile breakpoints | ✅️ ||| |
148-
| [Quantile](xrspatial/classify.py) | Distributes values into classes with equal observation counts | ✅️ ||| |
144+
| [Head/Tail Breaks](xrspatial/classify.py) | Classifies heavy-tailed distributions using recursive mean splitting | ✅️ || 🔄 | 🔄 |
145+
| [Maximum Breaks](xrspatial/classify.py) | Finds natural groupings by maximizing differences between sorted values | ✅️ || 🔄 | 🔄 |
146+
| [Natural Breaks](xrspatial/classify.py) | Optimizes class boundaries to minimize within-class variance (Jenks) | ✅️ || 🔄 | 🔄 |
147+
| [Percentiles](xrspatial/classify.py) | Assigns classes based on user-defined percentile breakpoints | ✅️ ||| 🔄 |
148+
| [Quantile](xrspatial/classify.py) | Distributes values into classes with equal observation counts | ✅️ ||| 🔄 |
149149
| [Reclassify](xrspatial/classify.py) | Remaps pixel values to new classes using a user-defined lookup | ✅️ ||||
150150
| [Std Mean](xrspatial/classify.py) | Classifies values by standard deviation intervals from the mean | ✅️ ||||
151151

@@ -216,7 +216,7 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
216216

217217
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
218218
|:-----|:------------|:------------------:|:-----------------:|:---------------------:|:---------------------:|
219-
| [Polygonize](xrspatial/polygonize.py) | Converts contiguous regions of equal value into vector polygons | ✅️ | ✅️ | ✅️ | ✅️ |
219+
| [Polygonize](xrspatial/polygonize.py) | Converts contiguous regions of equal value into vector polygons | ✅️ | ✅️ | ✅️ | 🔄 |
220220

221221
--------
222222

@@ -245,13 +245,13 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
245245
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
246246
| [Flow Direction (D8)](xrspatial/flow_direction.py) | Computes D8 flow direction from each cell toward the steepest downhill neighbor | ✅️ | ✅️ | ✅️ | ✅️ |
247247
| [Flow Direction (Dinf)](xrspatial/flow_direction_dinf.py) | Computes D-infinity flow direction as a continuous angle toward the steepest downslope facet | ✅️ | ✅️ | ✅️ | ✅️ |
248-
| [Flow Accumulation (D8)](xrspatial/flow_accumulation.py) | Counts upstream cells draining through each cell in a D8 flow direction grid | ✅️ | ✅️ | ✅️ | ✅️ |
249-
| [Watershed](xrspatial/watershed.py) | Labels each cell with the pour point it drains to via D8 flow direction | ✅️ | ✅️ | ✅️ | ✅️ |
250-
| [Basins](xrspatial/watershed.py) | Delineates drainage basins by labeling each cell with its outlet ID | ✅️ | ✅️ | ✅️ | ✅️ |
251-
| [Stream Order](xrspatial/stream_order.py) | Assigns Strahler or Shreve stream order to cells in a drainage network | ✅️ | ✅️ | ✅️ | ✅️ |
252-
| [Stream Link](xrspatial/stream_link.py) | Assigns unique IDs to each stream segment between junctions | ✅️ | ✅️ | ✅️ | ✅️ |
253-
| [Snap Pour Point](xrspatial/snap_pour_point.py) | Snaps pour points to the highest-accumulation cell within a search radius | ✅️ | ✅️ | ✅️ | ✅️ |
254-
| [Flow Path](xrspatial/flow_path.py) | Traces downstream flow paths from start points through a D8 direction grid | ✅️ | ✅️ | ✅️ | ✅️ |
248+
| [Flow Accumulation (D8)](xrspatial/flow_accumulation.py) | Counts upstream cells draining through each cell in a D8 flow direction grid | ✅️ | ✅️ | ✅️ | 🔄 |
249+
| [Watershed](xrspatial/watershed.py) | Labels each cell with the pour point it drains to via D8 flow direction | ✅️ | ✅️ | ✅️ | 🔄 |
250+
| [Basins](xrspatial/watershed.py) | Delineates drainage basins by labeling each cell with its outlet ID | ✅️ | ✅️ | ✅️ | 🔄 |
251+
| [Stream Order](xrspatial/stream_order.py) | Assigns Strahler or Shreve stream order to cells in a drainage network | ✅️ | ✅️ | ✅️ | 🔄 |
252+
| [Stream Link](xrspatial/stream_link.py) | Assigns unique IDs to each stream segment between junctions | ✅️ | ✅️ | ✅️ | 🔄 |
253+
| [Snap Pour Point](xrspatial/snap_pour_point.py) | Snaps pour points to the highest-accumulation cell within a search radius | ✅️ | ✅️ | 🔄 | 🔄 |
254+
| [Flow Path](xrspatial/flow_path.py) | Traces downstream flow paths from start points through a D8 direction grid | ✅️ | ✅️ | 🔄 | 🔄 |
255255

256256
-----------
257257

xrspatial/focal.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ class cupy(object):
3535
cuda_args, ngjit, not_implemented_func)
3636
from xrspatial.dataset_support import supports_dataset
3737

38+
39+
def _apply_per_band(band_func, agg, *args, **kwargs):
40+
"""Apply a 2D focal function independently to each band of a 3D array.
41+
42+
Slices along the first dimension, calls *band_func* on each 2D slice,
43+
and stacks the results back together.
44+
"""
45+
band_dim = agg.dims[0]
46+
slices = []
47+
for i in range(agg.sizes[band_dim]):
48+
band = agg.isel({band_dim: i})
49+
slices.append(band_func(band, *args, **kwargs))
50+
return xr.concat(slices, dim=band_dim)
51+
52+
3853
# TODO: Make convolution more generic with numba first-class functions.
3954

4055

@@ -281,9 +296,14 @@ def mean(agg, passes=1, excludes=[np.nan], name='mean', boundary='nan'):
281296
Dimensions without coordinates: dim_0, dim_1
282297
"""
283298

284-
_validate_raster(agg, func_name='mean', name='agg')
299+
_validate_raster(agg, func_name='mean', name='agg', ndim=(2, 3))
285300
_validate_scalar(passes, func_name='mean', name='passes', dtype=int, min_val=1)
286301
_validate_boundary(boundary)
302+
303+
if agg.ndim == 3:
304+
return _apply_per_band(mean, agg, passes=passes, excludes=excludes,
305+
name=name, boundary=boundary)
306+
287307
out = agg.data.astype(float)
288308
for i in range(passes):
289309
out = _mean(out, tuple(excludes), boundary)
@@ -512,7 +532,11 @@ def apply(raster, kernel, func=_calc_mean, name='focal_apply', boundary='nan'):
512532
[2. , 2. , 2. , 1.5]])
513533
Dimensions without coordinates: y, x
514534
"""
515-
_validate_raster(raster, func_name='apply', name='raster')
535+
_validate_raster(raster, func_name='apply', name='raster', ndim=(2, 3))
536+
537+
if raster.ndim == 3:
538+
return _apply_per_band(apply, raster, kernel=kernel, func=func,
539+
name=name, boundary=boundary)
516540

517541
# Validate the kernel
518542
kernel = custom_kernel(kernel)
@@ -957,7 +981,11 @@ def focal_stats(agg,
957981
* stats (stats) object 'min' 'sum'
958982
Dimensions without coordinates: dim_0, dim_1
959983
"""
960-
_validate_raster(agg, func_name='focal_stats', name='agg')
984+
_validate_raster(agg, func_name='focal_stats', name='agg', ndim=(2, 3))
985+
986+
if agg.ndim == 3:
987+
return _apply_per_band(focal_stats, agg, kernel=kernel,
988+
stats_funcs=stats_funcs, boundary=boundary)
961989

962990
# Validate the kernel
963991
kernel = custom_kernel(kernel)
@@ -1237,7 +1265,11 @@ def hotspots(raster, kernel, boundary='nan'):
12371265
Dimensions without coordinates: dim_0, dim_1
12381266
"""
12391267

1240-
_validate_raster(raster, func_name='hotspots', name='raster')
1268+
_validate_raster(raster, func_name='hotspots', name='raster', ndim=(2, 3))
1269+
1270+
if raster.ndim == 3:
1271+
return _apply_per_band(hotspots, raster, kernel=kernel,
1272+
boundary=boundary)
12411273

12421274
_validate_boundary(boundary)
12431275

xrspatial/tests/test_focal.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,3 +756,119 @@ def test_convolution_2d_boundary_no_nan(boundary):
756756
assert not np.any(np.isnan(da_result.data.compute()))
757757
np.testing.assert_allclose(
758758
np_result.data, da_result.data.compute(), equal_nan=True, rtol=1e-5)
759+
760+
761+
# --- 3D (multi-band) focal tests ---
762+
763+
764+
@pytest.fixture
765+
def rgb_data():
766+
rng = np.random.default_rng(123)
767+
return rng.random((3, 12, 14)).astype(np.float64)
768+
769+
770+
def test_mean_3d_numpy(rgb_data):
771+
agg = xr.DataArray(rgb_data, dims=['band', 'y', 'x'])
772+
result = mean(agg)
773+
assert result.shape == (3, 12, 14)
774+
assert result.dims == ('band', 'y', 'x')
775+
for i in range(3):
776+
band_result = mean(agg.isel(band=i))
777+
np.testing.assert_allclose(result.isel(band=i).data, band_result.data)
778+
779+
780+
@dask_array_available
781+
def test_mean_3d_dask(rgb_data):
782+
dask_data = da.from_array(rgb_data, chunks=(1, 6, 7))
783+
agg = xr.DataArray(dask_data, dims=['band', 'y', 'x'])
784+
result = mean(agg)
785+
assert result.shape == (3, 12, 14)
786+
# compare against numpy per-band
787+
numpy_agg = xr.DataArray(rgb_data, dims=['band', 'y', 'x'])
788+
numpy_result = mean(numpy_agg)
789+
np.testing.assert_allclose(
790+
result.data.compute(), numpy_result.data, equal_nan=True, rtol=1e-5)
791+
792+
793+
def test_apply_3d_numpy(rgb_data):
794+
kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
795+
agg = xr.DataArray(rgb_data, dims=['band', 'y', 'x'])
796+
result = apply(agg, kernel)
797+
assert result.shape == (3, 12, 14)
798+
assert result.dims == ('band', 'y', 'x')
799+
for i in range(3):
800+
band_result = apply(agg.isel(band=i), kernel)
801+
np.testing.assert_allclose(result.isel(band=i).data, band_result.data)
802+
803+
804+
@dask_array_available
805+
def test_apply_3d_dask(rgb_data):
806+
kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
807+
dask_data = da.from_array(rgb_data, chunks=(1, 6, 7))
808+
agg = xr.DataArray(dask_data, dims=['band', 'y', 'x'])
809+
result = apply(agg, kernel)
810+
assert result.shape == (3, 12, 14)
811+
numpy_agg = xr.DataArray(rgb_data, dims=['band', 'y', 'x'])
812+
numpy_result = apply(numpy_agg, kernel)
813+
np.testing.assert_allclose(
814+
result.data.compute(), numpy_result.data, equal_nan=True, rtol=1e-5)
815+
816+
817+
def test_focal_stats_3d_numpy(rgb_data):
818+
kernel = custom_kernel(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]))
819+
stats = ['mean', 'max']
820+
agg = xr.DataArray(rgb_data, dims=['band', 'y', 'x'])
821+
result = focal_stats(agg, kernel, stats_funcs=stats)
822+
# 3D input -> 4D output: (band, stats, y, x)
823+
assert result.shape == (3, 2, 12, 14)
824+
for i in range(3):
825+
band_result = focal_stats(agg.isel(band=i), kernel, stats_funcs=stats)
826+
np.testing.assert_allclose(
827+
result.isel(band=i).data, band_result.data, equal_nan=True)
828+
829+
830+
@dask_array_available
831+
def test_focal_stats_3d_dask(rgb_data):
832+
kernel = custom_kernel(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]))
833+
stats = ['mean', 'max']
834+
dask_data = da.from_array(rgb_data, chunks=(1, 6, 7))
835+
agg = xr.DataArray(dask_data, dims=['band', 'y', 'x'])
836+
result = focal_stats(agg, kernel, stats_funcs=stats)
837+
assert result.shape == (3, 2, 12, 14)
838+
numpy_agg = xr.DataArray(rgb_data, dims=['band', 'y', 'x'])
839+
numpy_result = focal_stats(numpy_agg, kernel, stats_funcs=stats)
840+
np.testing.assert_allclose(
841+
result.data.compute(), numpy_result.data, equal_nan=True, rtol=1e-5)
842+
843+
844+
def test_hotspots_3d_numpy():
845+
rng = np.random.default_rng(42)
846+
data_2d = rng.standard_normal((10, 12)).astype(np.float64)
847+
# stack 3 copies with different scales to avoid zero-std bands
848+
data_3d = np.stack([data_2d, data_2d * 2, data_2d * 0.5])
849+
kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.float64)
850+
agg = xr.DataArray(data_3d, dims=['band', 'y', 'x'])
851+
result = hotspots(agg, kernel)
852+
assert result.shape == (3, 10, 12)
853+
assert result.dims == ('band', 'y', 'x')
854+
for i in range(3):
855+
band_result = hotspots(agg.isel(band=i), kernel)
856+
np.testing.assert_array_equal(result.isel(band=i).data, band_result.data)
857+
858+
859+
@dask_array_available
860+
def test_hotspots_3d_dask():
861+
rng = np.random.default_rng(42)
862+
data_2d = rng.standard_normal((10, 12)).astype(np.float64)
863+
data_3d = np.stack([data_2d, data_2d * 2, data_2d * 0.5])
864+
kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.float64)
865+
# numpy reference
866+
numpy_agg = xr.DataArray(data_3d, dims=['band', 'y', 'x'])
867+
numpy_result = hotspots(numpy_agg, kernel)
868+
# dask
869+
dask_data = da.from_array(data_3d, chunks=(1, 5, 6))
870+
dask_agg = xr.DataArray(dask_data, dims=['band', 'y', 'x'])
871+
dask_result = hotspots(dask_agg, kernel)
872+
assert dask_result.shape == (3, 10, 12)
873+
np.testing.assert_array_equal(
874+
dask_result.data.compute(), numpy_result.data)

xrspatial/tests/test_validation.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,6 @@ def test_surface_rejects_3d(self, func, args):
224224
with pytest.raises(ValueError, match="2D"):
225225
func(self._agg_3d, *args)
226226

227-
def test_mean_rejects_3d(self):
228-
with pytest.raises(ValueError, match="2D"):
229-
mean(self._agg_3d)
230-
231-
def test_focal_apply_rejects_3d(self):
232-
with pytest.raises(ValueError, match="2D"):
233-
focal_apply(self._agg_3d, _kernel_3x3)
234-
235227
@pytest.mark.parametrize('func', [proximity, allocation, direction])
236228
def test_proximity_rejects_3d(self, func):
237229
with pytest.raises(ValueError, match="2D"):

0 commit comments

Comments
 (0)