Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 304 additions & 34 deletions xrspatial/proximity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import partial
from math import sqrt

Expand All @@ -15,6 +16,7 @@
import xarray as xr
from numba import prange

from xrspatial.pathfinding import _available_memory_bytes
from xrspatial.utils import get_dataarray_resolution, ngjit
from xrspatial.dataset_support import supports_dataset

Expand Down Expand Up @@ -426,55 +428,323 @@ def _kdtree_chunk_fn(block, y_coords_1d, x_coords_1d,
return dists


def _process_dask_kdtree(raster, x_coords, y_coords,
target_values, max_distance, distance_metric):
"""Two-phase k-d tree proximity for unbounded dask arrays."""
p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1
def _target_mask(chunk_data, target_values):
"""Boolean mask of target pixels in *chunk_data*."""
if len(target_values) == 0:
return np.isfinite(chunk_data) & (chunk_data != 0)
return np.isin(chunk_data, target_values) & np.isfinite(chunk_data)

# Phase 1: stream through chunks to collect target coordinates
target_list = []
chunks_y, chunks_x = raster.data.chunks
y_offset = 0
for iy, cy in enumerate(chunks_y):
x_offset = 0
for ix, cx in enumerate(chunks_x):

def _stream_target_counts(raster, target_values, y_coords, x_coords,
chunks_y, chunks_x):
"""Stream all dask chunks, counting targets per chunk.

Caches per-chunk coordinate arrays within a 25% memory budget to
reduce re-reads in later phases.

Returns
-------
target_counts : ndarray, shape (n_tile_y, n_tile_x), dtype int64
total_targets : int
coords_cache : dict (iy, ix) -> ndarray shape (N, 2)
"""
n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)
target_counts = np.zeros((n_tile_y, n_tile_x), dtype=np.int64)
coords_cache = {}
cache_bytes = 0
budget = int(0.25 * _available_memory_bytes())

y_offsets = np.zeros(n_tile_y + 1, dtype=np.int64)
np.cumsum(chunks_y, out=y_offsets[1:])
x_offsets = np.zeros(n_tile_x + 1, dtype=np.int64)
np.cumsum(chunks_x, out=x_offsets[1:])

for iy in range(n_tile_y):
for ix in range(n_tile_x):
chunk_data = raster.data.blocks[iy, ix].compute()
if len(target_values) == 0:
mask = np.isfinite(chunk_data) & (chunk_data != 0)
else:
mask = np.isin(chunk_data, target_values) & np.isfinite(chunk_data)
mask = _target_mask(chunk_data, target_values)
rows, cols = np.where(mask)
if len(rows) > 0:
n = len(rows)
target_counts[iy, ix] = n
if n > 0:
coords = np.column_stack([
y_coords[y_offset + rows],
x_coords[x_offset + cols],
y_coords[y_offsets[iy] + rows],
x_coords[x_offsets[ix] + cols],
])
target_list.append(coords)
x_offset += cx
y_offset += cy
entry_bytes = coords.nbytes
if cache_bytes + entry_bytes <= budget:
coords_cache[(iy, ix)] = coords
cache_bytes += entry_bytes

if len(target_list) == 0:
return da.full(raster.shape, np.nan, dtype=np.float32,
chunks=raster.data.chunks)
total_targets = int(target_counts.sum())
return target_counts, total_targets, coords_cache


def _chunk_offsets(chunks):
"""Return cumulative offset array of length len(chunks)+1."""
offsets = np.zeros(len(chunks) + 1, dtype=np.int64)
np.cumsum(chunks, out=offsets[1:])
return offsets


def _collect_region_targets(raster, jy_lo, jy_hi, jx_lo, jx_hi,
target_values, target_counts,
y_coords, x_coords,
y_offsets, x_offsets, coords_cache):
"""Collect target (y, x) coords from chunks in [jy_lo:jy_hi, jx_lo:jx_hi].

Uses cache where available, re-reads uncached chunks via .compute().
Returns ndarray shape (N, 2) or None if no targets in region.
"""
parts = []
for iy in range(jy_lo, jy_hi):
for ix in range(jx_lo, jx_hi):
if target_counts[iy, ix] == 0:
continue
if (iy, ix) in coords_cache:
parts.append(coords_cache[(iy, ix)])
else:
chunk_data = raster.data.blocks[iy, ix].compute()
mask = _target_mask(chunk_data, target_values)
rows, cols = np.where(mask)
if len(rows) > 0:
coords = np.column_stack([
y_coords[y_offsets[iy] + rows],
x_coords[x_offsets[ix] + cols],
])
parts.append(coords)
if not parts:
return None
return np.concatenate(parts)


def _min_boundary_distance(iy, ix, y_coords, x_coords,
y_offsets, x_offsets,
jy_lo, jy_hi, jx_lo, jx_hi,
n_tile_y, n_tile_x):
"""Lower bound on distance from any pixel in chunk (iy, ix) to any point
outside the search region [jy_lo:jy_hi, jx_lo:jx_hi].

For each of the 4 sides where the search region doesn't reach the raster
edge, compute the gap between the chunk's edge pixel coordinate and the
first pixel outside the search region. The minimum of these gaps is
a valid lower bound for both L1 and L2 norms.

Returns float (inf if search covers the full raster).
"""
gaps = []

# Top boundary
if jy_lo > 0:
# chunk's top-edge row in pixel space
chunk_top_row = y_offsets[iy]
# first row outside region (above)
outside_row = y_offsets[jy_lo] - 1
gap = abs(float(y_coords[chunk_top_row]) - float(y_coords[outside_row]))
gaps.append(gap)

# Bottom boundary
if jy_hi < n_tile_y:
chunk_bot_row = y_offsets[iy + 1] - 1
outside_row = y_offsets[jy_hi]
gap = abs(float(y_coords[chunk_bot_row]) - float(y_coords[outside_row]))
gaps.append(gap)

# Left boundary
if jx_lo > 0:
chunk_left_col = x_offsets[ix]
outside_col = x_offsets[jx_lo] - 1
gap = abs(float(x_coords[chunk_left_col]) - float(x_coords[outside_col]))
gaps.append(gap)

# Right boundary
if jx_hi < n_tile_x:
chunk_right_col = x_offsets[ix + 1] - 1
outside_col = x_offsets[jx_hi]
gap = abs(float(x_coords[chunk_right_col]) - float(x_coords[outside_col]))
gaps.append(gap)

return min(gaps) if gaps else np.inf


def _tiled_chunk_proximity(raster, iy, ix, y_coords, x_coords,
y_offsets, x_offsets,
target_values, target_counts,
coords_cache, max_distance, p,
n_tile_y, n_tile_x):
"""Expanding-ring local KDTree for one output chunk.

Returns ndarray shape (h, w), dtype float32.
"""
h = int(y_offsets[iy + 1] - y_offsets[iy])
w = int(x_offsets[ix + 1] - x_offsets[ix])

# Build query points for this chunk
chunk_ys = y_coords[y_offsets[iy]:y_offsets[iy + 1]]
chunk_xs = x_coords[x_offsets[ix]:x_offsets[ix + 1]]
yy, xx = np.meshgrid(chunk_ys, chunk_xs, indexing='ij')
query_pts = np.column_stack([yy.ravel(), xx.ravel()])

ring = 0
while True:
jy_lo = max(iy - ring, 0)
jy_hi = min(iy + 1 + ring, n_tile_y)
jx_lo = max(ix - ring, 0)
jx_hi = min(ix + 1 + ring, n_tile_x)

covers_full = (jy_lo == 0 and jy_hi == n_tile_y
and jx_lo == 0 and jx_hi == n_tile_x)

target_coords = _collect_region_targets(
raster, jy_lo, jy_hi, jx_lo, jx_hi,
target_values, target_counts,
y_coords, x_coords, y_offsets, x_offsets, coords_cache,
)

if target_coords is None:
if covers_full:
# No targets in entire raster
return np.full((h, w), np.nan, dtype=np.float32)
ring += 1
continue

tree = cKDTree(target_coords)
ub = max_distance if np.isfinite(max_distance) else np.inf
dists, _ = tree.query(query_pts, p=p, distance_upper_bound=ub)
dists = dists.reshape(h, w).astype(np.float32)
dists[dists == np.inf] = np.nan

if covers_full:
return dists

# Validate: max_nearest_dist < min_boundary_distance
max_nearest = np.nanmax(dists) if not np.all(np.isnan(dists)) else 0.0
min_bd = _min_boundary_distance(
iy, ix, y_coords, x_coords, y_offsets, x_offsets,
jy_lo, jy_hi, jx_lo, jx_hi, n_tile_y, n_tile_x,
)
if max_nearest < min_bd:
return dists

ring += 1


def _build_tiled_kdtree(raster, y_coords, x_coords, target_values,
max_distance, p, target_counts, coords_cache,
chunks_y, chunks_x):
"""Tiled (eager) KDTree proximity — memory-safe fallback."""
H, W = raster.shape
result_bytes = H * W * 4 # float32
avail = _available_memory_bytes()
if result_bytes > 0.8 * avail:
raise MemoryError(
f"Proximity result array ({H}x{W}, {result_bytes / 1e9:.1f} GB) "
f"exceeds 80% of available memory ({avail / 1e9:.1f} GB)."
)

warnings.warn(
"proximity: target coordinates exceed 50% of available memory; "
"using tiled KDTree fallback (slower but memory-safe).",
ResourceWarning,
stacklevel=4,
)

n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)
y_offsets = _chunk_offsets(chunks_y)
x_offsets = _chunk_offsets(chunks_x)

result = np.full((H, W), np.nan, dtype=np.float32)

for iy in range(n_tile_y):
for ix in range(n_tile_x):
chunk_result = _tiled_chunk_proximity(
raster, iy, ix, y_coords, x_coords,
y_offsets, x_offsets,
target_values, target_counts, coords_cache,
max_distance, p, n_tile_y, n_tile_x,
)
r0 = int(y_offsets[iy])
r1 = int(y_offsets[iy + 1])
c0 = int(x_offsets[ix])
c1 = int(x_offsets[ix + 1])
result[r0:r1, c0:c1] = chunk_result

return da.from_array(result, chunks=raster.data.chunks)


def _build_global_kdtree(raster, y_coords, x_coords, target_values,
max_distance, p, target_counts, coords_cache,
chunks_y, chunks_x):
"""Global KDTree proximity — fast, lazy via da.map_blocks."""
n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)
y_offsets = _chunk_offsets(chunks_y)
x_offsets = _chunk_offsets(chunks_x)

target_coords = _collect_region_targets(
raster, 0, n_tile_y, 0, n_tile_x,
target_values, target_counts,
y_coords, x_coords, y_offsets, x_offsets, coords_cache,
)

target_coords = np.concatenate(target_list)
tree = cKDTree(target_coords)

# Phase 2: query tree per chunk via map_blocks
chunk_fn = partial(_kdtree_chunk_fn,
y_coords_1d=y_coords,
x_coords_1d=x_coords,
tree=tree,
max_distance=max_distance if np.isfinite(max_distance) else np.inf,
p=p)
chunk_fn = partial(
_kdtree_chunk_fn,
y_coords_1d=y_coords,
x_coords_1d=x_coords,
tree=tree,
max_distance=max_distance if np.isfinite(max_distance) else np.inf,
p=p,
)

result = da.map_blocks(
return da.map_blocks(
chunk_fn,
raster.data,
dtype=np.float32,
meta=np.array((), dtype=np.float32),
)
return result


def _process_dask_kdtree(raster, x_coords, y_coords,
target_values, max_distance, distance_metric):
"""Memory-guarded k-d tree proximity for dask arrays.

Phase 0: stream through chunks counting targets (with caching).
Then choose global tree (fast, lazy) or tiled tree (memory-safe, eager)
based on estimated memory usage.
"""
p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1

chunks_y, chunks_x = raster.data.chunks

# Phase 0: streaming count pass
target_counts, total_targets, coords_cache = _stream_target_counts(
raster, target_values, y_coords, x_coords, chunks_y, chunks_x,
)

if total_targets == 0:
return da.full(raster.shape, np.nan, dtype=np.float32,
chunks=raster.data.chunks)

# Memory decision: 16 bytes per coord pair + ~32 bytes tree overhead
estimate = total_targets * 48
avail = _available_memory_bytes()

if estimate < 0.5 * avail:
return _build_global_kdtree(
raster, y_coords, x_coords, target_values,
max_distance, p, target_counts, coords_cache,
chunks_y, chunks_x,
)
else:
return _build_tiled_kdtree(
raster, y_coords, x_coords, target_values,
max_distance, p, target_counts, coords_cache,
chunks_y, chunks_x,
)


def _process(
Expand Down
Loading