|
| 1 | +import warnings |
1 | 2 | from functools import partial |
2 | 3 | from math import sqrt |
3 | 4 |
|
|
15 | 16 | import xarray as xr |
16 | 17 | from numba import prange |
17 | 18 |
|
| 19 | +from xrspatial.pathfinding import _available_memory_bytes |
18 | 20 | from xrspatial.utils import get_dataarray_resolution, ngjit |
19 | 21 | from xrspatial.dataset_support import supports_dataset |
20 | 22 |
|
@@ -426,55 +428,323 @@ def _kdtree_chunk_fn(block, y_coords_1d, x_coords_1d, |
426 | 428 | return dists |
427 | 429 |
|
428 | 430 |
|
429 | | -def _process_dask_kdtree(raster, x_coords, y_coords, |
430 | | - target_values, max_distance, distance_metric): |
431 | | - """Two-phase k-d tree proximity for unbounded dask arrays.""" |
432 | | - p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1 |
| 431 | +def _target_mask(chunk_data, target_values): |
| 432 | + """Boolean mask of target pixels in *chunk_data*.""" |
| 433 | + if len(target_values) == 0: |
| 434 | + return np.isfinite(chunk_data) & (chunk_data != 0) |
| 435 | + return np.isin(chunk_data, target_values) & np.isfinite(chunk_data) |
433 | 436 |
|
434 | | - # Phase 1: stream through chunks to collect target coordinates |
435 | | - target_list = [] |
436 | | - chunks_y, chunks_x = raster.data.chunks |
437 | | - y_offset = 0 |
438 | | - for iy, cy in enumerate(chunks_y): |
439 | | - x_offset = 0 |
440 | | - for ix, cx in enumerate(chunks_x): |
| 437 | + |
| 438 | +def _stream_target_counts(raster, target_values, y_coords, x_coords, |
| 439 | + chunks_y, chunks_x): |
| 440 | + """Stream all dask chunks, counting targets per chunk. |
| 441 | +
|
| 442 | + Caches per-chunk coordinate arrays within a 25% memory budget to |
| 443 | + reduce re-reads in later phases. |
| 444 | +
|
| 445 | + Returns |
| 446 | + ------- |
| 447 | + target_counts : ndarray, shape (n_tile_y, n_tile_x), dtype int64 |
| 448 | + total_targets : int |
| 449 | + coords_cache : dict (iy, ix) -> ndarray shape (N, 2) |
| 450 | + """ |
| 451 | + n_tile_y = len(chunks_y) |
| 452 | + n_tile_x = len(chunks_x) |
| 453 | + target_counts = np.zeros((n_tile_y, n_tile_x), dtype=np.int64) |
| 454 | + coords_cache = {} |
| 455 | + cache_bytes = 0 |
| 456 | + budget = int(0.25 * _available_memory_bytes()) |
| 457 | + |
| 458 | + y_offsets = np.zeros(n_tile_y + 1, dtype=np.int64) |
| 459 | + np.cumsum(chunks_y, out=y_offsets[1:]) |
| 460 | + x_offsets = np.zeros(n_tile_x + 1, dtype=np.int64) |
| 461 | + np.cumsum(chunks_x, out=x_offsets[1:]) |
| 462 | + |
| 463 | + for iy in range(n_tile_y): |
| 464 | + for ix in range(n_tile_x): |
441 | 465 | chunk_data = raster.data.blocks[iy, ix].compute() |
442 | | - if len(target_values) == 0: |
443 | | - mask = np.isfinite(chunk_data) & (chunk_data != 0) |
444 | | - else: |
445 | | - mask = np.isin(chunk_data, target_values) & np.isfinite(chunk_data) |
| 466 | + mask = _target_mask(chunk_data, target_values) |
446 | 467 | rows, cols = np.where(mask) |
447 | | - if len(rows) > 0: |
| 468 | + n = len(rows) |
| 469 | + target_counts[iy, ix] = n |
| 470 | + if n > 0: |
448 | 471 | coords = np.column_stack([ |
449 | | - y_coords[y_offset + rows], |
450 | | - x_coords[x_offset + cols], |
| 472 | + y_coords[y_offsets[iy] + rows], |
| 473 | + x_coords[x_offsets[ix] + cols], |
451 | 474 | ]) |
452 | | - target_list.append(coords) |
453 | | - x_offset += cx |
454 | | - y_offset += cy |
| 475 | + entry_bytes = coords.nbytes |
| 476 | + if cache_bytes + entry_bytes <= budget: |
| 477 | + coords_cache[(iy, ix)] = coords |
| 478 | + cache_bytes += entry_bytes |
455 | 479 |
|
456 | | - if len(target_list) == 0: |
457 | | - return da.full(raster.shape, np.nan, dtype=np.float32, |
458 | | - chunks=raster.data.chunks) |
| 480 | + total_targets = int(target_counts.sum()) |
| 481 | + return target_counts, total_targets, coords_cache |
| 482 | + |
| 483 | + |
| 484 | +def _chunk_offsets(chunks): |
| 485 | + """Return cumulative offset array of length len(chunks)+1.""" |
| 486 | + offsets = np.zeros(len(chunks) + 1, dtype=np.int64) |
| 487 | + np.cumsum(chunks, out=offsets[1:]) |
| 488 | + return offsets |
| 489 | + |
| 490 | + |
| 491 | +def _collect_region_targets(raster, jy_lo, jy_hi, jx_lo, jx_hi, |
| 492 | + target_values, target_counts, |
| 493 | + y_coords, x_coords, |
| 494 | + y_offsets, x_offsets, coords_cache): |
| 495 | + """Collect target (y, x) coords from chunks in [jy_lo:jy_hi, jx_lo:jx_hi]. |
| 496 | +
|
| 497 | + Uses cache where available, re-reads uncached chunks via .compute(). |
| 498 | + Returns ndarray shape (N, 2) or None if no targets in region. |
| 499 | + """ |
| 500 | + parts = [] |
| 501 | + for iy in range(jy_lo, jy_hi): |
| 502 | + for ix in range(jx_lo, jx_hi): |
| 503 | + if target_counts[iy, ix] == 0: |
| 504 | + continue |
| 505 | + if (iy, ix) in coords_cache: |
| 506 | + parts.append(coords_cache[(iy, ix)]) |
| 507 | + else: |
| 508 | + chunk_data = raster.data.blocks[iy, ix].compute() |
| 509 | + mask = _target_mask(chunk_data, target_values) |
| 510 | + rows, cols = np.where(mask) |
| 511 | + if len(rows) > 0: |
| 512 | + coords = np.column_stack([ |
| 513 | + y_coords[y_offsets[iy] + rows], |
| 514 | + x_coords[x_offsets[ix] + cols], |
| 515 | + ]) |
| 516 | + parts.append(coords) |
| 517 | + if not parts: |
| 518 | + return None |
| 519 | + return np.concatenate(parts) |
| 520 | + |
| 521 | + |
| 522 | +def _min_boundary_distance(iy, ix, y_coords, x_coords, |
| 523 | + y_offsets, x_offsets, |
| 524 | + jy_lo, jy_hi, jx_lo, jx_hi, |
| 525 | + n_tile_y, n_tile_x): |
| 526 | + """Lower bound on distance from any pixel in chunk (iy, ix) to any point |
| 527 | + outside the search region [jy_lo:jy_hi, jx_lo:jx_hi]. |
| 528 | +
|
| 529 | + For each of the 4 sides where the search region doesn't reach the raster |
| 530 | + edge, compute the gap between the chunk's edge pixel coordinate and the |
| 531 | + first pixel outside the search region. The minimum of these gaps is |
| 532 | + a valid lower bound for both L1 and L2 norms. |
| 533 | +
|
| 534 | + Returns float (inf if search covers the full raster). |
| 535 | + """ |
| 536 | + gaps = [] |
| 537 | + |
| 538 | + # Top boundary |
| 539 | + if jy_lo > 0: |
| 540 | + # chunk's top-edge row in pixel space |
| 541 | + chunk_top_row = y_offsets[iy] |
| 542 | + # first row outside region (above) |
| 543 | + outside_row = y_offsets[jy_lo] - 1 |
| 544 | + gap = abs(float(y_coords[chunk_top_row]) - float(y_coords[outside_row])) |
| 545 | + gaps.append(gap) |
| 546 | + |
| 547 | + # Bottom boundary |
| 548 | + if jy_hi < n_tile_y: |
| 549 | + chunk_bot_row = y_offsets[iy + 1] - 1 |
| 550 | + outside_row = y_offsets[jy_hi] |
| 551 | + gap = abs(float(y_coords[chunk_bot_row]) - float(y_coords[outside_row])) |
| 552 | + gaps.append(gap) |
| 553 | + |
| 554 | + # Left boundary |
| 555 | + if jx_lo > 0: |
| 556 | + chunk_left_col = x_offsets[ix] |
| 557 | + outside_col = x_offsets[jx_lo] - 1 |
| 558 | + gap = abs(float(x_coords[chunk_left_col]) - float(x_coords[outside_col])) |
| 559 | + gaps.append(gap) |
| 560 | + |
| 561 | + # Right boundary |
| 562 | + if jx_hi < n_tile_x: |
| 563 | + chunk_right_col = x_offsets[ix + 1] - 1 |
| 564 | + outside_col = x_offsets[jx_hi] |
| 565 | + gap = abs(float(x_coords[chunk_right_col]) - float(x_coords[outside_col])) |
| 566 | + gaps.append(gap) |
| 567 | + |
| 568 | + return min(gaps) if gaps else np.inf |
| 569 | + |
| 570 | + |
| 571 | +def _tiled_chunk_proximity(raster, iy, ix, y_coords, x_coords, |
| 572 | + y_offsets, x_offsets, |
| 573 | + target_values, target_counts, |
| 574 | + coords_cache, max_distance, p, |
| 575 | + n_tile_y, n_tile_x): |
| 576 | + """Expanding-ring local KDTree for one output chunk. |
| 577 | +
|
| 578 | + Returns ndarray shape (h, w), dtype float32. |
| 579 | + """ |
| 580 | + h = int(y_offsets[iy + 1] - y_offsets[iy]) |
| 581 | + w = int(x_offsets[ix + 1] - x_offsets[ix]) |
| 582 | + |
| 583 | + # Build query points for this chunk |
| 584 | + chunk_ys = y_coords[y_offsets[iy]:y_offsets[iy + 1]] |
| 585 | + chunk_xs = x_coords[x_offsets[ix]:x_offsets[ix + 1]] |
| 586 | + yy, xx = np.meshgrid(chunk_ys, chunk_xs, indexing='ij') |
| 587 | + query_pts = np.column_stack([yy.ravel(), xx.ravel()]) |
| 588 | + |
| 589 | + ring = 0 |
| 590 | + while True: |
| 591 | + jy_lo = max(iy - ring, 0) |
| 592 | + jy_hi = min(iy + 1 + ring, n_tile_y) |
| 593 | + jx_lo = max(ix - ring, 0) |
| 594 | + jx_hi = min(ix + 1 + ring, n_tile_x) |
| 595 | + |
| 596 | + covers_full = (jy_lo == 0 and jy_hi == n_tile_y |
| 597 | + and jx_lo == 0 and jx_hi == n_tile_x) |
| 598 | + |
| 599 | + target_coords = _collect_region_targets( |
| 600 | + raster, jy_lo, jy_hi, jx_lo, jx_hi, |
| 601 | + target_values, target_counts, |
| 602 | + y_coords, x_coords, y_offsets, x_offsets, coords_cache, |
| 603 | + ) |
| 604 | + |
| 605 | + if target_coords is None: |
| 606 | + if covers_full: |
| 607 | + # No targets in entire raster |
| 608 | + return np.full((h, w), np.nan, dtype=np.float32) |
| 609 | + ring += 1 |
| 610 | + continue |
| 611 | + |
| 612 | + tree = cKDTree(target_coords) |
| 613 | + ub = max_distance if np.isfinite(max_distance) else np.inf |
| 614 | + dists, _ = tree.query(query_pts, p=p, distance_upper_bound=ub) |
| 615 | + dists = dists.reshape(h, w).astype(np.float32) |
| 616 | + dists[dists == np.inf] = np.nan |
| 617 | + |
| 618 | + if covers_full: |
| 619 | + return dists |
| 620 | + |
| 621 | + # Validate: max_nearest_dist < min_boundary_distance |
| 622 | + max_nearest = np.nanmax(dists) if not np.all(np.isnan(dists)) else 0.0 |
| 623 | + min_bd = _min_boundary_distance( |
| 624 | + iy, ix, y_coords, x_coords, y_offsets, x_offsets, |
| 625 | + jy_lo, jy_hi, jx_lo, jx_hi, n_tile_y, n_tile_x, |
| 626 | + ) |
| 627 | + if max_nearest < min_bd: |
| 628 | + return dists |
| 629 | + |
| 630 | + ring += 1 |
| 631 | + |
| 632 | + |
| 633 | +def _build_tiled_kdtree(raster, y_coords, x_coords, target_values, |
| 634 | + max_distance, p, target_counts, coords_cache, |
| 635 | + chunks_y, chunks_x): |
| 636 | + """Tiled (eager) KDTree proximity — memory-safe fallback.""" |
| 637 | + H, W = raster.shape |
| 638 | + result_bytes = H * W * 4 # float32 |
| 639 | + avail = _available_memory_bytes() |
| 640 | + if result_bytes > 0.8 * avail: |
| 641 | + raise MemoryError( |
| 642 | + f"Proximity result array ({H}x{W}, {result_bytes / 1e9:.1f} GB) " |
| 643 | + f"exceeds 80% of available memory ({avail / 1e9:.1f} GB)." |
| 644 | + ) |
| 645 | + |
| 646 | + warnings.warn( |
| 647 | + "proximity: target coordinates exceed 50% of available memory; " |
| 648 | + "using tiled KDTree fallback (slower but memory-safe).", |
| 649 | + ResourceWarning, |
| 650 | + stacklevel=4, |
| 651 | + ) |
| 652 | + |
| 653 | + n_tile_y = len(chunks_y) |
| 654 | + n_tile_x = len(chunks_x) |
| 655 | + y_offsets = _chunk_offsets(chunks_y) |
| 656 | + x_offsets = _chunk_offsets(chunks_x) |
| 657 | + |
| 658 | + result = np.full((H, W), np.nan, dtype=np.float32) |
| 659 | + |
| 660 | + for iy in range(n_tile_y): |
| 661 | + for ix in range(n_tile_x): |
| 662 | + chunk_result = _tiled_chunk_proximity( |
| 663 | + raster, iy, ix, y_coords, x_coords, |
| 664 | + y_offsets, x_offsets, |
| 665 | + target_values, target_counts, coords_cache, |
| 666 | + max_distance, p, n_tile_y, n_tile_x, |
| 667 | + ) |
| 668 | + r0 = int(y_offsets[iy]) |
| 669 | + r1 = int(y_offsets[iy + 1]) |
| 670 | + c0 = int(x_offsets[ix]) |
| 671 | + c1 = int(x_offsets[ix + 1]) |
| 672 | + result[r0:r1, c0:c1] = chunk_result |
| 673 | + |
| 674 | + return da.from_array(result, chunks=raster.data.chunks) |
| 675 | + |
| 676 | + |
| 677 | +def _build_global_kdtree(raster, y_coords, x_coords, target_values, |
| 678 | + max_distance, p, target_counts, coords_cache, |
| 679 | + chunks_y, chunks_x): |
| 680 | + """Global KDTree proximity — fast, lazy via da.map_blocks.""" |
| 681 | + n_tile_y = len(chunks_y) |
| 682 | + n_tile_x = len(chunks_x) |
| 683 | + y_offsets = _chunk_offsets(chunks_y) |
| 684 | + x_offsets = _chunk_offsets(chunks_x) |
| 685 | + |
| 686 | + target_coords = _collect_region_targets( |
| 687 | + raster, 0, n_tile_y, 0, n_tile_x, |
| 688 | + target_values, target_counts, |
| 689 | + y_coords, x_coords, y_offsets, x_offsets, coords_cache, |
| 690 | + ) |
459 | 691 |
|
460 | | - target_coords = np.concatenate(target_list) |
461 | 692 | tree = cKDTree(target_coords) |
462 | 693 |
|
463 | | - # Phase 2: query tree per chunk via map_blocks |
464 | | - chunk_fn = partial(_kdtree_chunk_fn, |
465 | | - y_coords_1d=y_coords, |
466 | | - x_coords_1d=x_coords, |
467 | | - tree=tree, |
468 | | - max_distance=max_distance if np.isfinite(max_distance) else np.inf, |
469 | | - p=p) |
| 694 | + chunk_fn = partial( |
| 695 | + _kdtree_chunk_fn, |
| 696 | + y_coords_1d=y_coords, |
| 697 | + x_coords_1d=x_coords, |
| 698 | + tree=tree, |
| 699 | + max_distance=max_distance if np.isfinite(max_distance) else np.inf, |
| 700 | + p=p, |
| 701 | + ) |
470 | 702 |
|
471 | | - result = da.map_blocks( |
| 703 | + return da.map_blocks( |
472 | 704 | chunk_fn, |
473 | 705 | raster.data, |
474 | 706 | dtype=np.float32, |
475 | 707 | meta=np.array((), dtype=np.float32), |
476 | 708 | ) |
477 | | - return result |
| 709 | + |
| 710 | + |
| 711 | +def _process_dask_kdtree(raster, x_coords, y_coords, |
| 712 | + target_values, max_distance, distance_metric): |
| 713 | + """Memory-guarded k-d tree proximity for dask arrays. |
| 714 | +
|
| 715 | + Phase 0: stream through chunks counting targets (with caching). |
| 716 | + Then choose global tree (fast, lazy) or tiled tree (memory-safe, eager) |
| 717 | + based on estimated memory usage. |
| 718 | + """ |
| 719 | + p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1 |
| 720 | + |
| 721 | + chunks_y, chunks_x = raster.data.chunks |
| 722 | + |
| 723 | + # Phase 0: streaming count pass |
| 724 | + target_counts, total_targets, coords_cache = _stream_target_counts( |
| 725 | + raster, target_values, y_coords, x_coords, chunks_y, chunks_x, |
| 726 | + ) |
| 727 | + |
| 728 | + if total_targets == 0: |
| 729 | + return da.full(raster.shape, np.nan, dtype=np.float32, |
| 730 | + chunks=raster.data.chunks) |
| 731 | + |
| 732 | + # Memory decision: 16 bytes per coord pair + ~32 bytes tree overhead |
| 733 | + estimate = total_targets * 48 |
| 734 | + avail = _available_memory_bytes() |
| 735 | + |
| 736 | + if estimate < 0.5 * avail: |
| 737 | + return _build_global_kdtree( |
| 738 | + raster, y_coords, x_coords, target_values, |
| 739 | + max_distance, p, target_counts, coords_cache, |
| 740 | + chunks_y, chunks_x, |
| 741 | + ) |
| 742 | + else: |
| 743 | + return _build_tiled_kdtree( |
| 744 | + raster, y_coords, x_coords, target_values, |
| 745 | + max_distance, p, target_counts, coords_cache, |
| 746 | + chunks_y, chunks_x, |
| 747 | + ) |
478 | 748 |
|
479 | 749 |
|
480 | 750 | def _process( |
|
0 commit comments