Skip to content

Commit d05d9b7

Browse files
authored
Fuse hypsometric_integral dask path to a single graph evaluation (#1212)
_hi_dask_numpy did two blocking dask.compute() calls (_unique_finite_zones at one step, _hi_reduce at the next), so the caller paid for two full input scans before the lazy map_blocks output was even returned. _hi_reduce also np.stacked the per-block partials into an (n_blocks, n_zones, 4) array on the scheduler; at 240k blocks * 1000 zones that is ~7.7 GB resident in a single scheduler task. Have each block discover its own local unique zones and return a dict mapping zone id -> (min, max, sum, count). _hi_reduce stream-merges the partial dicts into a global hi_lookup so scheduler peak memory scales with the number of distinct zones, not n_blocks * n_zones. The up-front _unique_finite_zones pass is gone and the whole dask path collapses to a single graph evaluation.
1 parent 7fa9e04 commit d05d9b7

1 file changed

Lines changed: 60 additions & 37 deletions

File tree

xrspatial/zonal.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,65 +1440,88 @@ def _hi_cupy(zones_data, values_data, nodata):
14401440

14411441

14421442
@delayed
1443-
def _hi_block_stats(z_block, v_block, uzones):
1444-
"""Per-chunk: return (n_zones, 4) array of [min, max, sum, count]."""
1445-
result = np.full((len(uzones), 4), np.nan, dtype=np.float64)
1446-
result[:, 3] = 0 # count starts at 0
1447-
for i, z in enumerate(uzones):
1448-
mask = (z_block == z) & np.isfinite(v_block)
1449-
if not np.any(mask):
1443+
def _hi_block_stats(z_block, v_block, nodata):
1444+
"""Per-chunk: return dict mapping local zone IDs to (min, max, sum, count).
1445+
1446+
Each block discovers its own zones, so the driver never has to compute
1447+
a global unique-zone set up front. Sparse zones (geographic) stay sparse
1448+
in the returned dict instead of being padded to a full (n_zones, 4) array.
1449+
"""
1450+
finite_v = np.isfinite(v_block)
1451+
finite_z = np.isfinite(z_block)
1452+
valid = finite_z & finite_v
1453+
if not np.any(valid):
1454+
return {}
1455+
1456+
z_valid = z_block[valid]
1457+
v_valid = v_block[valid]
1458+
uzones = np.unique(z_valid)
1459+
1460+
result = {}
1461+
for z in uzones:
1462+
if nodata is not None and z == nodata:
14501463
continue
1451-
vals = v_block[mask]
1452-
result[i, 0] = vals.min()
1453-
result[i, 1] = vals.max()
1454-
result[i, 2] = vals.sum()
1455-
result[i, 3] = len(vals)
1464+
mask = z_valid == z
1465+
vals = v_valid[mask]
1466+
if vals.size == 0:
1467+
continue
1468+
result[z.item() if hasattr(z, 'item') else z] = (
1469+
float(vals.min()),
1470+
float(vals.max()),
1471+
float(vals.sum()),
1472+
int(vals.size),
1473+
)
14561474
return result
14571475

14581476

14591477
@delayed
1460-
def _hi_reduce(partials_list, uzones):
1461-
"""Reduce per-block stats to global per-zone HI lookup dict."""
1462-
stacked = np.stack(partials_list) # (n_blocks, n_zones, 4)
1463-
g_min = np.nanmin(stacked[:, :, 0], axis=0)
1464-
g_max = np.nanmax(stacked[:, :, 1], axis=0)
1465-
g_sum = np.nansum(stacked[:, :, 2], axis=0)
1466-
g_count = np.nansum(stacked[:, :, 3], axis=0)
1478+
def _hi_reduce(partials_list):
1479+
"""Stream-merge per-block dicts into global hi_lookup.
1480+
1481+
Scheduler peak memory is O(n_zones) for the merged dict, rather than
1482+
O(n_blocks * n_zones) from a stacked array. Per-block partials arrive
1483+
as a Python list but are iterated once and can be released.
1484+
"""
1485+
merged = {}
1486+
for partial in partials_list:
1487+
for z, (mn, mx, s, c) in partial.items():
1488+
if z in merged:
1489+
om, oM, os_, oc = merged[z]
1490+
merged[z] = (min(om, mn), max(oM, mx), os_ + s, oc + c)
1491+
else:
1492+
merged[z] = (mn, mx, s, c)
14671493

14681494
hi_lookup = {}
1469-
for i, z in enumerate(uzones):
1470-
if g_count[i] == 0 or g_max[i] == g_min[i]:
1495+
for z, (mn, mx, s, c) in merged.items():
1496+
if c == 0 or mx == mn:
14711497
hi_lookup[z] = np.nan
14721498
else:
1473-
mean = g_sum[i] / g_count[i]
1474-
hi_lookup[z] = (mean - g_min[i]) / (g_max[i] - g_min[i])
1499+
hi_lookup[z] = (s / c - mn) / (mx - mn)
14751500
return hi_lookup
14761501

14771502

14781503
def _hi_dask_numpy(zones_data, values_data, nodata):
1479-
"""Dask+numpy backend for hypsometric integral."""
1480-
# Step 1: find all unique zones across all chunks
1481-
unique_zones = _unique_finite_zones(zones_data)
1482-
if nodata is not None:
1483-
unique_zones = unique_zones[unique_zones != nodata]
1484-
1485-
if len(unique_zones) == 0:
1486-
return da.full(values_data.shape, np.nan, dtype=np.float64,
1487-
chunks=values_data.chunks)
1504+
"""Dask+numpy backend for hypsometric integral.
14881505
1489-
# Step 2: per-block aggregation -> global reduce
1506+
Single graph evaluation: each block computes its local (zone -> stats)
1507+
dict, then a streaming reduce merges them into a lookup table, then
1508+
map_blocks paints the result. No up-front `_unique_finite_zones`
1509+
compute and no O(n_blocks * n_zones) scheduler-side stack.
1510+
"""
14901511
zones_blocks = zones_data.to_delayed().ravel()
14911512
values_blocks = values_data.to_delayed().ravel()
14921513

14931514
partials = [
1494-
_hi_block_stats(zb, vb, unique_zones)
1515+
_hi_block_stats(zb, vb, nodata)
14951516
for zb, vb in zip(zones_blocks, values_blocks)
14961517
]
14971518

1498-
# Compute the HI lookup eagerly so map_blocks can use it as a parameter.
1499-
hi_lookup = dask.compute(_hi_reduce(partials, unique_zones))[0]
1519+
hi_lookup = dask.compute(_hi_reduce(partials))[0]
1520+
1521+
if not hi_lookup:
1522+
return da.full(values_data.shape, np.nan, dtype=np.float64,
1523+
chunks=values_data.chunks)
15001524

1501-
# Step 3: paint back using map_blocks (preserves chunk structure)
15021525
def _paint(zones_chunk, values_chunk, hi_map):
15031526
out = np.full(zones_chunk.shape, np.nan, dtype=np.float64)
15041527
for z, hi_val in hi_map.items():

0 commit comments

Comments
 (0)