55from typing import Callable
66
77import numpy as np
8+ import pandas as pd
89import shapely
910import xarray as xr
1011
1112
13+ def _agg_rasterize (groups , stats , ** kwargs ):
14+ if isinstance (stats , str ):
15+ return getattr (groups , stats )(** kwargs )
16+ return groups .reduce (stats , keep_attrs = True , ** kwargs )
17+
18+
19+ def _agg_iterate (masked , stats , x_coords , y_coords , ** kwargs ):
20+ if isinstance (stats , str ):
21+ return getattr (masked , stats )(
22+ dim = (y_coords , x_coords ), keep_attrs = True , ** kwargs
23+ )
24+ return masked .reduce (stats , dim = (y_coords , x_coords ), keep_attrs = True , ** kwargs )
25+
26+
1227def _zonal_stats_rasterize (
1328 acc ,
1429 geometry : Sequence [shapely .Geometry ],
1530 x_coords : Hashable ,
1631 y_coords : Hashable ,
17- stats : str | Callable = "mean" ,
32+ stats : str | Callable | Sequence [ str | Callable | tuple ] = "mean" ,
1833 name : str = "geometry" ,
1934 all_touched : bool = False ,
2035 ** kwargs ,
@@ -47,10 +62,31 @@ def _zonal_stats_rasterize(
4762 all_touched = all_touched ,
4863 )
4964 groups = acc ._obj .groupby (xr .DataArray (labels , dims = (y_coords , x_coords )))
50- if isinstance (stats , str ):
51- agg = getattr (groups , stats )(** kwargs )
65+
66+ if pd .api .types .is_list_like (stats ):
67+ agg = {}
68+ for stat in stats :
69+ if isinstance (stat , str ):
70+ agg [stat ] = _agg_rasterize (groups , stat , ** kwargs )
71+ elif callable (stat ):
72+ agg [stat .__name__ ] = _agg_rasterize (groups , stat , ** kwargs )
73+ elif isinstance (stat , tuple ):
74+ kws = stat [2 ] if len (stat ) == 3 else {}
75+ agg [stat [0 ]] = _agg_rasterize (groups , stat [1 ], ** kws )
76+ else :
77+ raise ValueError (f"{ stat } is not a valid aggregation." )
78+
79+ agg = xr .concat (
80+ agg .values (),
81+ dim = xr .DataArray (
82+ list (agg .keys ()), name = "zonal_statistics" , dims = "zonal_statistics"
83+ ),
84+ )
85+ elif isinstance (stats , str ) or callable (stats ):
86+ agg = _agg_rasterize (groups , stats , ** kwargs )
5287 else :
53- agg = groups .reduce (stats , keep_attrs = True , ** kwargs )
88+ raise ValueError (f"{ stats } is not a valid aggregation." )
89+
5490 vec_cube = (
5591 agg .reindex (group = range (len (geometry )))
5692 .assign_coords (group = geometry )
@@ -68,7 +104,7 @@ def _zonal_stats_iterative(
68104 geometry : Sequence [shapely .Geometry ],
69105 x_coords : Hashable ,
70106 y_coords : Hashable ,
71- stats : str | Callable = "mean" ,
107+ stats : str | Callable | Sequence [ str | Callable | tuple ] = "mean" ,
72108 name : str = "geometry" ,
73109 all_touched : bool = False ,
74110 n_jobs : int = - 1 ,
@@ -168,7 +204,7 @@ def _agg_geom(
168204 trans ,
169205 x_coords : str = None ,
170206 y_coords : str = None ,
171- stats : str | Callable = "mean" ,
207+ stats : str | Callable | Sequence [ str | Callable | tuple ] = "mean" ,
172208 all_touched = False ,
173209 ** kwargs ,
174210):
@@ -216,14 +252,31 @@ def _agg_geom(
216252 all_touched = all_touched ,
217253 )
218254 masked = acc ._obj .where (xr .DataArray (mask , dims = (y_coords , x_coords )))
219- if isinstance (stats , str ):
220- result = getattr (masked , stats )(
221- dim = (y_coords , x_coords ), keep_attrs = True , ** kwargs
255+ if pd .api .types .is_list_like (stats ):
256+ agg = {}
257+ for stat in stats :
258+ if isinstance (stat , str ):
259+ agg [stat ] = _agg_iterate (masked , stat , x_coords , y_coords , ** kwargs )
260+ elif callable (stat ):
261+ agg [stat .__name__ ] = _agg_iterate (
262+ masked , stat , x_coords , y_coords , ** kwargs
263+ )
264+ elif isinstance (stat , tuple ):
265+ kws = stat [2 ] if len (stat ) == 3 else {}
266+ agg [stat [0 ]] = _agg_iterate (masked , stat [1 ], x_coords , y_coords , ** kws )
267+ else :
268+ raise ValueError (f"{ stat } is not a valid aggregation." )
269+
270+ result = xr .concat (
271+ agg .values (),
272+ dim = xr .DataArray (
273+ list (agg .keys ()), name = "zonal_statistics" , dims = "zonal_statistics"
274+ ),
222275 )
276+ elif isinstance (stats , str ) or callable (stats ):
277+ result = _agg_iterate (masked , stats , x_coords , y_coords , ** kwargs )
223278 else :
224- result = masked .reduce (
225- stats , dim = (y_coords , x_coords ), keep_attrs = True , ** kwargs
226- )
279+ raise ValueError (f"{ stats } is not a valid aggregation." )
227280
228281 del mask
229282 gc .collect ()
0 commit comments