Skip to content

Commit 61f5676

Browse files
authored
Fixes #774: suppress dd.concat unknown divisions warning in zonal_stats (#857)
* Fixes #392: document and test 3D time-series zonal stats via Dataset Add docstring example and tests showing how to compute zonal statistics on 3D time-series DataArrays by converting to a Dataset with `.to_dataset(dim='time')` and passing to `stats()`. * Fixes #774: suppress dd.concat unknown divisions warning in zonal_stats
1 parent 0ef76c8 commit 61f5676

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

xrspatial/tests/test_zonal.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,44 @@ def test_crop():
11621162
assert compare.all()
11631163

11641164

1165+
@pytest.mark.skipif(not dask_array_available(), reason="Requires Dask")
1166+
def test_dask_zonal_stats_no_concat_warnings():
1167+
"""Regression test for #774: dd.concat should not warn about unknown divisions."""
1168+
import warnings
1169+
1170+
zones_data = np.array([[0, 0, 1, 1],
1171+
[0, 0, 1, 1],
1172+
[2, 2, 3, 3]])
1173+
values_data = np.array([[1, 2, 3, 4],
1174+
[5, 6, 7, 8],
1175+
[9, 10, 11, 12]], dtype=float)
1176+
1177+
zones = xr.DataArray(da.from_array(zones_data, chunks=(3, 2)), dims=['y', 'x'])
1178+
values = xr.DataArray(da.from_array(values_data, chunks=(3, 2)), dims=['y', 'x'])
1179+
1180+
with warnings.catch_warnings(record=True) as caught:
1181+
warnings.simplefilter("always")
1182+
1183+
# all zones (exercises column-wise concat, line 262)
1184+
result_all = stats(zones=zones, values=values)
1185+
assert isinstance(result_all, dd.DataFrame)
1186+
result_all.compute()
1187+
1188+
# filtered zone_ids (exercises row-wise concat, line 275)
1189+
result_filtered = stats(zones=zones, values=values, zone_ids=[0, 3])
1190+
assert isinstance(result_filtered, dd.DataFrame)
1191+
result_filtered.compute()
1192+
1193+
division_warnings = [
1194+
w for w in caught
1195+
if "unknown divisions" in str(w.message).lower()
1196+
]
1197+
assert division_warnings == [], (
1198+
f"Expected no 'unknown divisions' warnings, got: "
1199+
f"{[str(w.message) for w in division_warnings]}"
1200+
)
1201+
1202+
11651203
def test_crop_nothing_to_crop():
11661204
arr = np.array([[0, 4, 0, 3],
11671205
[0, 4, 4, 3],

xrspatial/zonal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _stats_dask_numpy(
259259
)
260260

261261
# generate dask dataframe
262-
stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1)
262+
stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1, ignore_unknown_divisions=True)
263263
# name columns
264264
stats_df.columns = stats_dict.keys()
265265
# select columns (only include stats that were actually computed)
@@ -272,7 +272,7 @@ def _stats_dask_numpy(
272272
for index, row in stats_df.iterrows():
273273
if row['zone'] in zone_ids:
274274
selected_rows.append(stats_df.loc[index])
275-
stats_df = dd.concat(selected_rows)
275+
stats_df = dd.concat(selected_rows, ignore_unknown_divisions=True)
276276

277277
return stats_df
278278

0 commit comments

Comments
 (0)