Skip to content

Commit be91527

Browse files
committed
Add snap pour point module with all four backends
Moves each pour point to the highest flow-accumulation cell within a circular search radius so that watershed delineation starts from the actual drainage channel. Dask backend extracts sparse pour points chunk-by-chunk (map_blocks flag pass + selective load) to keep memory bounded regardless of grid size.
1 parent 1fdfc89 commit be91527

5 files changed

Lines changed: 639 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
249249
| [Basins](xrspatial/watershed.py) | Delineates drainage basins by labeling each cell with its outlet ID | ✅️ | ✅️ | ✅️ | ✅️ |
250250
| [Stream Order](xrspatial/stream_order.py) | Assigns Strahler or Shreve stream order to cells in a drainage network | ✅️ | ✅️ | ✅️ | ✅️ |
251251
| [Stream Link](xrspatial/stream_link.py) | Assigns unique IDs to each stream segment between junctions | ✅️ | ✅️ | ✅️ | ✅️ |
252+
| [Snap Pour Point](xrspatial/snap_pour_point.py) | Snaps pour points to the highest-accumulation cell within a search radius | ✅️ | ✅️ | ✅️ | ✅️ |
252253

253254
-----------
254255

xrspatial/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from xrspatial.proximity import manhattan_distance # noqa
3636
from xrspatial.proximity import proximity # noqa
3737
from xrspatial.sink import sink # noqa
38+
from xrspatial.snap_pour_point import snap_pour_point # noqa
3839
from xrspatial.stream_link import stream_link # noqa
3940
from xrspatial.stream_order import stream_order # noqa
4041
from xrspatial.slope import slope # noqa

xrspatial/accessor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def stream_link(self, flow_accum, **kwargs):
9191
from .stream_link import stream_link
9292
return stream_link(self._obj, flow_accum, **kwargs)
9393

94+
def snap_pour_point(self, pour_points, **kwargs):
95+
from .snap_pour_point import snap_pour_point
96+
return snap_pour_point(self._obj, pour_points, **kwargs)
97+
9498
def viewshed(self, x, y, **kwargs):
9599
from .viewshed import viewshed
96100
return viewshed(self._obj, x, y, **kwargs)
@@ -325,6 +329,10 @@ def stream_link(self, flow_accum, **kwargs):
325329
from .stream_link import stream_link
326330
return stream_link(self._obj, flow_accum, **kwargs)
327331

332+
def snap_pour_point(self, pour_points, **kwargs):
333+
from .snap_pour_point import snap_pour_point
334+
return snap_pour_point(self._obj, pour_points, **kwargs)
335+
328336
# ---- Classification ----
329337

330338
def natural_breaks(self, **kwargs):

xrspatial/snap_pour_point.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""Snap pour points to the highest-accumulation cell within a search radius.
2+
3+
Users typically place pour points manually, but these often land a cell or
4+
two off from the actual drainage channel. This module moves each pour point
5+
to the highest flow-accumulation cell within a circular search radius so
6+
that subsequent ``watershed()`` calls delineate correctly.
7+
8+
Algorithm
9+
---------
10+
For each non-NaN cell in ``pour_points``:
11+
1. Search all cells within ``search_radius`` pixels (Euclidean distance).
12+
2. Among valid (non-NaN) ``flow_accum`` cells in that radius, find the one
13+
with maximum accumulation.
14+
3. Move the pour point label to that cell.
15+
16+
If multiple pour points snap to the same cell, the last one in raster-scan
17+
order wins (deterministic across all backends).
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import numpy as np
23+
import xarray as xr
24+
25+
try:
26+
import cupy
27+
except ImportError:
28+
class cupy: # type: ignore[no-redef]
29+
ndarray = False
30+
31+
try:
32+
import dask.array as da
33+
except ImportError:
34+
da = None
35+
36+
from xrspatial.utils import (
37+
_validate_raster,
38+
has_cuda_and_cupy,
39+
is_cupy_array,
40+
is_dask_cupy,
41+
ngjit,
42+
)
43+
from xrspatial.dataset_support import supports_dataset
44+
45+
46+
# =====================================================================
47+
# CPU kernel
48+
# =====================================================================
49+
50+
@ngjit
51+
def _snap_pour_point_cpu(flow_accum, pour_points, search_radius, H, W):
52+
"""Snap each pour point to the max-accumulation cell within *radius*."""
53+
out = np.empty((H, W), dtype=np.float64)
54+
out[:] = np.nan
55+
radius_sq = search_radius * search_radius
56+
57+
for r in range(H):
58+
for c in range(W):
59+
v = pour_points[r, c]
60+
if v != v: # NaN
61+
continue
62+
label = v
63+
64+
best_r = r
65+
best_c = c
66+
fa_val = flow_accum[r, c]
67+
if fa_val == fa_val: # not NaN
68+
best_accum = fa_val
69+
else:
70+
best_accum = -1e308 # ~-inf
71+
72+
r_lo = r - search_radius
73+
r_hi = r + search_radius
74+
c_lo = c - search_radius
75+
c_hi = c + search_radius
76+
77+
if r_lo < 0:
78+
r_lo = 0
79+
if r_hi >= H:
80+
r_hi = H - 1
81+
if c_lo < 0:
82+
c_lo = 0
83+
if c_hi >= W:
84+
c_hi = W - 1
85+
86+
for nr in range(r_lo, r_hi + 1):
87+
for nc in range(c_lo, c_hi + 1):
88+
dr = nr - r
89+
dc = nc - c
90+
if dr * dr + dc * dc > radius_sq:
91+
continue
92+
fa_n = flow_accum[nr, nc]
93+
if fa_n != fa_n: # NaN
94+
continue
95+
if fa_n > best_accum:
96+
best_accum = fa_n
97+
best_r = nr
98+
best_c = nc
99+
100+
out[best_r, best_c] = label
101+
102+
return out
103+
104+
105+
# =====================================================================
106+
# CuPy backend
107+
# =====================================================================
108+
109+
def _snap_pour_point_cupy(flow_accum_data, pour_points_data, search_radius):
110+
"""CuPy: convert to numpy, run CPU kernel, convert back."""
111+
import cupy as cp
112+
113+
fa_np = flow_accum_data.get() if hasattr(flow_accum_data, 'get') else np.asarray(flow_accum_data)
114+
pp_np = pour_points_data.get() if hasattr(pour_points_data, 'get') else np.asarray(pour_points_data)
115+
fa_np = fa_np.astype(np.float64)
116+
pp_np = pp_np.astype(np.float64)
117+
H, W = fa_np.shape
118+
out = _snap_pour_point_cpu(fa_np, pp_np, search_radius, H, W)
119+
return cp.asarray(out)
120+
121+
122+
# =====================================================================
123+
# Dask backend
124+
# =====================================================================
125+
126+
def _snap_pour_point_dask(flow_accum_data, pour_points_data, search_radius):
127+
"""Dask: extract sparse pour points chunk-by-chunk, windowed search, lazy assembly.
128+
129+
Pour points are sparse (typically < 100 in a multi-million-cell raster).
130+
We never materialize the full pour_points grid: a ``map_blocks`` pass
131+
reduces each chunk to a 1-byte flag, then only the (few) flagged chunks
132+
are loaded to extract coordinates. Small windows of ``flow_accum`` are
133+
sliced for each pour point, and the output is assembled lazily.
134+
"""
135+
H, W = flow_accum_data.shape
136+
chunks_y = pour_points_data.chunks[0]
137+
chunks_x = pour_points_data.chunks[1]
138+
139+
# --- Phase 1: identify which chunks contain pour points --------
140+
# Single dask pass; each chunk is reduced to a scalar flag.
141+
# The scheduler parallelizes reads and releases each chunk after
142+
# the reduction, so peak memory is bounded by thread count × chunk size.
143+
def _has_pp(block):
144+
return np.array(
145+
[[np.any(~np.isnan(np.asarray(block))).item()]],
146+
dtype=np.int8,
147+
)
148+
149+
flags = da.map_blocks(
150+
_has_pp, pour_points_data,
151+
dtype=np.int8,
152+
chunks=tuple((1,) * len(c) for c in pour_points_data.chunks),
153+
).compute() # tiny array: one byte per chunk
154+
155+
# --- Phase 2: load only flagged chunks, extract coordinates ----
156+
points = [] # list of (global_row, global_col, label)
157+
row_off = 0
158+
for iy, cy in enumerate(chunks_y):
159+
col_off = 0
160+
for ix, cx in enumerate(chunks_x):
161+
if flags[iy, ix]:
162+
chunk = np.asarray(
163+
pour_points_data.blocks[iy, ix].compute(),
164+
dtype=np.float64,
165+
)
166+
rs, cs = np.where(~np.isnan(chunk))
167+
for k in range(len(rs)):
168+
points.append((
169+
row_off + int(rs[k]),
170+
col_off + int(cs[k]),
171+
float(chunk[rs[k], cs[k]]),
172+
))
173+
col_off += cx
174+
row_off += cy
175+
176+
# --- Phase 3: snap each pour point via windowed search ---------
177+
snapped = [] # list of (snap_r, snap_c, label)
178+
radius_sq = search_radius * search_radius
179+
180+
for r, c, label in points:
181+
r_lo = max(0, r - search_radius)
182+
r_hi = min(H - 1, r + search_radius)
183+
c_lo = max(0, c - search_radius)
184+
c_hi = min(W - 1, c + search_radius)
185+
186+
# Small window; dask handles cross-chunk slicing
187+
window = np.asarray(
188+
flow_accum_data[r_lo:r_hi + 1, c_lo:c_hi + 1].compute(),
189+
dtype=np.float64,
190+
)
191+
192+
best_r, best_c = r, c
193+
fa_val = window[r - r_lo, c - c_lo]
194+
best_accum = fa_val if not np.isnan(fa_val) else -np.inf
195+
196+
for wr in range(window.shape[0]):
197+
for wc in range(window.shape[1]):
198+
nr = r_lo + wr
199+
nc = c_lo + wc
200+
dr = nr - r
201+
dc = nc - c
202+
if dr * dr + dc * dc > radius_sq:
203+
continue
204+
fa_n = window[wr, wc]
205+
if np.isnan(fa_n):
206+
continue
207+
if fa_n > best_accum:
208+
best_accum = fa_n
209+
best_r = nr
210+
best_c = nc
211+
212+
snapped.append((best_r, best_c, label))
213+
214+
# --- Phase 4: lazy output assembly via map_blocks --------------
215+
snap_rows = np.array([s[0] for s in snapped], dtype=np.int64) if snapped else np.array([], dtype=np.int64)
216+
snap_cols = np.array([s[1] for s in snapped], dtype=np.int64) if snapped else np.array([], dtype=np.int64)
217+
snap_labels = np.array([s[2] for s in snapped], dtype=np.float64) if snapped else np.array([], dtype=np.float64)
218+
219+
_snap_rows = snap_rows
220+
_snap_cols = snap_cols
221+
_snap_labels = snap_labels
222+
223+
def _assemble_block(block, block_info=None):
224+
if block_info is None or 0 not in block_info:
225+
return np.full(block.shape, np.nan, dtype=np.float64)
226+
row_start, row_end = block_info[0]['array-location'][0]
227+
col_start, col_end = block_info[0]['array-location'][1]
228+
h, w = block.shape
229+
out = np.full((h, w), np.nan, dtype=np.float64)
230+
for k in range(len(_snap_rows)):
231+
sr = _snap_rows[k]
232+
sc = _snap_cols[k]
233+
if row_start <= sr < row_end and col_start <= sc < col_end:
234+
out[sr - row_start, sc - col_start] = _snap_labels[k]
235+
return out
236+
237+
dummy = da.zeros((H, W), chunks=flow_accum_data.chunks, dtype=np.float64)
238+
return da.map_blocks(
239+
_assemble_block, dummy,
240+
dtype=np.float64,
241+
meta=np.array((), dtype=np.float64),
242+
)
243+
244+
245+
# =====================================================================
246+
# Dask+CuPy backend
247+
# =====================================================================
248+
249+
def _snap_pour_point_dask_cupy(flow_accum_data, pour_points_data, search_radius):
250+
"""Dask+CuPy: convert cupy chunks to numpy, run dask path, convert back."""
251+
import cupy as cp
252+
253+
fa_np = flow_accum_data.map_blocks(
254+
lambda b: b.get(), dtype=flow_accum_data.dtype,
255+
meta=np.array((), dtype=flow_accum_data.dtype),
256+
)
257+
pp_np = pour_points_data.map_blocks(
258+
lambda b: b.get(), dtype=pour_points_data.dtype,
259+
meta=np.array((), dtype=pour_points_data.dtype),
260+
)
261+
262+
result = _snap_pour_point_dask(fa_np, pp_np, search_radius)
263+
return result.map_blocks(
264+
cp.asarray, dtype=result.dtype,
265+
meta=cp.array((), dtype=result.dtype),
266+
)
267+
268+
269+
# =====================================================================
270+
# Public API
271+
# =====================================================================
272+
273+
@supports_dataset
274+
def snap_pour_point(flow_accum: xr.DataArray,
275+
pour_points: xr.DataArray,
276+
search_radius: int = 5,
277+
name: str = 'snapped_pour_points') -> xr.DataArray:
278+
"""Snap pour points to the highest-accumulation cell within a radius.
279+
280+
Parameters
281+
----------
282+
flow_accum : xarray.DataArray or xr.Dataset
283+
2D flow accumulation grid.
284+
pour_points : xarray.DataArray
285+
2D raster where non-NaN cells mark pour points (same format
286+
as ``watershed()`` expects). Values are preserved as labels.
287+
search_radius : int, default 5
288+
Maximum search distance in pixels (Euclidean).
289+
name : str, default 'snapped_pour_points'
290+
Name of output DataArray.
291+
292+
Returns
293+
-------
294+
xarray.DataArray or xr.Dataset
295+
Same-shape grid with pour point labels moved to their snapped
296+
locations. Non-pour-point cells are NaN.
297+
"""
298+
_validate_raster(flow_accum, func_name='snap_pour_point', name='flow_accum')
299+
300+
fa_data = flow_accum.data
301+
pp_data = pour_points.data
302+
303+
if isinstance(fa_data, np.ndarray):
304+
fa = fa_data.astype(np.float64)
305+
pp = np.asarray(pp_data, dtype=np.float64)
306+
H, W = fa.shape
307+
out = _snap_pour_point_cpu(fa, pp, search_radius, H, W)
308+
309+
elif has_cuda_and_cupy() and is_cupy_array(fa_data):
310+
out = _snap_pour_point_cupy(fa_data, pp_data, search_radius)
311+
312+
elif has_cuda_and_cupy() and is_dask_cupy(flow_accum):
313+
out = _snap_pour_point_dask_cupy(fa_data, pp_data, search_radius)
314+
315+
elif da is not None and isinstance(fa_data, da.Array):
316+
out = _snap_pour_point_dask(fa_data, pp_data, search_radius)
317+
318+
else:
319+
raise TypeError(f"Unsupported array type: {type(fa_data)}")
320+
321+
return xr.DataArray(out,
322+
name=name,
323+
coords=flow_accum.coords,
324+
dims=flow_accum.dims,
325+
attrs=flow_accum.attrs)

0 commit comments

Comments
 (0)