diff --git a/xrspatial/proximity.py b/xrspatial/proximity.py index b2c281e6..40738d62 100644 --- a/xrspatial/proximity.py +++ b/xrspatial/proximity.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from math import sqrt @@ -15,6 +16,7 @@ import xarray as xr from numba import prange +from xrspatial.pathfinding import _available_memory_bytes from xrspatial.utils import get_dataarray_resolution, ngjit from xrspatial.dataset_support import supports_dataset @@ -426,55 +428,323 @@ def _kdtree_chunk_fn(block, y_coords_1d, x_coords_1d, return dists -def _process_dask_kdtree(raster, x_coords, y_coords, - target_values, max_distance, distance_metric): - """Two-phase k-d tree proximity for unbounded dask arrays.""" - p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1 +def _target_mask(chunk_data, target_values): + """Boolean mask of target pixels in *chunk_data*.""" + if len(target_values) == 0: + return np.isfinite(chunk_data) & (chunk_data != 0) + return np.isin(chunk_data, target_values) & np.isfinite(chunk_data) - # Phase 1: stream through chunks to collect target coordinates - target_list = [] - chunks_y, chunks_x = raster.data.chunks - y_offset = 0 - for iy, cy in enumerate(chunks_y): - x_offset = 0 - for ix, cx in enumerate(chunks_x): + +def _stream_target_counts(raster, target_values, y_coords, x_coords, + chunks_y, chunks_x): + """Stream all dask chunks, counting targets per chunk. + + Caches per-chunk coordinate arrays within a 25% memory budget to + reduce re-reads in later phases. + + Returns + ------- + target_counts : ndarray, shape (n_tile_y, n_tile_x), dtype int64 + total_targets : int + coords_cache : dict (iy, ix) -> ndarray shape (N, 2) + """ + n_tile_y = len(chunks_y) + n_tile_x = len(chunks_x) + target_counts = np.zeros((n_tile_y, n_tile_x), dtype=np.int64) + coords_cache = {} + cache_bytes = 0 + budget = int(0.25 * _available_memory_bytes()) + + y_offsets = np.zeros(n_tile_y + 1, dtype=np.int64) + np.cumsum(chunks_y, out=y_offsets[1:]) + x_offsets = np.zeros(n_tile_x + 1, dtype=np.int64) + np.cumsum(chunks_x, out=x_offsets[1:]) + + for iy in range(n_tile_y): + for ix in range(n_tile_x): chunk_data = raster.data.blocks[iy, ix].compute() - if len(target_values) == 0: - mask = np.isfinite(chunk_data) & (chunk_data != 0) - else: - mask = np.isin(chunk_data, target_values) & np.isfinite(chunk_data) + mask = _target_mask(chunk_data, target_values) rows, cols = np.where(mask) - if len(rows) > 0: + n = len(rows) + target_counts[iy, ix] = n + if n > 0: coords = np.column_stack([ - y_coords[y_offset + rows], - x_coords[x_offset + cols], + y_coords[y_offsets[iy] + rows], + x_coords[x_offsets[ix] + cols], ]) - target_list.append(coords) - x_offset += cx - y_offset += cy + entry_bytes = coords.nbytes + if cache_bytes + entry_bytes <= budget: + coords_cache[(iy, ix)] = coords + cache_bytes += entry_bytes - if len(target_list) == 0: - return da.full(raster.shape, np.nan, dtype=np.float32, - chunks=raster.data.chunks) + total_targets = int(target_counts.sum()) + return target_counts, total_targets, coords_cache + + +def _chunk_offsets(chunks): + """Return cumulative offset array of length len(chunks)+1.""" + offsets = np.zeros(len(chunks) + 1, dtype=np.int64) + np.cumsum(chunks, out=offsets[1:]) + return offsets + + +def _collect_region_targets(raster, jy_lo, jy_hi, jx_lo, jx_hi, + target_values, target_counts, + y_coords, x_coords, + y_offsets, x_offsets, coords_cache): + """Collect target (y, x) coords from chunks in [jy_lo:jy_hi, jx_lo:jx_hi]. + + Uses cache where available, re-reads uncached chunks via .compute(). + Returns ndarray shape (N, 2) or None if no targets in region. + """ + parts = [] + for iy in range(jy_lo, jy_hi): + for ix in range(jx_lo, jx_hi): + if target_counts[iy, ix] == 0: + continue + if (iy, ix) in coords_cache: + parts.append(coords_cache[(iy, ix)]) + else: + chunk_data = raster.data.blocks[iy, ix].compute() + mask = _target_mask(chunk_data, target_values) + rows, cols = np.where(mask) + if len(rows) > 0: + coords = np.column_stack([ + y_coords[y_offsets[iy] + rows], + x_coords[x_offsets[ix] + cols], + ]) + parts.append(coords) + if not parts: + return None + return np.concatenate(parts) + + +def _min_boundary_distance(iy, ix, y_coords, x_coords, + y_offsets, x_offsets, + jy_lo, jy_hi, jx_lo, jx_hi, + n_tile_y, n_tile_x): + """Lower bound on distance from any pixel in chunk (iy, ix) to any point + outside the search region [jy_lo:jy_hi, jx_lo:jx_hi]. + + For each of the 4 sides where the search region doesn't reach the raster + edge, compute the gap between the chunk's edge pixel coordinate and the + first pixel outside the search region. The minimum of these gaps is + a valid lower bound for both L1 and L2 norms. + + Returns float (inf if search covers the full raster). + """ + gaps = [] + + # Top boundary + if jy_lo > 0: + # chunk's top-edge row in pixel space + chunk_top_row = y_offsets[iy] + # first row outside region (above) + outside_row = y_offsets[jy_lo] - 1 + gap = abs(float(y_coords[chunk_top_row]) - float(y_coords[outside_row])) + gaps.append(gap) + + # Bottom boundary + if jy_hi < n_tile_y: + chunk_bot_row = y_offsets[iy + 1] - 1 + outside_row = y_offsets[jy_hi] + gap = abs(float(y_coords[chunk_bot_row]) - float(y_coords[outside_row])) + gaps.append(gap) + + # Left boundary + if jx_lo > 0: + chunk_left_col = x_offsets[ix] + outside_col = x_offsets[jx_lo] - 1 + gap = abs(float(x_coords[chunk_left_col]) - float(x_coords[outside_col])) + gaps.append(gap) + + # Right boundary + if jx_hi < n_tile_x: + chunk_right_col = x_offsets[ix + 1] - 1 + outside_col = x_offsets[jx_hi] + gap = abs(float(x_coords[chunk_right_col]) - float(x_coords[outside_col])) + gaps.append(gap) + + return min(gaps) if gaps else np.inf + + +def _tiled_chunk_proximity(raster, iy, ix, y_coords, x_coords, + y_offsets, x_offsets, + target_values, target_counts, + coords_cache, max_distance, p, + n_tile_y, n_tile_x): + """Expanding-ring local KDTree for one output chunk. + + Returns ndarray shape (h, w), dtype float32. + """ + h = int(y_offsets[iy + 1] - y_offsets[iy]) + w = int(x_offsets[ix + 1] - x_offsets[ix]) + + # Build query points for this chunk + chunk_ys = y_coords[y_offsets[iy]:y_offsets[iy + 1]] + chunk_xs = x_coords[x_offsets[ix]:x_offsets[ix + 1]] + yy, xx = np.meshgrid(chunk_ys, chunk_xs, indexing='ij') + query_pts = np.column_stack([yy.ravel(), xx.ravel()]) + + ring = 0 + while True: + jy_lo = max(iy - ring, 0) + jy_hi = min(iy + 1 + ring, n_tile_y) + jx_lo = max(ix - ring, 0) + jx_hi = min(ix + 1 + ring, n_tile_x) + + covers_full = (jy_lo == 0 and jy_hi == n_tile_y + and jx_lo == 0 and jx_hi == n_tile_x) + + target_coords = _collect_region_targets( + raster, jy_lo, jy_hi, jx_lo, jx_hi, + target_values, target_counts, + y_coords, x_coords, y_offsets, x_offsets, coords_cache, + ) + + if target_coords is None: + if covers_full: + # No targets in entire raster + return np.full((h, w), np.nan, dtype=np.float32) + ring += 1 + continue + + tree = cKDTree(target_coords) + ub = max_distance if np.isfinite(max_distance) else np.inf + dists, _ = tree.query(query_pts, p=p, distance_upper_bound=ub) + dists = dists.reshape(h, w).astype(np.float32) + dists[dists == np.inf] = np.nan + + if covers_full: + return dists + + # Validate: max_nearest_dist < min_boundary_distance + max_nearest = np.nanmax(dists) if not np.all(np.isnan(dists)) else 0.0 + min_bd = _min_boundary_distance( + iy, ix, y_coords, x_coords, y_offsets, x_offsets, + jy_lo, jy_hi, jx_lo, jx_hi, n_tile_y, n_tile_x, + ) + if max_nearest < min_bd: + return dists + + ring += 1 + + +def _build_tiled_kdtree(raster, y_coords, x_coords, target_values, + max_distance, p, target_counts, coords_cache, + chunks_y, chunks_x): + """Tiled (eager) KDTree proximity — memory-safe fallback.""" + H, W = raster.shape + result_bytes = H * W * 4 # float32 + avail = _available_memory_bytes() + if result_bytes > 0.8 * avail: + raise MemoryError( + f"Proximity result array ({H}x{W}, {result_bytes / 1e9:.1f} GB) " + f"exceeds 80% of available memory ({avail / 1e9:.1f} GB)." + ) + + warnings.warn( + "proximity: target coordinates exceed 50% of available memory; " + "using tiled KDTree fallback (slower but memory-safe).", + ResourceWarning, + stacklevel=4, + ) + + n_tile_y = len(chunks_y) + n_tile_x = len(chunks_x) + y_offsets = _chunk_offsets(chunks_y) + x_offsets = _chunk_offsets(chunks_x) + + result = np.full((H, W), np.nan, dtype=np.float32) + + for iy in range(n_tile_y): + for ix in range(n_tile_x): + chunk_result = _tiled_chunk_proximity( + raster, iy, ix, y_coords, x_coords, + y_offsets, x_offsets, + target_values, target_counts, coords_cache, + max_distance, p, n_tile_y, n_tile_x, + ) + r0 = int(y_offsets[iy]) + r1 = int(y_offsets[iy + 1]) + c0 = int(x_offsets[ix]) + c1 = int(x_offsets[ix + 1]) + result[r0:r1, c0:c1] = chunk_result + + return da.from_array(result, chunks=raster.data.chunks) + + +def _build_global_kdtree(raster, y_coords, x_coords, target_values, + max_distance, p, target_counts, coords_cache, + chunks_y, chunks_x): + """Global KDTree proximity — fast, lazy via da.map_blocks.""" + n_tile_y = len(chunks_y) + n_tile_x = len(chunks_x) + y_offsets = _chunk_offsets(chunks_y) + x_offsets = _chunk_offsets(chunks_x) + + target_coords = _collect_region_targets( + raster, 0, n_tile_y, 0, n_tile_x, + target_values, target_counts, + y_coords, x_coords, y_offsets, x_offsets, coords_cache, + ) - target_coords = np.concatenate(target_list) tree = cKDTree(target_coords) - # Phase 2: query tree per chunk via map_blocks - chunk_fn = partial(_kdtree_chunk_fn, - y_coords_1d=y_coords, - x_coords_1d=x_coords, - tree=tree, - max_distance=max_distance if np.isfinite(max_distance) else np.inf, - p=p) + chunk_fn = partial( + _kdtree_chunk_fn, + y_coords_1d=y_coords, + x_coords_1d=x_coords, + tree=tree, + max_distance=max_distance if np.isfinite(max_distance) else np.inf, + p=p, + ) - result = da.map_blocks( + return da.map_blocks( chunk_fn, raster.data, dtype=np.float32, meta=np.array((), dtype=np.float32), ) - return result + + +def _process_dask_kdtree(raster, x_coords, y_coords, + target_values, max_distance, distance_metric): + """Memory-guarded k-d tree proximity for dask arrays. + + Phase 0: stream through chunks counting targets (with caching). + Then choose global tree (fast, lazy) or tiled tree (memory-safe, eager) + based on estimated memory usage. + """ + p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1 + + chunks_y, chunks_x = raster.data.chunks + + # Phase 0: streaming count pass + target_counts, total_targets, coords_cache = _stream_target_counts( + raster, target_values, y_coords, x_coords, chunks_y, chunks_x, + ) + + if total_targets == 0: + return da.full(raster.shape, np.nan, dtype=np.float32, + chunks=raster.data.chunks) + + # Memory decision: 16 bytes per coord pair + ~32 bytes tree overhead + estimate = total_targets * 48 + avail = _available_memory_bytes() + + if estimate < 0.5 * avail: + return _build_global_kdtree( + raster, y_coords, x_coords, target_values, + max_distance, p, target_counts, coords_cache, + chunks_y, chunks_x, + ) + else: + return _build_tiled_kdtree( + raster, y_coords, x_coords, target_values, + max_distance, p, target_counts, coords_cache, + chunks_y, chunks_x, + ) def _process( diff --git a/xrspatial/tests/test_proximity.py b/xrspatial/tests/test_proximity.py index 47cd3282..7c36b803 100644 --- a/xrspatial/tests/test_proximity.py +++ b/xrspatial/tests/test_proximity.py @@ -542,3 +542,172 @@ def spy(*args, **kwargs): assert len(kdtree_called) == 0, "k-d tree path should not be used for GREAT_CIRCLE" assert isinstance(result.data, da.Array) + + +# --------------------------------------------------------------------------- +# Tiled KDTree fallback tests (memory-guarded path) +# --------------------------------------------------------------------------- + +def _force_tiled_proximity(raster, **kwargs): + """Run proximity with _available_memory_bytes mocked to force tiled path. + + Uses a counter-based side_effect: + call 1 (_stream_target_counts cache budget): returns 1 → tiny cache + call 2 (_process_dask_kdtree decision): returns 1 → forces tiled + call 3+ (_build_tiled_kdtree result check): returns 10 GB → passes guard + """ + call_count = [0] + + def _small_then_large(): + call_count[0] += 1 + if call_count[0] <= 2: + return 1 + return 10 * 1024 ** 3 + + with patch('xrspatial.proximity._available_memory_bytes', + side_effect=_small_then_large): + return proximity(raster, **kwargs) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_tiled_matches_numpy(): + """Dense raster forced through tiled path must match numpy baseline.""" + height, width = 20, 30 + rng = np.random.RandomState(42) + data = rng.choice([0.0, 1.0, 2.0], size=(height, width), p=[0.3, 0.4, 0.3]) + _lon = np.linspace(0, 29, width) + _lat = np.linspace(19, 0, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + + numpy_result = proximity(raster, x='lon', y='lat') + + raster.data = da.from_array(data, chunks=(5, 10)) + dask_result = _force_tiled_proximity(raster, x='lon', y='lat') + + assert isinstance(dask_result.data, da.Array) + np.testing.assert_allclose( + dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True, + ) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_tiled_manhattan(): + """Tiled path with MANHATTAN metric matches numpy.""" + height, width = 16, 20 + rng = np.random.RandomState(99) + data = rng.choice([0.0, 1.0, 2.0], size=(height, width), p=[0.3, 0.4, 0.3]) + _lon = np.linspace(0, 19, width) + _lat = np.linspace(15, 0, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + + numpy_result = proximity(raster, x='lon', y='lat', + distance_metric='MANHATTAN') + + raster.data = da.from_array(data, chunks=(4, 5)) + dask_result = _force_tiled_proximity(raster, x='lon', y='lat', + distance_metric='MANHATTAN') + + assert isinstance(dask_result.data, da.Array) + np.testing.assert_allclose( + dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True, + ) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_tiled_single_target(): + """One target in a corner, many chunks → exercises max ring expansion.""" + height, width = 20, 20 + data = np.zeros((height, width), dtype=np.float64) + data[0, 0] = 1.0 + + _lon = np.linspace(0, 19, width) + _lat = np.linspace(19, 0, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + + numpy_result = proximity(raster, x='lon', y='lat') + + raster.data = da.from_array(data, chunks=(5, 5)) + dask_result = _force_tiled_proximity(raster, x='lon', y='lat') + + assert isinstance(dask_result.data, da.Array) + np.testing.assert_allclose( + dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True, + ) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_tiled_all_targets(): + """Every pixel is a target → result should be all zeros.""" + height, width = 12, 12 + data = np.ones((height, width), dtype=np.float64) + _lon = np.linspace(0, 11, width) + _lat = np.linspace(11, 0, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + + raster.data = da.from_array(data, chunks=(4, 4)) + dask_result = _force_tiled_proximity(raster, x='lon', y='lat') + + assert isinstance(dask_result.data, da.Array) + np.testing.assert_allclose(dask_result.values, 0.0) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_tiled_no_targets(): + """No targets, forced tiled path → all NaN.""" + data = np.zeros((10, 10), dtype=np.float64) + _lon = np.arange(10, dtype=np.float64) + _lat = np.arange(10, dtype=np.float64)[::-1] + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + raster.data = da.from_array(data, chunks=(5, 5)) + + # Even with tiny memory, zero targets should return early (all NaN) + dask_result = _force_tiled_proximity(raster, x='lon', y='lat') + assert isinstance(dask_result.data, da.Array) + assert np.all(np.isnan(dask_result.values)) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_tiled_warns(): + """Verify ResourceWarning fires when tiled fallback is selected.""" + height, width = 10, 10 + data = np.zeros((height, width), dtype=np.float64) + data[5, 5] = 1.0 + _lon = np.linspace(0, 9, width) + _lat = np.linspace(9, 0, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + raster.data = da.from_array(data, chunks=(5, 5)) + + with pytest.warns(ResourceWarning, match="tiled KDTree fallback"): + _force_tiled_proximity(raster, x='lon', y='lat') + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_global_uses_cache(): + """Global path still works correctly after Phase 0 restructure.""" + raster = _make_kdtree_raster() + numpy_raster = raster.copy() + numpy_raster.data = raster.data.compute() + + numpy_result = proximity(numpy_raster, x='lon', y='lat') + + # Global path (default): _available_memory_bytes returns large value + with patch('xrspatial.proximity._available_memory_bytes', + return_value=10 * 1024**3): + dask_result = proximity(raster, x='lon', y='lat') + + assert isinstance(dask_result.data, da.Array) + np.testing.assert_allclose( + dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True, + )