Skip to content

Commit 07a8e92

Browse files
authored
ENH: support callable as stats in zonal_stats (#55)
* ENH: support callable as stats in zonal_stats * polygons -> geometry * single thread to get proper coverage
1 parent 18ee579 commit 07a8e92

3 files changed

Lines changed: 95 additions & 50 deletions

File tree

xvec/accessor.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from collections.abc import Hashable, Mapping, Sequence
5-
from typing import Any
5+
from typing import Any, Callable
66

77
import numpy as np
88
import pandas as pd
@@ -921,10 +921,10 @@ def to_geodataframe(
921921

922922
def zonal_stats(
923923
self,
924-
polygons: Sequence[shapely.Geometry],
924+
geometry: Sequence[shapely.Geometry],
925925
x_coords: Hashable,
926926
y_coords: Hashable,
927-
stats: str = "mean",
927+
stats: str | Callable = "mean",
928928
name: Hashable = "geometry",
929929
index: bool = None,
930930
method: str = "rasterize",
@@ -934,37 +934,43 @@ def zonal_stats(
934934
):
935935
"""Extract the values from a dataset indexed by a set of geometries
936936
937-
The CRS of the raster and that of polygons need to be equal.
937+
The CRS of the raster and that of geometry need to be equal.
938938
Xvec does not verify their equality.
939939
940940
Parameters
941941
----------
942-
polygons : Sequence[shapely.Geometry]
942+
geometry : Sequence[shapely.Geometry]
943943
An arrray-like (1-D) of shapely geometries, like a numpy array or
944-
:class:`geopandas.GeoSeries`.
944+
:class:`geopandas.GeoSeries`. Polygon and LineString geometry types are
945+
supported.
945946
x_coords : Hashable
946947
name of the coordinates containing ``x`` coordinates (i.e. the first value
947948
in the coordinate pair encoding the vertex of the polygon)
948949
y_coords : Hashable
949950
name of the coordinates containing ``y`` coordinates (i.e. the second value
950951
in the coordinate pair encoding the vertex of the polygon)
951-
stats : string
952-
Spatial aggregation statistic method, by default "mean". It supports the
953-
following statistcs: ['mean', 'median', 'min', 'max', 'sum']
952+
stats : string | Callable
953+
Spatial aggregation statistic method, by default "mean". Any of the
954+
aggregations available as :class:`xarray.DataArray` or
955+
:class:`xarray.DataArrayGroupBy` methods like
956+
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
957+
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`
958+
are available. Alternatively, you can pass a ``Callable`` supported
959+
by :meth:`~xarray.DataArray.reduce`.
954960
name : Hashable, optional
955-
Name of the dimension that will hold the ``polygons``, by default "geometry"
961+
Name of the dimension that will hold the ``geometry``, by default "geometry"
956962
index : bool, optional
957-
If `polygons` is a GeoSeries, ``index=True`` will attach its index as another
963+
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will attach its index as another
958964
coordinate to the geometry dimension in the resulting object. If
959-
``index=None``, the index will be stored if the `polygons.index` is a named
965+
``index=None``, the index will be stored if the `geometry.index` is a named
960966
or non-default index. If ``index=False``, it will never be stored. This is
961967
useful as an attribute link between the resulting array and the GeoPandas
962-
object from which the polygons are sourced.
968+
object from which the geometry is sourced.
963969
method : str, optional
964970
The method of data extraction. The default is ``"rasterize"``, which uses
965971
:func:`rasterio.features.rasterize` and is faster, but can lead to loss
966-
of information in case of small polygons. Other option is ``"iterate"``, which
967-
iterates over polygons and uses :func:`rasterio.features.geometry_mask`.
972+
of information in case of small polygons or lines. Other option is ``"iterate"``, which
973+
iterates over geometries and uses :func:`rasterio.features.geometry_mask`.
968974
all_touched : bool, optional
969975
If True, all pixels touched by geometries will be considered. If False, only
970976
pixels whose center is within the polygon or that are selected by
@@ -975,22 +981,21 @@ def zonal_stats(
975981
only if ``method="iterate"``.
976982
**kwargs : optional
977983
Keyword arguments to be passed to the aggregation function
978-
(e.g., ``Dataset.mean(**kwargs)``).
984+
(e.g., ``Dataset.quantile(**kwargs)``).
979985
980986
Returns
981987
-------
982-
Dataset
988+
Dataset or DataArray
983989
A subset of the original object with N-1 dimensions indexed by
984-
the the GeometryIndex.
990+
the :class:`GeometryIndex` of ``geometry``.
985991
986992
"""
987993
# TODO: allow multiple stats at the same time (concat along a new axis),
988994
# TODO: possibly as a list of tuples to include names?
989-
# TODO: allow callable in stat (via .reduce())
990995
if method == "rasterize":
991996
result = _zonal_stats_rasterize(
992997
self,
993-
polygons=polygons,
998+
geometry=geometry,
994999
x_coords=x_coords,
9951000
y_coords=y_coords,
9961001
stats=stats,
@@ -1001,7 +1006,7 @@ def zonal_stats(
10011006
elif method == "iterate":
10021007
result = _zonal_stats_iterative(
10031008
self,
1004-
polygons=polygons,
1009+
geometry=geometry,
10051010
x_coords=x_coords,
10061011
y_coords=y_coords,
10071012
stats=stats,
@@ -1017,15 +1022,15 @@ def zonal_stats(
10171022
)
10181023

10191024
# save the index as a data variable
1020-
if isinstance(polygons, pd.Series):
1025+
if isinstance(geometry, pd.Series):
10211026
if index is None:
1022-
if polygons.index.name is not None or not polygons.index.equals(
1023-
pd.RangeIndex(0, len(polygons))
1027+
if geometry.index.name is not None or not geometry.index.equals(
1028+
pd.RangeIndex(0, len(geometry))
10241029
):
10251030
index = True
10261031
if index:
1027-
index_name = polygons.index.name if polygons.index.name else "index"
1028-
result = result.assign_coords({index_name: (name, polygons.index)})
1032+
index_name = geometry.index.name if geometry.index.name else "index"
1033+
result = result.assign_coords({index_name: (name, geometry.index)})
10291034

10301035
# standardize the shape - each method comes with a different one
10311036
return result.transpose(

xvec/tests/test_zonal_stats.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,29 @@ def test_crs(method):
209209

210210
actual = da.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
211211
xr.testing.assert_identical(actual, expected)
212+
213+
214+
@pytest.mark.parametrize("method", ["rasterize", "iterate"])
215+
def test_callable(method):
216+
ds = xr.tutorial.open_dataset("eraint_uvz")
217+
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
218+
ds_agg = ds.xvec.zonal_stats(
219+
world.geometry, "longitude", "latitude", method=method, stats=np.nanstd
220+
)
221+
ds_std = ds.xvec.zonal_stats(
222+
world.geometry, "longitude", "latitude", method=method, stats="std"
223+
)
224+
xr.testing.assert_identical(ds_agg, ds_std)
225+
226+
da_agg = ds.z.xvec.zonal_stats(
227+
world.geometry,
228+
"longitude",
229+
"latitude",
230+
method=method,
231+
stats=np.nanstd,
232+
n_jobs=1,
233+
)
234+
da_std = ds.z.xvec.zonal_stats(
235+
world.geometry, "longitude", "latitude", method=method, stats="std"
236+
)
237+
xr.testing.assert_identical(da_agg, da_std)

xvec/zonal.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import gc
44
from collections.abc import Hashable, Sequence
5+
from typing import Callable
56

67
import numpy as np
78
import shapely
@@ -10,16 +11,16 @@
1011

1112
def _zonal_stats_rasterize(
1213
acc,
13-
polygons: Sequence[shapely.Geometry],
14+
geometry: Sequence[shapely.Geometry],
1415
x_coords: Hashable,
1516
y_coords: Hashable,
16-
stats: str = "mean",
17+
stats: str | Callable = "mean",
1718
name: str = "geometry",
1819
all_touched: bool = False,
1920
**kwargs,
2021
):
2122
try:
22-
import rasterio # noqa: F401
23+
import rasterio
2324
import rioxarray # noqa: F401
2425
except ImportError as err:
2526
raise ImportError(
@@ -28,15 +29,15 @@ def _zonal_stats_rasterize(
2829
"'pip install rioxarray'."
2930
) from err
3031

31-
if hasattr(polygons, "crs"):
32-
crs = polygons.crs
32+
if hasattr(geometry, "crs"):
33+
crs = geometry.crs
3334
else:
3435
crs = None
3536

3637
transform = acc._obj.rio.transform()
3738

3839
labels = rasterio.features.rasterize(
39-
zip(polygons, range(len(polygons))),
40+
zip(geometry, range(len(geometry))),
4041
out_shape=(
4142
acc._obj[y_coords].shape[0],
4243
acc._obj[x_coords].shape[0],
@@ -46,10 +47,13 @@ def _zonal_stats_rasterize(
4647
all_touched=all_touched,
4748
)
4849
groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords)))
49-
agg = getattr(groups, stats)(**kwargs)
50+
if isinstance(stats, str):
51+
agg = getattr(groups, stats)(**kwargs)
52+
else:
53+
agg = groups.reduce(stats, keep_attrs=True, **kwargs)
5054
vec_cube = (
51-
agg.reindex(group=range(len(polygons)))
52-
.assign_coords(group=polygons)
55+
agg.reindex(group=range(len(geometry)))
56+
.assign_coords(group=geometry)
5357
.rename(group=name)
5458
).xvec.set_geom_indexes(name, crs=crs)
5559

@@ -61,23 +65,23 @@ def _zonal_stats_rasterize(
6165

6266
def _zonal_stats_iterative(
6367
acc,
64-
polygons: Sequence[shapely.Geometry],
68+
geometry: Sequence[shapely.Geometry],
6569
x_coords: Hashable,
6670
y_coords: Hashable,
67-
stats: str = "mean",
71+
stats: str | Callable = "mean",
6872
name: str = "geometry",
6973
all_touched: bool = False,
7074
n_jobs: int = -1,
7175
**kwargs,
7276
):
7377
"""Extract the values from a dataset indexed by a set of geometries
7478
75-
The CRS of the raster and that of polygons need to be equal.
79+
The CRS of the raster and that of geometry need to be equal.
7680
Xvec does not verify their equality.
7781
7882
Parameters
7983
----------
80-
polygons : Sequence[shapely.Geometry]
84+
geometry : Sequence[shapely.Geometry]
8185
An arrray-like (1-D) of shapely geometries, like a numpy array or
8286
:class:`geopandas.GeoSeries`.
8387
x_coords : Hashable
@@ -87,10 +91,14 @@ def _zonal_stats_iterative(
8791
name of the coordinates containing ``y`` coordinates (i.e. the second value
8892
in the coordinate pair encoding the vertex of the polygon)
8993
stats : Hashable
90-
Spatial aggregation statistic method, by default "mean". It supports the
91-
following statistcs: ['mean', 'median', 'min', 'max', 'sum']
94+
Spatial aggregation statistic method, by default "mean". Any of the
95+
aggregations available as DataArray or DataArrayGroupBy like
96+
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
97+
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`,
98+
methods are available. Alternatively, you can pass a ``Callable`` supported
99+
by :meth:`~xarray.DataArray.reduce`.
92100
name : Hashable, optional
93-
Name of the dimension that will hold the ``polygons``, by default "geometry"
101+
Name of the dimension that will hold the ``geometry``, by default "geometry"
94102
all_touched : bool, optional
95103
If True, all pixels touched by geometries will be considered. If False, only
96104
pixels whose center is within the polygon or that are selected by
@@ -140,14 +148,14 @@ def _zonal_stats_iterative(
140148
all_touched=all_touched,
141149
**kwargs,
142150
)
143-
for geom in polygons
151+
for geom in geometry
144152
)
145-
if hasattr(polygons, "crs"):
146-
crs = polygons.crs
153+
if hasattr(geometry, "crs"):
154+
crs = geometry.crs
147155
else:
148156
crs = None
149157
vec_cube = xr.concat(
150-
zonal, dim=xr.DataArray(polygons, name=name, dims=name)
158+
zonal, dim=xr.DataArray(geometry, name=name, dims=name)
151159
).xvec.set_geom_indexes(name, crs=crs)
152160
gc.collect()
153161

@@ -160,7 +168,7 @@ def _agg_geom(
160168
trans,
161169
x_coords: str = None,
162170
y_coords: str = None,
163-
stats: str = "mean",
171+
stats: str | Callable = "mean",
164172
all_touched=False,
165173
**kwargs,
166174
):
@@ -207,9 +215,15 @@ def _agg_geom(
207215
invert=True,
208216
all_touched=all_touched,
209217
)
210-
result = getattr(
211-
acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords))), stats
212-
)(dim=(y_coords, x_coords), keep_attrs=True, **kwargs)
218+
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
219+
if isinstance(stats, str):
220+
result = getattr(masked, stats)(
221+
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
222+
)
223+
else:
224+
result = masked.reduce(
225+
stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs
226+
)
213227

214228
del mask
215229
gc.collect()

0 commit comments

Comments
 (0)