22
33import gc
44from collections .abc import Hashable , Sequence
5+ from typing import Callable
56
67import numpy as np
78import shapely
1011
1112def _zonal_stats_rasterize (
1213 acc ,
13- polygons : Sequence [shapely .Geometry ],
14+ geometry : Sequence [shapely .Geometry ],
1415 x_coords : Hashable ,
1516 y_coords : Hashable ,
16- stats : str = "mean" ,
17+ stats : str | Callable = "mean" ,
1718 name : str = "geometry" ,
1819 all_touched : bool = False ,
1920 ** kwargs ,
2021):
2122 try :
22- import rasterio # noqa: F401
23+ import rasterio
2324 import rioxarray # noqa: F401
2425 except ImportError as err :
2526 raise ImportError (
@@ -28,15 +29,15 @@ def _zonal_stats_rasterize(
2829 "'pip install rioxarray'."
2930 ) from err
3031
31- if hasattr (polygons , "crs" ):
32- crs = polygons .crs
32+ if hasattr (geometry , "crs" ):
33+ crs = geometry .crs
3334 else :
3435 crs = None
3536
3637 transform = acc ._obj .rio .transform ()
3738
3839 labels = rasterio .features .rasterize (
39- zip (polygons , range (len (polygons ))),
40+ zip (geometry , range (len (geometry ))),
4041 out_shape = (
4142 acc ._obj [y_coords ].shape [0 ],
4243 acc ._obj [x_coords ].shape [0 ],
@@ -46,10 +47,13 @@ def _zonal_stats_rasterize(
4647 all_touched = all_touched ,
4748 )
4849 groups = acc ._obj .groupby (xr .DataArray (labels , dims = (y_coords , x_coords )))
49- agg = getattr (groups , stats )(** kwargs )
50+ if isinstance (stats , str ):
51+ agg = getattr (groups , stats )(** kwargs )
52+ else :
53+ agg = groups .reduce (stats , keep_attrs = True , ** kwargs )
5054 vec_cube = (
51- agg .reindex (group = range (len (polygons )))
52- .assign_coords (group = polygons )
55+ agg .reindex (group = range (len (geometry )))
56+ .assign_coords (group = geometry )
5357 .rename (group = name )
5458 ).xvec .set_geom_indexes (name , crs = crs )
5559
@@ -61,23 +65,23 @@ def _zonal_stats_rasterize(
6165
6266def _zonal_stats_iterative (
6367 acc ,
64- polygons : Sequence [shapely .Geometry ],
68+ geometry : Sequence [shapely .Geometry ],
6569 x_coords : Hashable ,
6670 y_coords : Hashable ,
67- stats : str = "mean" ,
71+ stats : str | Callable = "mean" ,
6872 name : str = "geometry" ,
6973 all_touched : bool = False ,
7074 n_jobs : int = - 1 ,
7175 ** kwargs ,
7276):
7377 """Extract the values from a dataset indexed by a set of geometries
7478
75- The CRS of the raster and that of polygons need to be equal.
79+ The CRS of the raster and that of geometry need to be equal.
7680 Xvec does not verify their equality.
7781
7882 Parameters
7983 ----------
80- polygons : Sequence[shapely.Geometry]
84+ geometry : Sequence[shapely.Geometry]
8185 An arrray-like (1-D) of shapely geometries, like a numpy array or
8286 :class:`geopandas.GeoSeries`.
8387 x_coords : Hashable
@@ -87,10 +91,14 @@ def _zonal_stats_iterative(
8791 name of the coordinates containing ``y`` coordinates (i.e. the second value
8892 in the coordinate pair encoding the vertex of the polygon)
8993 stats : Hashable
90- Spatial aggregation statistic method, by default "mean". It supports the
91- following statistcs: ['mean', 'median', 'min', 'max', 'sum']
94+ Spatial aggregation statistic method, by default "mean". Any of the
95+ aggregations available as DataArray or DataArrayGroupBy like
96+ :meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
97+ :meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`,
98+ methods are available. Alternatively, you can pass a ``Callable`` supported
99+ by :meth:`~xarray.DataArray.reduce`.
92100 name : Hashable, optional
93- Name of the dimension that will hold the ``polygons ``, by default "geometry"
101+ Name of the dimension that will hold the ``geometry ``, by default "geometry"
94102 all_touched : bool, optional
95103 If True, all pixels touched by geometries will be considered. If False, only
96104 pixels whose center is within the polygon or that are selected by
@@ -140,14 +148,14 @@ def _zonal_stats_iterative(
140148 all_touched = all_touched ,
141149 ** kwargs ,
142150 )
143- for geom in polygons
151+ for geom in geometry
144152 )
145- if hasattr (polygons , "crs" ):
146- crs = polygons .crs
153+ if hasattr (geometry , "crs" ):
154+ crs = geometry .crs
147155 else :
148156 crs = None
149157 vec_cube = xr .concat (
150- zonal , dim = xr .DataArray (polygons , name = name , dims = name )
158+ zonal , dim = xr .DataArray (geometry , name = name , dims = name )
151159 ).xvec .set_geom_indexes (name , crs = crs )
152160 gc .collect ()
153161
@@ -160,7 +168,7 @@ def _agg_geom(
160168 trans ,
161169 x_coords : str = None ,
162170 y_coords : str = None ,
163- stats : str = "mean" ,
171+ stats : str | Callable = "mean" ,
164172 all_touched = False ,
165173 ** kwargs ,
166174):
@@ -207,9 +215,15 @@ def _agg_geom(
207215 invert = True ,
208216 all_touched = all_touched ,
209217 )
210- result = getattr (
211- acc ._obj .where (xr .DataArray (mask , dims = (y_coords , x_coords ))), stats
212- )(dim = (y_coords , x_coords ), keep_attrs = True , ** kwargs )
218+ masked = acc ._obj .where (xr .DataArray (mask , dims = (y_coords , x_coords )))
219+ if isinstance (stats , str ):
220+ result = getattr (masked , stats )(
221+ dim = (y_coords , x_coords ), keep_attrs = True , ** kwargs
222+ )
223+ else :
224+ result = masked .reduce (
225+ stats , dim = (y_coords , x_coords ), keep_attrs = True , ** kwargs
226+ )
213227
214228 del mask
215229 gc .collect ()
0 commit comments