Skip to content

Commit e86d3ec

Browse files
authored
ENH: multiple aggregations at once in zonal_stats (#56)
* ENH: multiple aggregations at once in zonal_stats * expand doctring * fix typing
1 parent 07a8e92 commit e86d3ec

3 files changed

Lines changed: 146 additions & 33 deletions

File tree

xvec/accessor.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def zonal_stats(
924924
geometry: Sequence[shapely.Geometry],
925925
x_coords: Hashable,
926926
y_coords: Hashable,
927-
stats: str | Callable = "mean",
927+
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
928928
name: Hashable = "geometry",
929929
index: bool = None,
930930
method: str = "rasterize",
@@ -949,36 +949,39 @@ def zonal_stats(
949949
y_coords : Hashable
950950
name of the coordinates containing ``y`` coordinates (i.e. the second value
951951
in the coordinate pair encoding the vertex of the polygon)
952-
stats : string | Callable
952+
stats : string | Callable | Sequence[str | Callable | tuple]
953953
Spatial aggregation statistic method, by default "mean". Any of the
954954
aggregations available as :class:`xarray.DataArray` or
955955
:class:`xarray.DataArrayGroupBy` methods like
956956
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
957-
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`
958-
are available. Alternatively, you can pass a ``Callable`` supported
959-
by :meth:`~xarray.DataArray.reduce`.
957+
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile` are
958+
available. Alternatively, you can pass a ``Callable`` supported by
959+
:meth:`~xarray.DataArray.reduce` or a list with ``strings``, ``callables``
960+
or ``tuples`` in a ``(name, func, {kwargs})`` format, where ``func`` can be
961+
a string or a callable.
960962
name : Hashable, optional
961963
Name of the dimension that will hold the ``geometry``, by default "geometry"
962964
index : bool, optional
963-
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will attach its index as another
964-
coordinate to the geometry dimension in the resulting object. If
965-
``index=None``, the index will be stored if the `geometry.index` is a named
966-
or non-default index. If ``index=False``, it will never be stored. This is
967-
useful as an attribute link between the resulting array and the GeoPandas
968-
object from which the geometry is sourced.
965+
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will
966+
attach its index as another coordinate to the geometry dimension in the
967+
resulting object. If ``index=None``, the index will be stored if the
968+
`geometry.index` is a named or non-default index. If ``index=False``, it
969+
will never be stored. This is useful as an attribute link between the
970+
resulting array and the GeoPandas object from which the geometry is sourced.
969971
method : str, optional
970972
The method of data extraction. The default is ``"rasterize"``, which uses
971-
:func:`rasterio.features.rasterize` and is faster, but can lead to loss
972-
of information in case of small polygons or lines. Other option is ``"iterate"``, which
973-
iterates over geometries and uses :func:`rasterio.features.geometry_mask`.
973+
:func:`rasterio.features.rasterize` and is faster, but can lead to loss of
974+
information in case of small polygons or lines. Other option is
975+
``"iterate"``, which iterates over geometries and uses
976+
:func:`rasterio.features.geometry_mask`.
974977
all_touched : bool, optional
975978
If True, all pixels touched by geometries will be considered. If False, only
976979
pixels whose center is within the polygon or that are selected by
977980
Bresenham’s line algorithm will be considered.
978981
n_jobs : int, optional
979982
Number of parallel threads to use. It is recommended to set this to the
980-
number of physical cores of the CPU. ``-1`` uses all available cores. Applies
981-
only if ``method="iterate"``.
983+
number of physical cores of the CPU. ``-1`` uses all available cores.
984+
Applies only if ``method="iterate"``.
982985
**kwargs : optional
983986
Keyword arguments to be passed to the aggregation function
984987
(e.g., ``Dataset.quantile(**kwargs)``).
@@ -990,8 +993,6 @@ def zonal_stats(
990993
the :class:`GeometryIndex` of ``geometry``.
991994
992995
"""
993-
# TODO: allow multiple stats at the same time (concat along a new axis),
994-
# TODO: possibly as a list of tuples to include names?
995996
if method == "rasterize":
996997
result = _zonal_stats_rasterize(
997998
self,
@@ -1033,9 +1034,7 @@ def zonal_stats(
10331034
result = result.assign_coords({index_name: (name, geometry.index)})
10341035

10351036
# standardize the shape - each method comes with a different one
1036-
return result.transpose(
1037-
name, *tuple(d for d in self._obj.dims if d not in [x_coords, y_coords])
1038-
)
1037+
return result.transpose(name, ...)
10391038

10401039
def extract_points(
10411040
self,

xvec/tests/test_zonal_stats.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,64 @@ def test_callable(method):
235235
world.geometry, "longitude", "latitude", method=method, stats="std"
236236
)
237237
xr.testing.assert_identical(da_agg, da_std)
238+
239+
240+
@pytest.mark.parametrize("method", ["rasterize", "iterate"])
241+
def test_multiple(method):
242+
ds = xr.tutorial.open_dataset("eraint_uvz")
243+
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
244+
result = ds.xvec.zonal_stats(
245+
world.geometry[:10].boundary,
246+
"longitude",
247+
"latitude",
248+
stats=[
249+
"mean",
250+
"sum",
251+
("quantile", "quantile", {"q": [0.1, 0.2, 0.3]}),
252+
("numpymean", np.nanmean),
253+
np.nanmean,
254+
],
255+
method=method,
256+
n_jobs=1,
257+
)
258+
assert sorted(result.dims) == sorted(
259+
[
260+
"level",
261+
"zonal_statistics",
262+
"geometry",
263+
"month",
264+
"quantile",
265+
]
266+
)
267+
268+
assert (
269+
result.zonal_statistics == ["mean", "sum", "quantile", "numpymean", "nanmean"]
270+
).all()
271+
272+
273+
@pytest.mark.parametrize("method", ["rasterize", "iterate"])
274+
def test_invalid(method):
275+
ds = xr.tutorial.open_dataset("eraint_uvz")
276+
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
277+
with pytest.raises(ValueError, match=r"\['gorilla'\] is not a valid aggregation."):
278+
ds.xvec.zonal_stats(
279+
world.geometry[:10].boundary,
280+
"longitude",
281+
"latitude",
282+
stats=[
283+
"mean",
284+
["gorilla"],
285+
],
286+
method=method,
287+
n_jobs=1,
288+
)
289+
290+
with pytest.raises(ValueError, match="3 is not a valid aggregation."):
291+
ds.xvec.zonal_stats(
292+
world.geometry[:10].boundary,
293+
"longitude",
294+
"latitude",
295+
stats=3,
296+
method=method,
297+
n_jobs=1,
298+
)

xvec/zonal.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,31 @@
55
from typing import Callable
66

77
import numpy as np
8+
import pandas as pd
89
import shapely
910
import xarray as xr
1011

1112

13+
def _agg_rasterize(groups, stats, **kwargs):
14+
if isinstance(stats, str):
15+
return getattr(groups, stats)(**kwargs)
16+
return groups.reduce(stats, keep_attrs=True, **kwargs)
17+
18+
19+
def _agg_iterate(masked, stats, x_coords, y_coords, **kwargs):
20+
if isinstance(stats, str):
21+
return getattr(masked, stats)(
22+
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
23+
)
24+
return masked.reduce(stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs)
25+
26+
1227
def _zonal_stats_rasterize(
1328
acc,
1429
geometry: Sequence[shapely.Geometry],
1530
x_coords: Hashable,
1631
y_coords: Hashable,
17-
stats: str | Callable = "mean",
32+
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
1833
name: str = "geometry",
1934
all_touched: bool = False,
2035
**kwargs,
@@ -47,10 +62,31 @@ def _zonal_stats_rasterize(
4762
all_touched=all_touched,
4863
)
4964
groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords)))
50-
if isinstance(stats, str):
51-
agg = getattr(groups, stats)(**kwargs)
65+
66+
if pd.api.types.is_list_like(stats):
67+
agg = {}
68+
for stat in stats:
69+
if isinstance(stat, str):
70+
agg[stat] = _agg_rasterize(groups, stat, **kwargs)
71+
elif callable(stat):
72+
agg[stat.__name__] = _agg_rasterize(groups, stat, **kwargs)
73+
elif isinstance(stat, tuple):
74+
kws = stat[2] if len(stat) == 3 else {}
75+
agg[stat[0]] = _agg_rasterize(groups, stat[1], **kws)
76+
else:
77+
raise ValueError(f"{stat} is not a valid aggregation.")
78+
79+
agg = xr.concat(
80+
agg.values(),
81+
dim=xr.DataArray(
82+
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
83+
),
84+
)
85+
elif isinstance(stats, str) or callable(stats):
86+
agg = _agg_rasterize(groups, stats, **kwargs)
5287
else:
53-
agg = groups.reduce(stats, keep_attrs=True, **kwargs)
88+
raise ValueError(f"{stats} is not a valid aggregation.")
89+
5490
vec_cube = (
5591
agg.reindex(group=range(len(geometry)))
5692
.assign_coords(group=geometry)
@@ -68,7 +104,7 @@ def _zonal_stats_iterative(
68104
geometry: Sequence[shapely.Geometry],
69105
x_coords: Hashable,
70106
y_coords: Hashable,
71-
stats: str | Callable = "mean",
107+
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
72108
name: str = "geometry",
73109
all_touched: bool = False,
74110
n_jobs: int = -1,
@@ -168,7 +204,7 @@ def _agg_geom(
168204
trans,
169205
x_coords: str = None,
170206
y_coords: str = None,
171-
stats: str | Callable = "mean",
207+
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
172208
all_touched=False,
173209
**kwargs,
174210
):
@@ -216,14 +252,31 @@ def _agg_geom(
216252
all_touched=all_touched,
217253
)
218254
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
219-
if isinstance(stats, str):
220-
result = getattr(masked, stats)(
221-
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
255+
if pd.api.types.is_list_like(stats):
256+
agg = {}
257+
for stat in stats:
258+
if isinstance(stat, str):
259+
agg[stat] = _agg_iterate(masked, stat, x_coords, y_coords, **kwargs)
260+
elif callable(stat):
261+
agg[stat.__name__] = _agg_iterate(
262+
masked, stat, x_coords, y_coords, **kwargs
263+
)
264+
elif isinstance(stat, tuple):
265+
kws = stat[2] if len(stat) == 3 else {}
266+
agg[stat[0]] = _agg_iterate(masked, stat[1], x_coords, y_coords, **kws)
267+
else:
268+
raise ValueError(f"{stat} is not a valid aggregation.")
269+
270+
result = xr.concat(
271+
agg.values(),
272+
dim=xr.DataArray(
273+
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
274+
),
222275
)
276+
elif isinstance(stats, str) or callable(stats):
277+
result = _agg_iterate(masked, stats, x_coords, y_coords, **kwargs)
223278
else:
224-
result = masked.reduce(
225-
stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs
226-
)
279+
raise ValueError(f"{stats} is not a valid aggregation.")
227280

228281
del mask
229282
gc.collect()

0 commit comments

Comments
 (0)