Skip to content

Commit 5b9c830

Browse files
authored
Fixes #879: add memory guard and tiled KDTree fallback to proximity (#892)
The dask KDTree path in proximity() accumulated all target coordinates in memory before building a global cKDTree, causing OOM on dense rasters. Replace the unbounded accumulation with a two-phase approach: - Phase 0: stream chunks counting targets, cache coords within 25% budget - Memory decision: if global tree estimate < 50% available RAM, build a single cKDTree (fast, lazy via da.map_blocks); otherwise fall back to a tiled expanding-ring KDTree (eager but memory-safe) The tiled path builds local trees per output chunk, expanding the search ring until a boundary-distance validation proves correctness. A ResourceWarning is emitted when the tiled fallback activates.
1 parent 107f043 commit 5b9c830

File tree

2 files changed

+473
-34
lines changed

2 files changed

+473
-34
lines changed

xrspatial/proximity.py

Lines changed: 304 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from functools import partial
23
from math import sqrt
34

@@ -15,6 +16,7 @@
1516
import xarray as xr
1617
from numba import prange
1718

19+
from xrspatial.pathfinding import _available_memory_bytes
1820
from xrspatial.utils import get_dataarray_resolution, ngjit
1921
from xrspatial.dataset_support import supports_dataset
2022

@@ -426,55 +428,323 @@ def _kdtree_chunk_fn(block, y_coords_1d, x_coords_1d,
426428
return dists
427429

428430

429-
def _process_dask_kdtree(raster, x_coords, y_coords,
430-
target_values, max_distance, distance_metric):
431-
"""Two-phase k-d tree proximity for unbounded dask arrays."""
432-
p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1
431+
def _target_mask(chunk_data, target_values):
432+
"""Boolean mask of target pixels in *chunk_data*."""
433+
if len(target_values) == 0:
434+
return np.isfinite(chunk_data) & (chunk_data != 0)
435+
return np.isin(chunk_data, target_values) & np.isfinite(chunk_data)
433436

434-
# Phase 1: stream through chunks to collect target coordinates
435-
target_list = []
436-
chunks_y, chunks_x = raster.data.chunks
437-
y_offset = 0
438-
for iy, cy in enumerate(chunks_y):
439-
x_offset = 0
440-
for ix, cx in enumerate(chunks_x):
437+
438+
def _stream_target_counts(raster, target_values, y_coords, x_coords,
439+
chunks_y, chunks_x):
440+
"""Stream all dask chunks, counting targets per chunk.
441+
442+
Caches per-chunk coordinate arrays within a 25% memory budget to
443+
reduce re-reads in later phases.
444+
445+
Returns
446+
-------
447+
target_counts : ndarray, shape (n_tile_y, n_tile_x), dtype int64
448+
total_targets : int
449+
coords_cache : dict (iy, ix) -> ndarray shape (N, 2)
450+
"""
451+
n_tile_y = len(chunks_y)
452+
n_tile_x = len(chunks_x)
453+
target_counts = np.zeros((n_tile_y, n_tile_x), dtype=np.int64)
454+
coords_cache = {}
455+
cache_bytes = 0
456+
budget = int(0.25 * _available_memory_bytes())
457+
458+
y_offsets = np.zeros(n_tile_y + 1, dtype=np.int64)
459+
np.cumsum(chunks_y, out=y_offsets[1:])
460+
x_offsets = np.zeros(n_tile_x + 1, dtype=np.int64)
461+
np.cumsum(chunks_x, out=x_offsets[1:])
462+
463+
for iy in range(n_tile_y):
464+
for ix in range(n_tile_x):
441465
chunk_data = raster.data.blocks[iy, ix].compute()
442-
if len(target_values) == 0:
443-
mask = np.isfinite(chunk_data) & (chunk_data != 0)
444-
else:
445-
mask = np.isin(chunk_data, target_values) & np.isfinite(chunk_data)
466+
mask = _target_mask(chunk_data, target_values)
446467
rows, cols = np.where(mask)
447-
if len(rows) > 0:
468+
n = len(rows)
469+
target_counts[iy, ix] = n
470+
if n > 0:
448471
coords = np.column_stack([
449-
y_coords[y_offset + rows],
450-
x_coords[x_offset + cols],
472+
y_coords[y_offsets[iy] + rows],
473+
x_coords[x_offsets[ix] + cols],
451474
])
452-
target_list.append(coords)
453-
x_offset += cx
454-
y_offset += cy
475+
entry_bytes = coords.nbytes
476+
if cache_bytes + entry_bytes <= budget:
477+
coords_cache[(iy, ix)] = coords
478+
cache_bytes += entry_bytes
455479

456-
if len(target_list) == 0:
457-
return da.full(raster.shape, np.nan, dtype=np.float32,
458-
chunks=raster.data.chunks)
480+
total_targets = int(target_counts.sum())
481+
return target_counts, total_targets, coords_cache
482+
483+
484+
def _chunk_offsets(chunks):
485+
"""Return cumulative offset array of length len(chunks)+1."""
486+
offsets = np.zeros(len(chunks) + 1, dtype=np.int64)
487+
np.cumsum(chunks, out=offsets[1:])
488+
return offsets
489+
490+
491+
def _collect_region_targets(raster, jy_lo, jy_hi, jx_lo, jx_hi,
492+
target_values, target_counts,
493+
y_coords, x_coords,
494+
y_offsets, x_offsets, coords_cache):
495+
"""Collect target (y, x) coords from chunks in [jy_lo:jy_hi, jx_lo:jx_hi].
496+
497+
Uses cache where available, re-reads uncached chunks via .compute().
498+
Returns ndarray shape (N, 2) or None if no targets in region.
499+
"""
500+
parts = []
501+
for iy in range(jy_lo, jy_hi):
502+
for ix in range(jx_lo, jx_hi):
503+
if target_counts[iy, ix] == 0:
504+
continue
505+
if (iy, ix) in coords_cache:
506+
parts.append(coords_cache[(iy, ix)])
507+
else:
508+
chunk_data = raster.data.blocks[iy, ix].compute()
509+
mask = _target_mask(chunk_data, target_values)
510+
rows, cols = np.where(mask)
511+
if len(rows) > 0:
512+
coords = np.column_stack([
513+
y_coords[y_offsets[iy] + rows],
514+
x_coords[x_offsets[ix] + cols],
515+
])
516+
parts.append(coords)
517+
if not parts:
518+
return None
519+
return np.concatenate(parts)
520+
521+
522+
def _min_boundary_distance(iy, ix, y_coords, x_coords,
523+
y_offsets, x_offsets,
524+
jy_lo, jy_hi, jx_lo, jx_hi,
525+
n_tile_y, n_tile_x):
526+
"""Lower bound on distance from any pixel in chunk (iy, ix) to any point
527+
outside the search region [jy_lo:jy_hi, jx_lo:jx_hi].
528+
529+
For each of the 4 sides where the search region doesn't reach the raster
530+
edge, compute the gap between the chunk's edge pixel coordinate and the
531+
first pixel outside the search region. The minimum of these gaps is
532+
a valid lower bound for both L1 and L2 norms.
533+
534+
Returns float (inf if search covers the full raster).
535+
"""
536+
gaps = []
537+
538+
# Top boundary
539+
if jy_lo > 0:
540+
# chunk's top-edge row in pixel space
541+
chunk_top_row = y_offsets[iy]
542+
# first row outside region (above)
543+
outside_row = y_offsets[jy_lo] - 1
544+
gap = abs(float(y_coords[chunk_top_row]) - float(y_coords[outside_row]))
545+
gaps.append(gap)
546+
547+
# Bottom boundary
548+
if jy_hi < n_tile_y:
549+
chunk_bot_row = y_offsets[iy + 1] - 1
550+
outside_row = y_offsets[jy_hi]
551+
gap = abs(float(y_coords[chunk_bot_row]) - float(y_coords[outside_row]))
552+
gaps.append(gap)
553+
554+
# Left boundary
555+
if jx_lo > 0:
556+
chunk_left_col = x_offsets[ix]
557+
outside_col = x_offsets[jx_lo] - 1
558+
gap = abs(float(x_coords[chunk_left_col]) - float(x_coords[outside_col]))
559+
gaps.append(gap)
560+
561+
# Right boundary
562+
if jx_hi < n_tile_x:
563+
chunk_right_col = x_offsets[ix + 1] - 1
564+
outside_col = x_offsets[jx_hi]
565+
gap = abs(float(x_coords[chunk_right_col]) - float(x_coords[outside_col]))
566+
gaps.append(gap)
567+
568+
return min(gaps) if gaps else np.inf
569+
570+
571+
def _tiled_chunk_proximity(raster, iy, ix, y_coords, x_coords,
572+
y_offsets, x_offsets,
573+
target_values, target_counts,
574+
coords_cache, max_distance, p,
575+
n_tile_y, n_tile_x):
576+
"""Expanding-ring local KDTree for one output chunk.
577+
578+
Returns ndarray shape (h, w), dtype float32.
579+
"""
580+
h = int(y_offsets[iy + 1] - y_offsets[iy])
581+
w = int(x_offsets[ix + 1] - x_offsets[ix])
582+
583+
# Build query points for this chunk
584+
chunk_ys = y_coords[y_offsets[iy]:y_offsets[iy + 1]]
585+
chunk_xs = x_coords[x_offsets[ix]:x_offsets[ix + 1]]
586+
yy, xx = np.meshgrid(chunk_ys, chunk_xs, indexing='ij')
587+
query_pts = np.column_stack([yy.ravel(), xx.ravel()])
588+
589+
ring = 0
590+
while True:
591+
jy_lo = max(iy - ring, 0)
592+
jy_hi = min(iy + 1 + ring, n_tile_y)
593+
jx_lo = max(ix - ring, 0)
594+
jx_hi = min(ix + 1 + ring, n_tile_x)
595+
596+
covers_full = (jy_lo == 0 and jy_hi == n_tile_y
597+
and jx_lo == 0 and jx_hi == n_tile_x)
598+
599+
target_coords = _collect_region_targets(
600+
raster, jy_lo, jy_hi, jx_lo, jx_hi,
601+
target_values, target_counts,
602+
y_coords, x_coords, y_offsets, x_offsets, coords_cache,
603+
)
604+
605+
if target_coords is None:
606+
if covers_full:
607+
# No targets in entire raster
608+
return np.full((h, w), np.nan, dtype=np.float32)
609+
ring += 1
610+
continue
611+
612+
tree = cKDTree(target_coords)
613+
ub = max_distance if np.isfinite(max_distance) else np.inf
614+
dists, _ = tree.query(query_pts, p=p, distance_upper_bound=ub)
615+
dists = dists.reshape(h, w).astype(np.float32)
616+
dists[dists == np.inf] = np.nan
617+
618+
if covers_full:
619+
return dists
620+
621+
# Validate: max_nearest_dist < min_boundary_distance
622+
max_nearest = np.nanmax(dists) if not np.all(np.isnan(dists)) else 0.0
623+
min_bd = _min_boundary_distance(
624+
iy, ix, y_coords, x_coords, y_offsets, x_offsets,
625+
jy_lo, jy_hi, jx_lo, jx_hi, n_tile_y, n_tile_x,
626+
)
627+
if max_nearest < min_bd:
628+
return dists
629+
630+
ring += 1
631+
632+
633+
def _build_tiled_kdtree(raster, y_coords, x_coords, target_values,
634+
max_distance, p, target_counts, coords_cache,
635+
chunks_y, chunks_x):
636+
"""Tiled (eager) KDTree proximity — memory-safe fallback."""
637+
H, W = raster.shape
638+
result_bytes = H * W * 4 # float32
639+
avail = _available_memory_bytes()
640+
if result_bytes > 0.8 * avail:
641+
raise MemoryError(
642+
f"Proximity result array ({H}x{W}, {result_bytes / 1e9:.1f} GB) "
643+
f"exceeds 80% of available memory ({avail / 1e9:.1f} GB)."
644+
)
645+
646+
warnings.warn(
647+
"proximity: target coordinates exceed 50% of available memory; "
648+
"using tiled KDTree fallback (slower but memory-safe).",
649+
ResourceWarning,
650+
stacklevel=4,
651+
)
652+
653+
n_tile_y = len(chunks_y)
654+
n_tile_x = len(chunks_x)
655+
y_offsets = _chunk_offsets(chunks_y)
656+
x_offsets = _chunk_offsets(chunks_x)
657+
658+
result = np.full((H, W), np.nan, dtype=np.float32)
659+
660+
for iy in range(n_tile_y):
661+
for ix in range(n_tile_x):
662+
chunk_result = _tiled_chunk_proximity(
663+
raster, iy, ix, y_coords, x_coords,
664+
y_offsets, x_offsets,
665+
target_values, target_counts, coords_cache,
666+
max_distance, p, n_tile_y, n_tile_x,
667+
)
668+
r0 = int(y_offsets[iy])
669+
r1 = int(y_offsets[iy + 1])
670+
c0 = int(x_offsets[ix])
671+
c1 = int(x_offsets[ix + 1])
672+
result[r0:r1, c0:c1] = chunk_result
673+
674+
return da.from_array(result, chunks=raster.data.chunks)
675+
676+
677+
def _build_global_kdtree(raster, y_coords, x_coords, target_values,
678+
max_distance, p, target_counts, coords_cache,
679+
chunks_y, chunks_x):
680+
"""Global KDTree proximity — fast, lazy via da.map_blocks."""
681+
n_tile_y = len(chunks_y)
682+
n_tile_x = len(chunks_x)
683+
y_offsets = _chunk_offsets(chunks_y)
684+
x_offsets = _chunk_offsets(chunks_x)
685+
686+
target_coords = _collect_region_targets(
687+
raster, 0, n_tile_y, 0, n_tile_x,
688+
target_values, target_counts,
689+
y_coords, x_coords, y_offsets, x_offsets, coords_cache,
690+
)
459691

460-
target_coords = np.concatenate(target_list)
461692
tree = cKDTree(target_coords)
462693

463-
# Phase 2: query tree per chunk via map_blocks
464-
chunk_fn = partial(_kdtree_chunk_fn,
465-
y_coords_1d=y_coords,
466-
x_coords_1d=x_coords,
467-
tree=tree,
468-
max_distance=max_distance if np.isfinite(max_distance) else np.inf,
469-
p=p)
694+
chunk_fn = partial(
695+
_kdtree_chunk_fn,
696+
y_coords_1d=y_coords,
697+
x_coords_1d=x_coords,
698+
tree=tree,
699+
max_distance=max_distance if np.isfinite(max_distance) else np.inf,
700+
p=p,
701+
)
470702

471-
result = da.map_blocks(
703+
return da.map_blocks(
472704
chunk_fn,
473705
raster.data,
474706
dtype=np.float32,
475707
meta=np.array((), dtype=np.float32),
476708
)
477-
return result
709+
710+
711+
def _process_dask_kdtree(raster, x_coords, y_coords,
712+
target_values, max_distance, distance_metric):
713+
"""Memory-guarded k-d tree proximity for dask arrays.
714+
715+
Phase 0: stream through chunks counting targets (with caching).
716+
Then choose global tree (fast, lazy) or tiled tree (memory-safe, eager)
717+
based on estimated memory usage.
718+
"""
719+
p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1
720+
721+
chunks_y, chunks_x = raster.data.chunks
722+
723+
# Phase 0: streaming count pass
724+
target_counts, total_targets, coords_cache = _stream_target_counts(
725+
raster, target_values, y_coords, x_coords, chunks_y, chunks_x,
726+
)
727+
728+
if total_targets == 0:
729+
return da.full(raster.shape, np.nan, dtype=np.float32,
730+
chunks=raster.data.chunks)
731+
732+
# Memory decision: 16 bytes per coord pair + ~32 bytes tree overhead
733+
estimate = total_targets * 48
734+
avail = _available_memory_bytes()
735+
736+
if estimate < 0.5 * avail:
737+
return _build_global_kdtree(
738+
raster, y_coords, x_coords, target_values,
739+
max_distance, p, target_counts, coords_cache,
740+
chunks_y, chunks_x,
741+
)
742+
else:
743+
return _build_tiled_kdtree(
744+
raster, y_coords, x_coords, target_values,
745+
max_distance, p, target_counts, coords_cache,
746+
chunks_y, chunks_x,
747+
)
478748

479749

480750
def _process(

0 commit comments

Comments
 (0)