Skip to content

Commit 80390d2

Browse files
committed
variable zonal stats using exactextract and automatic detection of method
1 parent e6a15bb commit 80390d2

3 files changed

Lines changed: 271 additions & 35 deletions

File tree

xvec/accessor.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from .plotting import _plot
1717
from .utils import transform_geom
1818
from .zonal import (
19+
_get_method,
1920
_variable_zonal,
21+
_variable_zonal_exactextract,
2022
_zonal_stats_exactextract,
2123
_zonal_stats_iterative,
2224
_zonal_stats_rasterize,
@@ -983,7 +985,7 @@ def zonal_stats(
983985
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
984986
name: str = "geometry",
985987
index: bool | None = None,
986-
method: str = "rasterize",
988+
method: str | None = None,
987989
all_touched: bool = False,
988990
n_jobs: int = -1,
989991
nodata: Any = None,
@@ -1005,10 +1007,10 @@ def zonal_stats(
10051007
10061008
Parameters
10071009
----------
1008-
geometry : Sequence[shapely.Geometry]
1010+
geometry : Sequence[shapely.Geometry] | xr.DataArray
10091011
An arrray-like (1-D) of shapely geometries, like a numpy array or
1010-
:class:`geopandas.GeoSeries`. Polygon and LineString geometry types are
1011-
supported.
1012+
:class:`geopandas.GeoSeries` or xr.DataArray holding variable geometry.
1013+
Polygon and LineString geometry types are supported.
10121014
x_coords : Hashable
10131015
name of the coordinates containing ``x`` coordinates (i.e. the first value
10141016
in the coordinate pair encoding the vertex of the polygon)
@@ -1045,7 +1047,8 @@ def zonal_stats(
10451047
10461048
``"rasterize"``
10471049
uses :func:`rasterio.features.rasterize` and is faster, but can lead to
1048-
loss of information in case of small polygons or lines.
1050+
loss of information in case of small polygons or lines. Not supported
1051+
for zonal stats using variable geometry (n-D array of geometry).
10491052
10501053
``"iterate"``
10511054
iterates over geometries and uses
@@ -1057,7 +1060,8 @@ def zonal_stats(
10571060
that is covered by the polygon and uses
10581061
:func:`exactextract.exact_extract`.
10591062
1060-
The default is ``"rasterize"``.
1063+
The default is selected based on the availability of engines in the order
1064+
of priority 1. ``"exactextract"``, 2. ``"rasterize"`` 3. ``"iterate"``.
10611065
all_touched : bool, optional
10621066
If True, all pixels touched by geometries will be considered. If False, only
10631067
pixels whose center is within the polygon or that are selected by
@@ -1156,17 +1160,40 @@ def zonal_stats(
11561160
--------
11571161
extract_points : extraction of values for the raster object for points
11581162
"""
1163+
11591164
if isinstance(geometry, xr.DataArray) and len(geometry.dims) > 1:
1160-
return _variable_zonal(
1161-
self,
1162-
variable_geometry=geometry,
1163-
x_coords=x_coords,
1164-
y_coords=y_coords,
1165-
stats=stats,
1166-
all_touched=all_touched,
1167-
nodata=nodata,
1165+
if method is None:
1166+
method = _get_method(variable=True)
1167+
1168+
if method == "iterate":
1169+
return _variable_zonal(
1170+
self,
1171+
variable_geometry=geometry,
1172+
x_coords=x_coords,
1173+
y_coords=y_coords,
1174+
stats=stats,
1175+
all_touched=all_touched,
1176+
n_jobs=n_jobs,
1177+
nodata=nodata,
1178+
)
1179+
if method == "exactextract":
1180+
return _variable_zonal_exactextract(
1181+
self,
1182+
geometry=geometry,
1183+
x_coords=x_coords,
1184+
y_coords=y_coords,
1185+
stats=stats,
1186+
nodata=nodata,
1187+
strategy=strategy,
1188+
)
1189+
raise ValueError(
1190+
f"Method '{method}' is not supported for zonal statistics based on "
1191+
"variable geometry. Use one of `exactextract` or `iterate`."
11681192
)
11691193

1194+
if method is None:
1195+
method = _get_method(variable=False)
1196+
11701197
if method == "rasterize":
11711198
result = _zonal_stats_rasterize(
11721199
self,

xvec/tests/test_zonal_stats.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def glaciers():
3030
return glaciers, sentinel_2
3131

3232

33-
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
33+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
3434
def test_structure(method):
3535
da = xr.DataArray(
3636
np.ones((10, 10, 5)),
@@ -45,7 +45,7 @@ def test_structure(method):
4545
polygon2 = shapely.geometry.Polygon([(6, 22), (9, 22), (9, 29), (6, 26)])
4646
polygons = gpd.GeoSeries([polygon1, polygon2], crs="EPSG:4326")
4747

48-
if method == "exactextract":
48+
if method in ["exactextract", None]:
4949
expected = xr.DataArray(
5050
np.array([[12.0] * 5, [16.5] * 5]),
5151
coords={
@@ -117,13 +117,13 @@ def test_match():
117117
xr.testing.assert_allclose(rasterize, iterate)
118118

119119

120-
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
120+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
121121
def test_dataset(method):
122122
ds = xr.tutorial.open_dataset("eraint_uvz")
123123
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
124124
result = ds.xvec.zonal_stats(world.geometry, "longitude", "latitude", method=method)
125125

126-
if method == "exactextract":
126+
if method in ["exactextract", None]:
127127
xr.testing.assert_allclose(
128128
xr.Dataset(
129129
{
@@ -147,7 +147,7 @@ def test_dataset(method):
147147
)
148148

149149

150-
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
150+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
151151
def test_dataarray(method):
152152
ds = xr.tutorial.open_dataset("eraint_uvz")
153153
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
@@ -157,13 +157,13 @@ def test_dataarray(method):
157157

158158
assert result.shape == (127, 2, 3)
159159
assert result.dims == ("geometry", "month", "level")
160-
if method == "exactextract":
160+
if method in ["exactextract", None]:
161161
assert result.mean() == pytest.approx(61625.53438858)
162162
else:
163163
assert result.mean() == pytest.approx(61367.76185577)
164164

165165

166-
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
166+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
167167
def test_stat(method):
168168
ds = xr.tutorial.open_dataset("eraint_uvz")
169169
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
@@ -174,7 +174,7 @@ def test_stat(method):
174174
median_ = ds.z.xvec.zonal_stats(
175175
world.geometry, "longitude", "latitude", method=method, stats="median"
176176
)
177-
if method == "exactextract":
177+
if method in ["exactextract", None]:
178178
quantile_ = ds.z.xvec.zonal_stats(
179179
world.geometry,
180180
"longitude",
@@ -192,7 +192,7 @@ def test_stat(method):
192192
q=0.2,
193193
)
194194

195-
if method == "exactextract":
195+
if method in ["exactextract", None]:
196196
assert mean_.mean() == pytest.approx(61625.53438858)
197197
assert median_.mean() == pytest.approx(61628.67168691)
198198
assert quantile_.mean() == pytest.approx(61540.75632235)
@@ -308,11 +308,11 @@ def test_callable(method):
308308
xr.testing.assert_identical(da_agg, da_std)
309309

310310

311-
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
311+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
312312
def test_multiple(method):
313313
ds = xr.tutorial.open_dataset("eraint_uvz")
314314
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
315-
if method == "exactextract":
315+
if method in ["exactextract", None]:
316316
result = ds.xvec.zonal_stats(
317317
world.geometry[:10].boundary,
318318
"longitude",
@@ -366,7 +366,7 @@ def test_multiple(method):
366366
).all()
367367

368368

369-
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
369+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
370370
def test_invalid(method):
371371
ds = xr.tutorial.open_dataset("eraint_uvz")
372372
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
@@ -394,7 +394,25 @@ def test_invalid(method):
394394
)
395395

396396

397-
def test_variable_geometry_multiple(glaciers):
397+
@pytest.mark.parametrize("method", [None, "iterate", "exactextract"])
398+
def test_variable_geometry_multiple(glaciers, method):
399+
da, sentinel_2 = glaciers
400+
401+
result = sentinel_2.xvec.zonal_stats(
402+
da.geometry,
403+
x_coords="x",
404+
y_coords="y",
405+
stats=[
406+
"mean",
407+
"sum",
408+
],
409+
method=method,
410+
)
411+
412+
assert result.sizes == {"year": 3, "name": 5, "zonal_statistics": 2, "band": 11}
413+
414+
415+
def test_variable_geometry_iterate_custom(glaciers):
398416
da, sentinel_2 = glaciers
399417

400418
result = sentinel_2.xvec.zonal_stats(
@@ -407,24 +425,31 @@ def test_variable_geometry_multiple(glaciers):
407425
("numpymean", np.nanmean),
408426
np.nanmean,
409427
],
428+
method="iterate",
410429
)
411430

412431
assert result.sizes == {"year": 3, "name": 5, "zonal_statistics": 4, "band": 11}
413-
assert result.statistics.mean() == 17067828
432+
assert result.mean() == 17067828
414433

415434

416-
def test_variable_geometry_single(glaciers):
435+
@pytest.mark.parametrize("method", [None, "iterate", "exactextract"])
436+
def test_variable_geometry_single(glaciers, method):
417437
da, sentinel_2 = glaciers
418438

419439
result = sentinel_2.xvec.zonal_stats(
420440
da.geometry,
421441
x_coords="x",
422442
y_coords="y",
423443
stats="mean",
444+
method=method,
424445
)
425446

426447
assert result.sizes == {"year": 3, "name": 5, "band": 11}
427-
assert result.statistics.mean() == 13168.585
448+
449+
if method in ("exactextract", None):
450+
assert result.mean() == pytest.approx(13168.076)
451+
else:
452+
assert result.mean() == 13168.585
428453

429454

430455
def test_exactextract_strategy():
@@ -458,7 +483,7 @@ def test_exactextract_strategy():
458483
)
459484

460485

461-
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
486+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
462487
def test_nodata(method):
463488
ds = xr.tutorial.open_dataset("eraint_uvz")
464489
world = gpd.read_file(geodatasets.get_path("naturalearth land"))

0 commit comments

Comments
 (0)