Skip to content

Commit 17c50c8

Browse files
committed
Fixes #134: add xr.Dataset as input type for appropriate modules
Add transparent Dataset support via two decorators in a new dataset_support module. Single-input functions (slope, aspect, curvature, hillshade, focal.mean, 10 classify functions, 3 proximity functions) iterate over data_vars and return a Dataset. Multi-input functions (10 multispectral indices) accept a Dataset with band-name kwargs. zonal.stats merges per-variable DataFrames with prefixed columns. Includes 18 new tests.
1 parent 623eba0 commit 17c50c8

File tree

11 files changed

+352
-4
lines changed

11 files changed

+352
-4
lines changed

xrspatial/aspect.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from xrspatial.utils import _extract_latlon_coords
2020
from xrspatial.utils import cuda_args
2121
from xrspatial.utils import ngjit
22+
from xrspatial.dataset_support import supports_dataset
2223

2324

2425
def _geodesic_cuda_dims(shape):
@@ -270,6 +271,7 @@ def _run_dask_cupy_geodesic(data, lat_2d, lon_2d, a2, b2, z_factor):
270271
# Public API
271272
# =====================================================================
272273

274+
@supports_dataset
273275
def aspect(agg: xr.DataArray,
274276
name: Optional[str] = 'aspect',
275277
method: str = 'planar',

xrspatial/classify.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class cupy(object):
2525
import numpy as np
2626

2727
from xrspatial.utils import ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func
28+
from xrspatial.dataset_support import supports_dataset
2829

2930

3031
@ngjit
@@ -83,6 +84,7 @@ def _run_dask_cupy_binary(data, values_cupy):
8384
return out
8485

8586

87+
@supports_dataset
8688
def binary(agg, values, name='binary'):
8789
"""
8890
Binarize a data array based on a set of values. Data that equals to a value in the set will be
@@ -266,6 +268,7 @@ def _bin(agg, bins, new_values):
266268
return out
267269

268270

271+
@supports_dataset
269272
def reclassify(agg: xr.DataArray,
270273
bins: List[int],
271274
new_values: List[int],
@@ -416,6 +419,7 @@ def _quantile(agg, k):
416419
return out
417420

418421

422+
@supports_dataset
419423
def quantile(agg: xr.DataArray,
420424
k: int = 4,
421425
name: Optional[str] = 'quantile') -> xr.DataArray:
@@ -723,6 +727,7 @@ def _run_dask_cupy_natural_break(agg, num_sample, k):
723727
return out
724728

725729

730+
@supports_dataset
726731
def natural_breaks(agg: xr.DataArray,
727732
num_sample: Optional[int] = 20000,
728733
name: Optional[str] = 'natural_breaks',
@@ -854,6 +859,7 @@ def _run_equal_interval(agg, k, module):
854859
return out
855860

856861

862+
@supports_dataset
857863
def equal_interval(agg: xr.DataArray,
858864
k: int = 5,
859865
name: Optional[str] = 'equal_interval') -> xr.DataArray:
@@ -952,6 +958,7 @@ def _run_std_mean(agg, module):
952958
return out
953959

954960

961+
@supports_dataset
955962
def std_mean(agg: xr.DataArray,
956963
name: Optional[str] = 'std_mean') -> xr.DataArray:
957964
"""
@@ -1044,6 +1051,7 @@ def _run_dask_head_tail_breaks(agg):
10441051
return out
10451052

10461053

1054+
@supports_dataset
10471055
def head_tail_breaks(agg: xr.DataArray,
10481056
name: Optional[str] = 'head_tail_breaks') -> xr.DataArray:
10491057
"""
@@ -1096,6 +1104,7 @@ def _run_dask_cupy_percentiles(data, pct):
10961104
return _run_percentiles(data_cpu, pct, da)
10971105

10981106

1107+
@supports_dataset
10991108
def percentiles(agg: xr.DataArray,
11001109
pct: Optional[List] = None,
11011110
name: Optional[str] = 'percentiles') -> xr.DataArray:
@@ -1212,6 +1221,7 @@ def _run_dask_cupy_maximum_breaks(agg, k):
12121221
return out
12131222

12141223

1224+
@supports_dataset
12151225
def maximum_breaks(agg: xr.DataArray,
12161226
k: int = 5,
12171227
name: Optional[str] = 'maximum_breaks') -> xr.DataArray:
@@ -1312,6 +1322,7 @@ def _run_dask_cupy_box_plot(agg, hinge):
13121322
return out
13131323

13141324

1325+
@supports_dataset
13151326
def box_plot(agg: xr.DataArray,
13161327
hinge: float = 1.5,
13171328
name: Optional[str] = 'box_plot') -> xr.DataArray:

xrspatial/curvature.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class cupy(object):
2525
from xrspatial.utils import cuda_args
2626
from xrspatial.utils import get_dataarray_resolution
2727
from xrspatial.utils import ngjit
28+
from xrspatial.dataset_support import supports_dataset
2829

2930

3031
@ngjit
@@ -107,6 +108,7 @@ def _run_dask_cupy(data: da.Array,
107108
return out
108109

109110

111+
@supports_dataset
110112
def curvature(agg: xr.DataArray,
111113
name: Optional[str] = 'curvature') -> xr.DataArray:
112114
"""

xrspatial/dataset_support.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Decorators for transparent xr.Dataset support on xr.DataArray functions."""
2+
3+
from __future__ import annotations
4+
5+
import functools
6+
import inspect
7+
8+
import xarray as xr
9+
10+
11+
def supports_dataset(func):
12+
"""Decorator that lets single-input DataArray functions accept a Dataset.
13+
14+
When a Dataset is passed as the first argument, the wrapped function
15+
is called on each data variable and the results are collected into
16+
a new Dataset.
17+
"""
18+
sig = inspect.signature(func)
19+
has_name_param = 'name' in sig.parameters
20+
21+
@functools.wraps(func)
22+
def wrapper(agg, *args, **kwargs):
23+
if isinstance(agg, xr.Dataset):
24+
results = {}
25+
for var_name in agg.data_vars:
26+
kw = dict(kwargs)
27+
if has_name_param:
28+
kw['name'] = var_name
29+
results[var_name] = func(agg[var_name], *args, **kw)
30+
return xr.Dataset(results, attrs=agg.attrs)
31+
return func(agg, *args, **kwargs)
32+
33+
return wrapper
34+
35+
36+
def supports_dataset_bands(**band_param_map):
37+
"""Decorator for multi-input functions that take separate band DataArrays.
38+
39+
Enables passing a single Dataset with keyword arguments that map
40+
band aliases to Dataset variable names.
41+
42+
Example::
43+
44+
@supports_dataset_bands(nir='nir_agg', red='red_agg')
45+
def ndvi(nir_agg, red_agg, name='ndvi'): ...
46+
47+
# Enables:
48+
ndvi(ds, nir='band_8', red='band_4')
49+
"""
50+
51+
def decorator(func):
52+
@functools.wraps(func)
53+
def wrapper(*args, **kwargs):
54+
if args and isinstance(args[0], xr.Dataset):
55+
ds = args[0]
56+
func_kwargs = {}
57+
used = set()
58+
for alias, param in band_param_map.items():
59+
if alias not in kwargs:
60+
raise TypeError(
61+
f"'{alias}' keyword required when passing a Dataset"
62+
)
63+
var_name = kwargs[alias]
64+
if var_name not in ds.data_vars:
65+
raise ValueError(
66+
f"'{var_name}' not in Dataset. "
67+
f"Available: {list(ds.data_vars)}"
68+
)
69+
func_kwargs[param] = ds[var_name]
70+
used.add(alias)
71+
# Pass through remaining kwargs (name, soil_factor, etc.)
72+
for k, v in kwargs.items():
73+
if k not in used:
74+
func_kwargs[k] = v
75+
return func(**func_kwargs)
76+
return func(*args, **kwargs)
77+
78+
return wrapper
79+
80+
return decorator

xrspatial/focal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class cupy(object):
2929

3030
from xrspatial.convolution import convolve_2d, custom_kernel
3131
from xrspatial.utils import ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func
32+
from xrspatial.dataset_support import supports_dataset
3233

3334
# TODO: Make convolution more generic with numba first-class functions.
3435

@@ -158,6 +159,7 @@ def _mean(data, excludes):
158159
return out
159160

160161

162+
@supports_dataset
161163
def mean(agg, passes=1, excludes=[np.nan], name='mean'):
162164
"""
163165
Returns Mean filtered array using a 3x3 window.

xrspatial/hillshade.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .gpu_rtx import has_rtx
1616
from .utils import calc_cuda_dims, has_cuda_and_cupy, is_cupy_array, is_cupy_backed
17+
from .dataset_support import supports_dataset
1718

1819

1920
def _run_numpy(data, azimuth=225, angle_altitude=25):
@@ -99,6 +100,7 @@ def _run_cupy(d_data, azimuth, angle_altitude):
99100
return output
100101

101102

103+
@supports_dataset
102104
def hillshade(agg: xr.DataArray,
103105
azimuth: int = 225,
104106
angle_altitude: int = 25,

xrspatial/multispectral.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from xrspatial.utils import (ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func,
1313
validate_arrays)
14+
from xrspatial.dataset_support import supports_dataset_bands
1415

1516
# 3rd-party
1617
try:
@@ -75,6 +76,7 @@ def _arvi_dask_cupy(nir_data, red_data, blue_data):
7576
return out
7677

7778

79+
@supports_dataset_bands(nir='nir_agg', red='red_agg', blue='blue_agg')
7880
def arvi(nir_agg: xr.DataArray,
7981
red_agg: xr.DataArray,
8082
blue_agg: xr.DataArray,
@@ -215,6 +217,7 @@ def _evi_dask_cupy(nir_data, red_data, blue_data, c1, c2, soil_factor, gain):
215217
return out
216218

217219

220+
@supports_dataset_bands(nir='nir_agg', red='red_agg', blue='blue_agg')
218221
def evi(nir_agg: xr.DataArray,
219222
red_agg: xr.DataArray,
220223
blue_agg: xr.DataArray,
@@ -374,6 +377,7 @@ def _gci_dask_cupy(nir_data, green_data):
374377
return out
375378

376379

380+
@supports_dataset_bands(nir='nir_agg', green='green_agg')
377381
def gci(nir_agg: xr.DataArray,
378382
green_agg: xr.DataArray,
379383
name='gci'):
@@ -451,6 +455,7 @@ def gci(nir_agg: xr.DataArray,
451455

452456

453457
# NBR ----------
458+
@supports_dataset_bands(nir='nir_agg', swir2='swir2_agg')
454459
def nbr(nir_agg: xr.DataArray,
455460
swir2_agg: xr.DataArray,
456461
name='nbr'):
@@ -529,6 +534,7 @@ def nbr(nir_agg: xr.DataArray,
529534
attrs=nir_agg.attrs)
530535

531536

537+
@supports_dataset_bands(swir1='swir1_agg', swir2='swir2_agg')
532538
def nbr2(swir1_agg: xr.DataArray,
533539
swir2_agg: xr.DataArray,
534540
name='nbr2'):
@@ -614,6 +620,7 @@ def nbr2(swir1_agg: xr.DataArray,
614620

615621

616622
# NDVI ----------
623+
@supports_dataset_bands(nir='nir_agg', red='red_agg')
617624
def ndvi(nir_agg: xr.DataArray,
618625
red_agg: xr.DataArray,
619626
name='ndvi'):
@@ -691,6 +698,7 @@ def ndvi(nir_agg: xr.DataArray,
691698

692699

693700
# NDMI ----------
701+
@supports_dataset_bands(nir='nir_agg', swir1='swir1_agg')
694702
def ndmi(nir_agg: xr.DataArray,
695703
swir1_agg: xr.DataArray,
696704
name='ndmi'):
@@ -874,6 +882,7 @@ def _savi_dask_cupy(nir_data, red_data, soil_factor):
874882

875883

876884
# SAVI ----------
885+
@supports_dataset_bands(nir='nir_agg', red='red_agg')
877886
def savi(nir_agg: xr.DataArray,
878887
red_agg: xr.DataArray,
879888
soil_factor: float = 1.0,
@@ -1006,6 +1015,7 @@ def _sipi_dask_cupy(nir_data, red_data, blue_data):
10061015
return out
10071016

10081017

1018+
@supports_dataset_bands(nir='nir_agg', red='red_agg', blue='blue_agg')
10091019
def sipi(nir_agg: xr.DataArray,
10101020
red_agg: xr.DataArray,
10111021
blue_agg: xr.DataArray,
@@ -1142,6 +1152,7 @@ def _ebbi_dask_cupy(red_data, swir_data, tir_data):
11421152
return out
11431153

11441154

1155+
@supports_dataset_bands(red='red_agg', swir='swir_agg', tir='tir_agg')
11451156
def ebbi(red_agg: xr.DataArray,
11461157
swir_agg: xr.DataArray,
11471158
tir_agg: xr.DataArray,

xrspatial/proximity.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from numba import prange
1111

1212
from xrspatial.utils import get_dataarray_resolution, ngjit
13+
from xrspatial.dataset_support import supports_dataset
1314

1415
EUCLIDEAN = 0
1516
GREAT_CIRCLE = 1
@@ -648,6 +649,7 @@ def _process_dask(raster, xs, ys):
648649

649650
# ported from
650651
# https://github.com/OSGeo/gdal/blob/master/gdal/alg/gdalproximity.cpp
652+
@supports_dataset
651653
def proximity(
652654
raster: xr.DataArray,
653655
x: str = "x",
@@ -783,6 +785,7 @@ def proximity(
783785
return result
784786

785787

788+
@supports_dataset
786789
def allocation(
787790
raster: xr.DataArray,
788791
x: str = "x",
@@ -915,6 +918,7 @@ def allocation(
915918
return result
916919

917920

921+
@supports_dataset
918922
def direction(
919923
raster: xr.DataArray,
920924
x: str = "x",

xrspatial/slope.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class cupy(object):
2828
from xrspatial.utils import cuda_args
2929
from xrspatial.utils import get_dataarray_resolution
3030
from xrspatial.utils import ngjit
31+
from xrspatial.dataset_support import supports_dataset
3132

3233

3334
def _geodesic_cuda_dims(shape):
@@ -267,6 +268,7 @@ def _run_dask_cupy_geodesic(data, lat_2d, lon_2d, a2, b2, z_factor):
267268
# Public API
268269
# =====================================================================
269270

271+
@supports_dataset
270272
def slope(agg: xr.DataArray,
271273
name: str = 'slope',
272274
method: str = 'planar',

0 commit comments

Comments
 (0)