@@ -1440,65 +1440,88 @@ def _hi_cupy(zones_data, values_data, nodata):
14401440
14411441
14421442@delayed
1443- def _hi_block_stats (z_block , v_block , uzones ):
1444- """Per-chunk: return (n_zones, 4) array of [min, max, sum, count]."""
1445- result = np .full ((len (uzones ), 4 ), np .nan , dtype = np .float64 )
1446- result [:, 3 ] = 0 # count starts at 0
1447- for i , z in enumerate (uzones ):
1448- mask = (z_block == z ) & np .isfinite (v_block )
1449- if not np .any (mask ):
1443+ def _hi_block_stats (z_block , v_block , nodata ):
1444+ """Per-chunk: return dict mapping local zone IDs to (min, max, sum, count).
1445+
1446+ Each block discovers its own zones, so the driver never has to compute
1447+ a global unique-zone set up front. Sparse zones (geographic) stay sparse
1448+ in the returned dict instead of being padded to a full (n_zones, 4) array.
1449+ """
1450+ finite_v = np .isfinite (v_block )
1451+ finite_z = np .isfinite (z_block )
1452+ valid = finite_z & finite_v
1453+ if not np .any (valid ):
1454+ return {}
1455+
1456+ z_valid = z_block [valid ]
1457+ v_valid = v_block [valid ]
1458+ uzones = np .unique (z_valid )
1459+
1460+ result = {}
1461+ for z in uzones :
1462+ if nodata is not None and z == nodata :
14501463 continue
1451- vals = v_block [mask ]
1452- result [i , 0 ] = vals .min ()
1453- result [i , 1 ] = vals .max ()
1454- result [i , 2 ] = vals .sum ()
1455- result [i , 3 ] = len (vals )
1464+ mask = z_valid == z
1465+ vals = v_valid [mask ]
1466+ if vals .size == 0 :
1467+ continue
1468+ result [z .item () if hasattr (z , 'item' ) else z ] = (
1469+ float (vals .min ()),
1470+ float (vals .max ()),
1471+ float (vals .sum ()),
1472+ int (vals .size ),
1473+ )
14561474 return result
14571475
14581476
14591477@delayed
1460- def _hi_reduce (partials_list , uzones ):
1461- """Reduce per-block stats to global per-zone HI lookup dict."""
1462- stacked = np .stack (partials_list ) # (n_blocks, n_zones, 4)
1463- g_min = np .nanmin (stacked [:, :, 0 ], axis = 0 )
1464- g_max = np .nanmax (stacked [:, :, 1 ], axis = 0 )
1465- g_sum = np .nansum (stacked [:, :, 2 ], axis = 0 )
1466- g_count = np .nansum (stacked [:, :, 3 ], axis = 0 )
1478+ def _hi_reduce (partials_list ):
1479+ """Stream-merge per-block dicts into global hi_lookup.
1480+
1481+ Scheduler peak memory is O(n_zones) for the merged dict, rather than
1482+ O(n_blocks * n_zones) from a stacked array. Per-block partials arrive
1483+ as a Python list but are iterated once and can be released.
1484+ """
1485+ merged = {}
1486+ for partial in partials_list :
1487+ for z , (mn , mx , s , c ) in partial .items ():
1488+ if z in merged :
1489+ om , oM , os_ , oc = merged [z ]
1490+ merged [z ] = (min (om , mn ), max (oM , mx ), os_ + s , oc + c )
1491+ else :
1492+ merged [z ] = (mn , mx , s , c )
14671493
14681494 hi_lookup = {}
1469- for i , z in enumerate ( uzones ):
1470- if g_count [ i ] == 0 or g_max [ i ] == g_min [ i ] :
1495+ for z , ( mn , mx , s , c ) in merged . items ( ):
1496+ if c == 0 or mx == mn :
14711497 hi_lookup [z ] = np .nan
14721498 else :
1473- mean = g_sum [i ] / g_count [i ]
1474- hi_lookup [z ] = (mean - g_min [i ]) / (g_max [i ] - g_min [i ])
1499+ hi_lookup [z ] = (s / c - mn ) / (mx - mn )
14751500 return hi_lookup
14761501
14771502
14781503def _hi_dask_numpy (zones_data , values_data , nodata ):
1479- """Dask+numpy backend for hypsometric integral."""
1480- # Step 1: find all unique zones across all chunks
1481- unique_zones = _unique_finite_zones (zones_data )
1482- if nodata is not None :
1483- unique_zones = unique_zones [unique_zones != nodata ]
1484-
1485- if len (unique_zones ) == 0 :
1486- return da .full (values_data .shape , np .nan , dtype = np .float64 ,
1487- chunks = values_data .chunks )
1504+ """Dask+numpy backend for hypsometric integral.
14881505
1489- # Step 2: per-block aggregation -> global reduce
1506+ Single graph evaluation: each block computes its local (zone -> stats)
1507+ dict, then a streaming reduce merges them into a lookup table, then
1508+ map_blocks paints the result. No up-front `_unique_finite_zones`
1509+ compute and no O(n_blocks * n_zones) scheduler-side stack.
1510+ """
14901511 zones_blocks = zones_data .to_delayed ().ravel ()
14911512 values_blocks = values_data .to_delayed ().ravel ()
14921513
14931514 partials = [
1494- _hi_block_stats (zb , vb , unique_zones )
1515+ _hi_block_stats (zb , vb , nodata )
14951516 for zb , vb in zip (zones_blocks , values_blocks )
14961517 ]
14971518
1498- # Compute the HI lookup eagerly so map_blocks can use it as a parameter.
1499- hi_lookup = dask .compute (_hi_reduce (partials , unique_zones ))[0 ]
1519+ hi_lookup = dask .compute (_hi_reduce (partials ))[0 ]
1520+
1521+ if not hi_lookup :
1522+ return da .full (values_data .shape , np .nan , dtype = np .float64 ,
1523+ chunks = values_data .chunks )
15001524
1501- # Step 3: paint back using map_blocks (preserves chunk structure)
15021525 def _paint (zones_chunk , values_chunk , hi_map ):
15031526 out = np .full (zones_chunk .shape , np .nan , dtype = np .float64 )
15041527 for z , hi_val in hi_map .items ():
0 commit comments