Skip to content

Commit 1b2754b

Browse files
committed
Add k-d tree dask path for unbounded proximity to avoid single-chunk rechunk
The existing dask proximity path rechunked the entire raster into one chunk when max_distance was unbounded, defeating dask's out-of-core purpose. This adds a two-phase scipy.spatial.cKDTree approach: Phase 1 streams chunks to collect target coordinates, Phase 2 queries the tree per-chunk via map_blocks. Exact results, memory proportional to targets not raster size. Supports EUCLIDEAN and MANHATTAN metrics; GREAT_CIRCLE/ALLOCATION/DIRECTION fall back to the existing single-chunk path. Gracefully degrades when scipy is absent.
1 parent 868c843 commit 1b2754b

File tree

3 files changed

+292
-10
lines changed

3 files changed

+292
-10
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ tests =
6363
pyarrow
6464
pytest
6565
pytest-cov
66+
scipy
6667

6768

6869
[flake8]

xrspatial/proximity.py

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
from functools import partial
12
from math import sqrt
23

34
try:
45
import dask.array as da
56
except ImportError:
67
da = None
78

9+
try:
10+
from scipy.spatial import cKDTree
11+
except ImportError:
12+
cKDTree = None
13+
814
import numpy as np
915
import xarray as xr
1016
from numba import prange
@@ -398,6 +404,79 @@ def _process_proximity_line(
398404
return
399405

400406

407+
def _kdtree_chunk_fn(block, y_coords_1d, x_coords_1d,
408+
tree, block_info, max_distance, p):
409+
"""Query k-d tree for nearest target distance for every pixel in block."""
410+
if block_info is None or block_info == []:
411+
return np.full(block.shape, np.nan, dtype=np.float32)
412+
413+
y_start = block_info[0]['array-location'][0][0]
414+
x_start = block_info[0]['array-location'][1][0]
415+
h, w = block.shape
416+
417+
chunk_ys = y_coords_1d[y_start:y_start + h]
418+
chunk_xs = x_coords_1d[x_start:x_start + w]
419+
yy, xx = np.meshgrid(chunk_ys, chunk_xs, indexing='ij')
420+
query_pts = np.column_stack([yy.ravel(), xx.ravel()])
421+
422+
dists, _ = tree.query(query_pts, p=p,
423+
distance_upper_bound=max_distance)
424+
dists = dists.reshape(h, w).astype(np.float32)
425+
dists[dists == np.inf] = np.nan
426+
return dists
427+
428+
429+
def _process_dask_kdtree(raster, x_coords, y_coords,
430+
target_values, max_distance, distance_metric):
431+
"""Two-phase k-d tree proximity for unbounded dask arrays."""
432+
p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1
433+
434+
# Phase 1: stream through chunks to collect target coordinates
435+
target_list = []
436+
chunks_y, chunks_x = raster.data.chunks
437+
y_offset = 0
438+
for iy, cy in enumerate(chunks_y):
439+
x_offset = 0
440+
for ix, cx in enumerate(chunks_x):
441+
chunk_data = raster.data.blocks[iy, ix].compute()
442+
if len(target_values) == 0:
443+
mask = np.isfinite(chunk_data) & (chunk_data != 0)
444+
else:
445+
mask = np.isin(chunk_data, target_values) & np.isfinite(chunk_data)
446+
rows, cols = np.where(mask)
447+
if len(rows) > 0:
448+
coords = np.column_stack([
449+
y_coords[y_offset + rows],
450+
x_coords[x_offset + cols],
451+
])
452+
target_list.append(coords)
453+
x_offset += cx
454+
y_offset += cy
455+
456+
if len(target_list) == 0:
457+
return da.full(raster.shape, np.nan, dtype=np.float32,
458+
chunks=raster.data.chunks)
459+
460+
target_coords = np.concatenate(target_list)
461+
tree = cKDTree(target_coords)
462+
463+
# Phase 2: query tree per chunk via map_blocks
464+
chunk_fn = partial(_kdtree_chunk_fn,
465+
y_coords_1d=y_coords,
466+
x_coords_1d=x_coords,
467+
tree=tree,
468+
max_distance=max_distance if np.isfinite(max_distance) else np.inf,
469+
p=p)
470+
471+
result = da.map_blocks(
472+
chunk_fn,
473+
raster.data,
474+
dtype=np.float32,
475+
meta=np.array((), dtype=np.float32),
476+
)
477+
return result
478+
479+
401480
def _process(
402481
raster,
403482
x,
@@ -633,16 +712,26 @@ def _process_dask(raster, xs, ys):
633712
result = _process_numpy(raster.data, xs, ys)
634713

635714
elif da is not None and isinstance(raster.data, da.Array):
636-
# dask case - create coordinate arrays as dask arrays directly
637-
# This avoids materializing the full arrays in memory
638-
# Convert 1D coords to dask arrays first
639-
x_coords_da = da.from_array(x_coords, chunks=x_coords.shape[0])
640-
y_coords_da = da.from_array(y_coords, chunks=y_coords.shape[0])
641-
xs = da.tile(x_coords_da, (raster.shape[0], 1))
642-
ys = da.repeat(y_coords_da, raster.shape[1]).reshape(raster.shape)
643-
xs = xs.rechunk(raster.chunks)
644-
ys = ys.rechunk(raster.chunks)
645-
result = _process_dask(raster, xs, ys)
715+
use_kdtree = (
716+
cKDTree is not None
717+
and process_mode == PROXIMITY
718+
and distance_metric in (EUCLIDEAN, MANHATTAN)
719+
and max_distance >= max_possible_distance
720+
)
721+
if use_kdtree:
722+
result = _process_dask_kdtree(
723+
raster, x_coords, y_coords,
724+
target_values, max_distance, distance_metric,
725+
)
726+
else:
727+
# Existing path: build 2D coordinate arrays as dask arrays
728+
x_coords_da = da.from_array(x_coords, chunks=x_coords.shape[0])
729+
y_coords_da = da.from_array(y_coords, chunks=y_coords.shape[0])
730+
xs = da.tile(x_coords_da, (raster.shape[0], 1))
731+
ys = da.repeat(y_coords_da, raster.shape[1]).reshape(raster.shape)
732+
xs = xs.rechunk(raster.chunks)
733+
ys = ys.rechunk(raster.chunks)
734+
result = _process_dask(raster, xs, ys)
646735

647736
return result
648737

xrspatial/tests/test_proximity.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import patch
2+
13
try:
24
import dask.array as da
35
except ImportError:
@@ -350,3 +352,193 @@ def tracking_repeat(a, repeats, axis=None):
350352
assert computed.data[90, 100] == 0.0
351353
# Check that non-target pixels have positive distance
352354
assert computed.data[0, 0] > 0.0
355+
356+
357+
def _make_kdtree_raster(height=20, width=30, chunks=(10, 15)):
358+
"""Helper: build a small dask-backed raster with a few target pixels."""
359+
data = np.zeros((height, width), dtype=np.float64)
360+
data[3, 5] = 1.0
361+
data[12, 20] = 2.0
362+
data[18, 2] = 3.0
363+
_lon = np.linspace(0, 29, width)
364+
_lat = np.linspace(19, 0, height)
365+
raster = xr.DataArray(data, dims=['lat', 'lon'])
366+
raster['lon'] = _lon
367+
raster['lat'] = _lat
368+
raster.data = da.from_array(data, chunks=chunks)
369+
return raster
370+
371+
372+
@pytest.mark.skipif(da is None, reason="dask is not installed")
373+
@pytest.mark.parametrize("metric", ["EUCLIDEAN", "MANHATTAN"])
374+
def test_proximity_dask_kdtree_matches_numpy(metric):
375+
"""k-d tree dask result must match numpy result for the same raster."""
376+
raster = _make_kdtree_raster()
377+
numpy_raster = raster.copy()
378+
numpy_raster.data = raster.data.compute()
379+
380+
numpy_result = proximity(numpy_raster, x='lon', y='lat',
381+
distance_metric=metric)
382+
dask_result = proximity(raster, x='lon', y='lat',
383+
distance_metric=metric)
384+
385+
assert isinstance(dask_result.data, da.Array)
386+
np.testing.assert_allclose(
387+
dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True,
388+
)
389+
390+
391+
@pytest.mark.skipif(da is None, reason="dask is not installed")
392+
def test_proximity_dask_kdtree_no_large_arrays():
393+
"""No full-raster-sized numpy arrays should be created in k-d tree path."""
394+
height, width = 100, 120
395+
data = np.zeros((height, width), dtype=np.float64)
396+
data[10, 10] = 1.0
397+
data[50, 60] = 2.0
398+
399+
_lon = np.linspace(0, 119, width)
400+
_lat = np.linspace(99, 0, height)
401+
raster = xr.DataArray(data, dims=['lat', 'lon'])
402+
raster['lon'] = _lon
403+
raster['lat'] = _lat
404+
raster.data = da.from_array(data, chunks=(25, 30))
405+
406+
original_tile = np.tile
407+
original_repeat = np.repeat
408+
large_numpy_created = []
409+
410+
def tracking_tile(A, reps):
411+
result = original_tile(A, reps)
412+
if result.size >= height * width:
413+
large_numpy_created.append(('tile', result.shape))
414+
return result
415+
416+
def tracking_repeat(a, repeats, axis=None):
417+
result = original_repeat(a, repeats, axis=axis)
418+
if result.size >= height * width:
419+
large_numpy_created.append(('repeat', result.shape))
420+
return result
421+
422+
with patch.object(np, 'tile', tracking_tile):
423+
with patch.object(np, 'repeat', tracking_repeat):
424+
result = proximity(raster, x='lon', y='lat')
425+
426+
assert len(large_numpy_created) == 0, (
427+
f"Large numpy arrays created: {large_numpy_created}"
428+
)
429+
assert isinstance(result.data, da.Array)
430+
431+
432+
@pytest.mark.skipif(da is None, reason="dask is not installed")
433+
def test_proximity_dask_kdtree_with_target_values():
434+
"""target_values filtering works through the k-d tree path."""
435+
raster = _make_kdtree_raster()
436+
numpy_raster = raster.copy()
437+
numpy_raster.data = raster.data.compute()
438+
439+
target_values = [2, 3]
440+
numpy_result = proximity(numpy_raster, x='lon', y='lat',
441+
target_values=target_values)
442+
dask_result = proximity(raster, x='lon', y='lat',
443+
target_values=target_values)
444+
445+
assert isinstance(dask_result.data, da.Array)
446+
np.testing.assert_allclose(
447+
dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True,
448+
)
449+
450+
451+
@pytest.mark.skipif(da is None, reason="dask is not installed")
452+
def test_proximity_dask_kdtree_no_targets():
453+
"""No target pixels found → result is all NaN."""
454+
data = np.zeros((10, 10), dtype=np.float64)
455+
_lon = np.arange(10, dtype=np.float64)
456+
_lat = np.arange(10, dtype=np.float64)[::-1]
457+
raster = xr.DataArray(data, dims=['lat', 'lon'])
458+
raster['lon'] = _lon
459+
raster['lat'] = _lat
460+
raster.data = da.from_array(data, chunks=(5, 5))
461+
462+
result = proximity(raster, x='lon', y='lat')
463+
assert isinstance(result.data, da.Array)
464+
computed = result.values
465+
assert np.all(np.isnan(computed))
466+
467+
468+
@pytest.mark.skipif(da is None, reason="dask is not installed")
469+
def test_proximity_dask_kdtree_max_distance():
470+
"""max_distance truncation works via distance_upper_bound in tree query."""
471+
raster = _make_kdtree_raster()
472+
numpy_raster = raster.copy()
473+
numpy_raster.data = raster.data.compute()
474+
475+
max_dist = 5.0
476+
numpy_result = proximity(numpy_raster, x='lon', y='lat',
477+
max_distance=max_dist)
478+
dask_result = proximity(raster, x='lon', y='lat',
479+
max_distance=max_dist)
480+
481+
np.testing.assert_allclose(
482+
dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True,
483+
)
484+
485+
486+
@pytest.mark.skipif(da is None, reason="dask is not installed")
487+
def test_proximity_dask_kdtree_fallback_no_scipy():
488+
"""When cKDTree is None, falls back to single-chunk path."""
489+
import sys
490+
prox_mod = sys.modules['xrspatial.proximity']
491+
492+
height, width = 8, 10
493+
data = np.zeros((height, width), dtype=np.float64)
494+
data[2, 3] = 1.0
495+
data[6, 8] = 2.0
496+
_lon = np.linspace(0, 9, width)
497+
_lat = np.linspace(7, 0, height)
498+
raster = xr.DataArray(data, dims=['lat', 'lon'])
499+
raster['lon'] = _lon
500+
raster['lat'] = _lat
501+
raster.data = da.from_array(data, chunks=(4, 5))
502+
503+
original_ckdtree = prox_mod.cKDTree
504+
try:
505+
prox_mod.cKDTree = None
506+
result = proximity(raster, x='lon', y='lat')
507+
assert isinstance(result.data, da.Array)
508+
# Should still produce correct results via fallback
509+
computed = result.values
510+
assert computed[2, 3] == 0.0
511+
finally:
512+
prox_mod.cKDTree = original_ckdtree
513+
514+
515+
@pytest.mark.skipif(da is None, reason="dask is not installed")
516+
def test_proximity_dask_kdtree_fallback_great_circle():
517+
"""GREAT_CIRCLE metric falls back to single-chunk, not k-d tree."""
518+
import sys
519+
prox_mod = sys.modules['xrspatial.proximity']
520+
521+
height, width = 8, 10
522+
data = np.zeros((height, width), dtype=np.float64)
523+
data[2, 3] = 1.0
524+
_lon = np.linspace(-10, 10, width)
525+
_lat = np.linspace(10, -10, height)
526+
raster = xr.DataArray(data, dims=['lat', 'lon'])
527+
raster['lon'] = _lon
528+
raster['lat'] = _lat
529+
raster.data = da.from_array(data, chunks=(4, 5))
530+
531+
# Patch _process_dask_kdtree to detect if it's called
532+
kdtree_called = []
533+
original_fn = prox_mod._process_dask_kdtree
534+
535+
def spy(*args, **kwargs):
536+
kdtree_called.append(True)
537+
return original_fn(*args, **kwargs)
538+
539+
with patch.object(prox_mod, '_process_dask_kdtree', spy):
540+
result = proximity(raster, x='lon', y='lat',
541+
distance_metric='GREAT_CIRCLE')
542+
543+
assert len(kdtree_called) == 0, "k-d tree path should not be used for GREAT_CIRCLE"
544+
assert isinstance(result.data, da.Array)

0 commit comments

Comments
 (0)