Skip to content

Commit 9dd54a8

Browse files
committed
Reduce dask graph size for large raster reprojection
Replace the O(N) delayed/from_delayed/block pattern in _reproject_dask and _merge_dask with da.map_blocks over a template array. This produces a single blockwise layer in the HighLevelGraph with O(1) metadata, so graph construction no longer scales with chunk count. Also adds empty-chunk skipping: precompute the source footprint in target CRS and skip chunks that fall outside it entirely. This avoids pyproj initialization and source data fetching for chunks that would just be filled with nodata. The streaming fallback threshold (previously 1GB graph metadata) is removed since graph metadata is now constant-size. Large in-memory arrays always go through dask when available.
1 parent a4b7507 commit 9dd54a8

File tree

2 files changed

+280
-79
lines changed

2 files changed

+280
-79
lines changed

xrspatial/reproject/__init__.py

Lines changed: 157 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -574,37 +574,31 @@ def reproject(
574574
else:
575575
is_cupy = is_cupy_array(data)
576576

577-
# For very large datasets, estimate whether a dask graph would fit
578-
# in memory. Each dask task uses ~1KB of graph metadata. If the
579-
# graph itself would exceed available memory, use a streaming
580-
# approach instead of dask (process tiles sequentially, no graph).
577+
# For large in-memory datasets, wrap in dask for chunked processing.
578+
# map_blocks generates an O(1) HighLevelGraph (single blockwise layer)
579+
# so graph metadata is no longer a concern -- the streaming fallback
580+
# is only needed when dask itself is unavailable.
581581
_use_streaming = False
582582
if not is_dask and not is_cupy:
583583
nbytes = src_shape[0] * src_shape[1] * data.dtype.itemsize
584584
if data.ndim == 3:
585585
nbytes *= data.shape[2]
586586
_OOM_THRESHOLD = 512 * 1024 * 1024 # 512 MB
587587
if nbytes > _OOM_THRESHOLD:
588-
# Estimate graph size for the output
589588
cs = chunk_size or 2048
590589
if isinstance(cs, int):
591590
cs = (cs, cs)
592-
n_out_chunks = (math.ceil(out_shape[0] / cs[0])
593-
* math.ceil(out_shape[1] / cs[1]))
594-
graph_bytes = n_out_chunks * 1024 # ~1KB per task
595-
596-
if graph_bytes > 1024 * 1024 * 1024: # > 1GB graph
597-
# Graph too large for dask -- use streaming
598-
_use_streaming = True
599-
else:
600-
# Graph fits -- use dask with large chunks
591+
try:
601592
import dask.array as _da
602593
data = _da.from_array(data, chunks=cs)
603594
raster = xr.DataArray(
604595
data, dims=raster.dims, coords=raster.coords,
605596
name=raster.name, attrs=raster.attrs,
606597
)
607598
is_dask = True
599+
except ImportError:
600+
# dask not available -- fall back to streaming
601+
_use_streaming = True
608602

609603
# Serialize CRS for pickle safety
610604
src_wkt = src_crs.to_wkt()
@@ -1125,51 +1119,121 @@ def _reproject_dask_cupy(
11251119
return result
11261120

11271121

1122+
def _source_footprint_in_target(src_bounds, src_wkt, tgt_wkt):
1123+
"""Compute an approximate bounding box of the source raster in target CRS.
1124+
1125+
Transforms corners and edge midpoints (12 points) to handle non-linear
1126+
projections. Returns ``(left, bottom, right, top)`` in target CRS, or
1127+
*None* if the transform fails (e.g. out-of-domain).
1128+
"""
1129+
try:
1130+
from ._crs_utils import _require_pyproj
1131+
pyproj = _require_pyproj()
1132+
src_crs = pyproj.CRS(src_wkt)
1133+
tgt_crs = pyproj.CRS(tgt_wkt)
1134+
transformer = pyproj.Transformer.from_crs(
1135+
src_crs, tgt_crs, always_xy=True
1136+
)
1137+
except Exception:
1138+
return None
1139+
1140+
sl, sb, sr, st = src_bounds
1141+
mx = (sl + sr) / 2
1142+
my = (sb + st) / 2
1143+
xs = [sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx]
1144+
ys = [sb, sb, sb, my, my, my, st, st, st, mx, mx, sb]
1145+
try:
1146+
tx, ty = transformer.transform(xs, ys)
1147+
tx = [v for v in tx if np.isfinite(v)]
1148+
ty = [v for v in ty if np.isfinite(v)]
1149+
if not tx or not ty:
1150+
return None
1151+
return (min(tx), min(ty), max(tx), max(ty))
1152+
except Exception:
1153+
return None
1154+
1155+
1156+
def _bounds_overlap(a, b):
1157+
"""Return True if bounding boxes *a* and *b* overlap."""
1158+
return a[0] < b[2] and a[2] > b[0] and a[1] < b[3] and a[3] > b[1]
1159+
1160+
1161+
def _reproject_block_adapter(
1162+
block, block_info, source_data,
1163+
src_bounds, src_shape, y_desc,
1164+
src_wkt, tgt_wkt,
1165+
out_bounds, out_shape,
1166+
resampling, nodata, precision,
1167+
is_cupy, src_footprint_tgt,
1168+
):
1169+
"""``map_blocks`` adapter for reprojection.
1170+
1171+
Derives chunk bounds from *block_info* and delegates to the
1172+
per-chunk worker.
1173+
"""
1174+
info = block_info[0]
1175+
(row_start, row_end), (col_start, col_end) = info['array-location']
1176+
chunk_shape = (row_end - row_start, col_end - col_start)
1177+
cb = _chunk_bounds(out_bounds, out_shape,
1178+
row_start, row_end, col_start, col_end)
1179+
1180+
# Skip chunks that don't overlap the source footprint
1181+
if src_footprint_tgt is not None and not _bounds_overlap(cb, src_footprint_tgt):
1182+
return np.full(chunk_shape, nodata, dtype=np.float64)
1183+
1184+
chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy
1185+
return chunk_fn(
1186+
source_data, src_bounds, src_shape, y_desc,
1187+
src_wkt, tgt_wkt,
1188+
cb, chunk_shape,
1189+
resampling, nodata, precision,
1190+
)
1191+
1192+
11281193
def _reproject_dask(
11291194
raster, src_bounds, src_shape, y_desc,
11301195
src_wkt, tgt_wkt,
11311196
out_bounds, out_shape,
11321197
resampling, nodata, precision,
11331198
chunk_size, is_cupy,
11341199
):
1135-
"""Dask+NumPy backend: build output as ``da.block`` of delayed chunks."""
1136-
import dask
1200+
"""Dask+NumPy backend: ``map_blocks`` over a template array.
1201+
1202+
Uses a single ``blockwise`` layer in the HighLevelGraph instead of
1203+
O(N) ``dask.delayed`` nodes, keeping graph metadata O(1).
1204+
"""
11371205
import dask.array as da
11381206

11391207
row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size)
1140-
n_row = len(row_chunks)
1141-
n_col = len(col_chunks)
1142-
1143-
chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy
1144-
dtype = np.float64
11451208

1146-
blocks = [[None] * n_col for _ in range(n_row)]
1209+
# Precompute source footprint in target CRS for empty-chunk skipping
1210+
src_footprint_tgt = _source_footprint_in_target(
1211+
src_bounds, src_wkt, tgt_wkt
1212+
)
11471213

1148-
row_offset = 0
1149-
for i in range(n_row):
1150-
col_offset = 0
1151-
for j in range(n_col):
1152-
rchunk = row_chunks[i]
1153-
cchunk = col_chunks[j]
1154-
cb = _chunk_bounds(
1155-
out_bounds, out_shape,
1156-
row_offset, row_offset + rchunk,
1157-
col_offset, col_offset + cchunk,
1158-
)
1159-
delayed_chunk = dask.delayed(chunk_fn)(
1160-
raster.data,
1161-
src_bounds, src_shape, y_desc,
1162-
src_wkt, tgt_wkt,
1163-
cb, (rchunk, cchunk),
1164-
resampling, nodata, precision,
1165-
)
1166-
blocks[i][j] = da.from_delayed(
1167-
delayed_chunk, shape=(rchunk, cchunk), dtype=dtype
1168-
)
1169-
col_offset += cchunk
1170-
row_offset += rchunk
1214+
template = da.empty(
1215+
out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks)
1216+
)
11711217

1172-
return da.block(blocks)
1218+
return da.map_blocks(
1219+
_reproject_block_adapter,
1220+
template,
1221+
source_data=raster.data,
1222+
src_bounds=src_bounds,
1223+
src_shape=src_shape,
1224+
y_desc=y_desc,
1225+
src_wkt=src_wkt,
1226+
tgt_wkt=tgt_wkt,
1227+
out_bounds=out_bounds,
1228+
out_shape=out_shape,
1229+
resampling=resampling,
1230+
nodata=nodata,
1231+
precision=precision,
1232+
is_cupy=is_cupy,
1233+
src_footprint_tgt=src_footprint_tgt,
1234+
dtype=np.float64,
1235+
meta=np.array((), dtype=np.float64),
1236+
)
11731237

11741238

11751239
# ---------------------------------------------------------------------------
@@ -1434,37 +1498,49 @@ def _merge_inmemory(
14341498
return _merge_arrays_numpy(arrays, nodata, strategy)
14351499

14361500

1437-
def _merge_chunk_worker(
1501+
def _merge_block_adapter(
1502+
block, block_info,
14381503
raster_data_list, src_bounds_list, src_shape_list, y_desc_list,
14391504
src_wkt_list, tgt_wkt,
1440-
chunk_bounds_tuple, chunk_shape,
1505+
out_bounds, out_shape,
14411506
resampling, nodata, strategy, precision,
1507+
src_footprints_tgt,
14421508
):
1443-
"""Worker for a single merge chunk."""
1509+
"""``map_blocks`` adapter for merge."""
1510+
info = block_info[0]
1511+
(row_start, row_end), (col_start, col_end) = info['array-location']
1512+
chunk_shape = (row_end - row_start, col_end - col_start)
1513+
cb = _chunk_bounds(out_bounds, out_shape,
1514+
row_start, row_end, col_start, col_end)
1515+
1516+
# Only reproject rasters whose footprint overlaps this chunk
14441517
arrays = []
14451518
for i in range(len(raster_data_list)):
1519+
if (src_footprints_tgt[i] is not None
1520+
and not _bounds_overlap(cb, src_footprints_tgt[i])):
1521+
continue
14461522
reprojected = _reproject_chunk_numpy(
14471523
raster_data_list[i],
14481524
src_bounds_list[i], src_shape_list[i], y_desc_list[i],
14491525
src_wkt_list[i], tgt_wkt,
1450-
chunk_bounds_tuple, chunk_shape,
1526+
cb, chunk_shape,
14511527
resampling, nodata, precision,
14521528
)
14531529
arrays.append(reprojected)
1530+
1531+
if not arrays:
1532+
return np.full(chunk_shape, nodata, dtype=np.float64)
14541533
return _merge_arrays_numpy(arrays, nodata, strategy)
14551534

14561535

14571536
def _merge_dask(
14581537
raster_infos, tgt_wkt, out_bounds, out_shape,
14591538
resampling, nodata, strategy, chunk_size,
14601539
):
1461-
"""Dask merge backend."""
1462-
import dask
1540+
"""Dask merge backend using ``map_blocks``."""
14631541
import dask.array as da
14641542

14651543
row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size)
1466-
n_row = len(row_chunks)
1467-
n_col = len(col_chunks)
14681544

14691545
# Prepare lists for the worker
14701546
data_list = [info['raster'].data for info in raster_infos]
@@ -1473,30 +1549,32 @@ def _merge_dask(
14731549
ydesc_list = [info['y_desc'] for info in raster_infos]
14741550
wkt_list = [info['src_wkt'] for info in raster_infos]
14751551

1476-
dtype = np.float64
1477-
blocks = [[None] * n_col for _ in range(n_row)]
1552+
# Precompute source footprints in target CRS
1553+
footprints = [
1554+
_source_footprint_in_target(bounds_list[i], wkt_list[i], tgt_wkt)
1555+
for i in range(len(raster_infos))
1556+
]
14781557

1479-
row_offset = 0
1480-
for i in range(n_row):
1481-
col_offset = 0
1482-
for j in range(n_col):
1483-
rchunk = row_chunks[i]
1484-
cchunk = col_chunks[j]
1485-
cb = _chunk_bounds(
1486-
out_bounds, out_shape,
1487-
row_offset, row_offset + rchunk,
1488-
col_offset, col_offset + cchunk,
1489-
)
1490-
delayed_chunk = dask.delayed(_merge_chunk_worker)(
1491-
data_list, bounds_list, shape_list, ydesc_list,
1492-
wkt_list, tgt_wkt,
1493-
cb, (rchunk, cchunk),
1494-
resampling, nodata, strategy, 16,
1495-
)
1496-
blocks[i][j] = da.from_delayed(
1497-
delayed_chunk, shape=(rchunk, cchunk), dtype=dtype
1498-
)
1499-
col_offset += cchunk
1500-
row_offset += rchunk
1558+
template = da.empty(
1559+
out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks)
1560+
)
15011561

1502-
return da.block(blocks)
1562+
return da.map_blocks(
1563+
_merge_block_adapter,
1564+
template,
1565+
raster_data_list=data_list,
1566+
src_bounds_list=bounds_list,
1567+
src_shape_list=shape_list,
1568+
y_desc_list=ydesc_list,
1569+
src_wkt_list=wkt_list,
1570+
tgt_wkt=tgt_wkt,
1571+
out_bounds=out_bounds,
1572+
out_shape=out_shape,
1573+
resampling=resampling,
1574+
nodata=nodata,
1575+
strategy=strategy,
1576+
precision=16,
1577+
src_footprints_tgt=footprints,
1578+
dtype=np.float64,
1579+
meta=np.array((), dtype=np.float64),
1580+
)

0 commit comments

Comments
 (0)