Skip to content

Commit dbb945f

Browse files
committed
Fixes #881: replace np.unique/np.isfinite with dask-safe helpers in zonal.py
np.unique(zones[np.isfinite(zones)]) silently materialises the full dask array into RAM, causing OOM on large rasters. Replace with da.unique() which reduces per-chunk and only .compute()s the tiny set of distinct zone IDs.
1 parent 5b9c830 commit dbb945f

File tree

2 files changed

+107
-8
lines changed

2 files changed

+107
-8
lines changed

xrspatial/tests/test_zonal.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,3 +1212,72 @@ def test_crop_nothing_to_crop():
12121212
assert result.shape == arr.shape
12131213
compare = arr == result.data
12141214
assert compare.all()
1215+
1216+
1217+
# ---------------------------------------------------------------------------
1218+
# Regression tests for #881: np.unique / np.isfinite must not materialise
1219+
# the full dask array.
1220+
# ---------------------------------------------------------------------------
1221+
1222+
@pytest.mark.skipif(not has_dask_array(), reason="dask.array not available")
1223+
def test_stats_does_not_materialise_dask_zones():
1224+
"""stats() with dask backend must never pass a dask array to np.unique."""
1225+
from unittest import mock
1226+
1227+
zones_np = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
1228+
[0, 0, 1, 1, 2, 2, 3, 3],
1229+
[0, 0, 1, 1, 2, np.nan, 3, 3]])
1230+
values_np = np.array([[0, 0, 1, 1, 2, 2, 3, np.inf],
1231+
[0, 0, 1, 1, 2, np.nan, 3, 0],
1232+
[np.inf, 0, 1, 1, 2, 2, 3, 3]])
1233+
1234+
zones = xr.DataArray(da.from_array(zones_np, chunks=(3, 4)), dims=['y', 'x'])
1235+
values = xr.DataArray(da.from_array(values_np, chunks=(3, 4)), dims=['y', 'x'])
1236+
1237+
_real_np_unique = np.unique
1238+
1239+
def _guarded_unique(a, *args, **kwargs):
1240+
if isinstance(a, da.Array):
1241+
raise AssertionError("np.unique called with a dask array — would materialise")
1242+
return _real_np_unique(a, *args, **kwargs)
1243+
1244+
with mock.patch("xrspatial.zonal.np.unique", side_effect=_guarded_unique):
1245+
result = stats(zones, values)
1246+
1247+
# dask path returns a lazy dask DataFrame; compute to verify correctness
1248+
if hasattr(result, 'compute'):
1249+
result = result.compute()
1250+
assert isinstance(result, pd.DataFrame)
1251+
assert len(result) > 0
1252+
1253+
1254+
@pytest.mark.skipif(not has_dask_array(), reason="dask.array not available")
1255+
def test_crosstab_does_not_materialise_dask_zones():
1256+
"""crosstab() with dask backend must never pass a dask array to np.unique."""
1257+
from unittest import mock
1258+
1259+
zones_np = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
1260+
[0, 0, 1, 1, 2, 2, 3, 3],
1261+
[0, 0, 1, 1, 2, np.nan, 3, 3]])
1262+
values_np = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
1263+
[0, 0, 1, 1, 2, np.nan, 3, 0],
1264+
[0, 0, 1, 1, 2, 2, 3, 3]])
1265+
1266+
zones = xr.DataArray(da.from_array(zones_np, chunks=(3, 4)), dims=['y', 'x'])
1267+
values = xr.DataArray(da.from_array(values_np, chunks=(3, 4)), dims=['y', 'x'])
1268+
1269+
_real_np_unique = np.unique
1270+
1271+
def _guarded_unique(a, *args, **kwargs):
1272+
if isinstance(a, da.Array):
1273+
raise AssertionError("np.unique called with a dask array — would materialise")
1274+
return _real_np_unique(a, *args, **kwargs)
1275+
1276+
with mock.patch("xrspatial.zonal.np.unique", side_effect=_guarded_unique):
1277+
result = crosstab(zones, values)
1278+
1279+
# dask path returns a lazy dask DataFrame; compute to verify correctness
1280+
if hasattr(result, 'compute'):
1281+
result = result.compute()
1282+
assert isinstance(result, pd.DataFrame)
1283+
assert len(result) > 0

xrspatial/zonal.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,35 @@ class cupy(object):
4040
TOTAL_COUNT = '_total_count'
4141

4242

43+
def _unique_finite_zones(arr):
44+
"""Sorted unique finite values from *arr* without full materialisation.
45+
46+
For dask arrays uses ``da.unique`` (per-chunk reduction) so the full
47+
array is never pulled into RAM.
48+
"""
49+
if da is not None and isinstance(arr, da.Array):
50+
uniq = da.unique(arr).compute()
51+
return uniq[np.isfinite(uniq)]
52+
return np.unique(arr[np.isfinite(arr)])
53+
54+
55+
def _unique_finite_cats(arr, nodata_values):
56+
"""Sorted unique values excluding NaN, Inf, and *nodata_values*.
57+
58+
Dask-safe: uses ``da.unique`` so the full array is never materialised.
59+
"""
60+
if da is not None and isinstance(arr, da.Array):
61+
uniq = da.unique(arr).compute()
62+
mask = np.isfinite(uniq)
63+
if nodata_values is not None:
64+
mask &= (uniq != nodata_values)
65+
return uniq[mask]
66+
mask = np.isfinite(arr)
67+
if nodata_values is not None:
68+
mask &= (arr != nodata_values)
69+
return np.unique(arr[mask])
70+
71+
4372
def _stats_count(data):
4473
if isinstance(data, np.ndarray):
4574
# numpy case
@@ -187,7 +216,7 @@ def _stats_dask_numpy(
187216
) -> pd.DataFrame:
188217

189218
# find ids for all zones
190-
unique_zones = np.unique(zones[np.isfinite(zones)])
219+
unique_zones = _unique_finite_zones(zones)
191220

192221
select_all_zones = False
193222
# selecte zones to do analysis
@@ -199,7 +228,10 @@ def _stats_dask_numpy(
199228
values_blocks = values.to_delayed().ravel()
200229

201230
stats_dict = {}
202-
stats_dict["zone"] = unique_zones # zone column
231+
stats_dict["zone"] = da.from_delayed( # zone column
232+
delayed(lambda x: x)(unique_zones),
233+
shape=(np.nan,), dtype=unique_zones.dtype,
234+
)
203235

204236
compute_sum_squares = False
205237
compute_sum = False
@@ -287,7 +319,7 @@ def _stats_numpy(
287319
) -> Union[pd.DataFrame, np.ndarray]:
288320

289321
# find ids for all zones
290-
unique_zones = np.unique(zones[np.isfinite(zones)])
322+
unique_zones = _unique_finite_zones(zones)
291323
# selected zones to do analysis
292324
if zone_ids is None:
293325
zone_ids = unique_zones
@@ -670,9 +702,7 @@ def stats(
670702
def _find_cats(values, cat_ids, nodata_values):
671703
if len(values.shape) == 2:
672704
# 2D case
673-
unique_cats = np.unique(values.data[
674-
np.isfinite(values.data) & (values.data != nodata_values)
675-
])
705+
unique_cats = _unique_finite_cats(values.data, nodata_values)
676706
else:
677707
# 3D case
678708
unique_cats = values[values.dims[0]].data
@@ -756,7 +786,7 @@ def _crosstab_numpy(
756786
) -> pd.DataFrame:
757787

758788
# find ids for all zones
759-
unique_zones = np.unique(zones[np.isfinite(zones)])
789+
unique_zones = _unique_finite_zones(zones)
760790
# selected zones to do analysis
761791
if zone_ids is None:
762792
zone_ids = unique_zones
@@ -894,7 +924,7 @@ def _crosstab_dask_numpy(
894924
agg: str,
895925
):
896926
# find ids for all zones
897-
unique_zones = np.unique(zones[np.isfinite(zones)])
927+
unique_zones = _unique_finite_zones(zones)
898928
if zone_ids is None:
899929
zone_ids = unique_zones
900930
else:

0 commit comments

Comments
 (0)