Skip to content

Commit ba1c048

Browse files
committed
Chunked dask+cupy reproject without full-source eager compute (#1045)
Replaces the eager .compute() approach with a chunked GPU pipeline that fetches only the needed source window per output chunk. This handles sources larger than GPU memory while still being 8-20x faster than the old dask.delayed path. The key optimizations vs dask.delayed: - CRS objects and transformer created once (not per chunk) - CUDA projection + native CUDA resampling per chunk - Default 2048x2048 GPU chunks (not 512x512) - Sequential loop avoids dask scheduler overhead Performance (4096x4096 WGS84 -> UTM, bilinear): CuPy single pass: 34ms Dask+CuPy (2048): 49ms (was 958ms) Dask+CuPy (512): 71ms Dask+CuPy (256): 124ms All chunk sizes are pixel-exact vs plain CuPy (max_err < 1e-11).
1 parent a82e7d0 commit ba1c048

File tree

1 file changed

+143
-16
lines changed

1 file changed

+143
-16
lines changed

xrspatial/reproject/__init__.py

Lines changed: 143 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -444,24 +444,12 @@ def reproject(
444444
tgt_wkt = tgt_crs.to_wkt()
445445

446446
if is_dask and is_cupy:
447-
# Dask+CuPy: eagerly compute source to GPU, then single-pass
448-
# CuPy reproject. This avoids per-chunk overhead (pyproj init,
449-
# small CUDA kernel launches, dask scheduler) that makes chunked
450-
# GPU reproject ~28x slower than a single pass. The output is
451-
# returned as a plain CuPy array; caller can .rechunk() if needed.
452-
import cupy as _cp
453-
eager_data = raster.data.compute()
454-
if not isinstance(eager_data, _cp.ndarray):
455-
eager_data = _cp.asarray(eager_data)
456-
eager_da = xr.DataArray(
457-
eager_data, dims=raster.dims,
458-
coords=raster.coords, attrs=raster.attrs,
459-
)
460-
result_data = _reproject_inmemory_cupy(
461-
eager_da, src_bounds, src_shape, y_desc,
447+
result_data = _reproject_dask_cupy(
448+
raster, src_bounds, src_shape, y_desc,
462449
src_wkt, tgt_wkt,
463450
out_bounds, out_shape,
464451
resampling, nd, transform_precision,
452+
chunk_size,
465453
)
466454
elif is_dask:
467455
result_data = _reproject_dask(
@@ -533,14 +521,153 @@ def _reproject_inmemory_cupy(
533521
)
534522

535523

524+
def _reproject_dask_cupy(
525+
raster, src_bounds, src_shape, y_desc,
526+
src_wkt, tgt_wkt,
527+
out_bounds, out_shape,
528+
resampling, nodata, precision,
529+
chunk_size,
530+
):
531+
"""Dask+CuPy backend: process output chunks on GPU sequentially.
532+
533+
Instead of dask.delayed per chunk (which has ~15ms overhead each from
534+
pyproj init + small CUDA launches), we:
535+
1. Create CRS/transformer objects once
536+
2. Use GPU-sized output chunks (2048x2048 by default)
537+
3. For each output chunk, compute CUDA coordinates and fetch only
538+
the source window needed from the dask array
539+
4. Assemble the result as a CuPy array
540+
541+
For sources that fit in GPU memory, this is ~22x faster than the
542+
dask.delayed path. For sources that don't fit, each chunk fetches
543+
only its required window, so GPU memory usage scales with chunk size,
544+
not source size.
545+
"""
546+
import cupy as cp
547+
548+
from ._crs_utils import _require_pyproj
549+
550+
pyproj = _require_pyproj()
551+
src_crs = pyproj.CRS.from_wkt(src_wkt)
552+
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
553+
554+
# Use larger chunks for GPU to amortize kernel launch overhead
555+
gpu_chunk = chunk_size or 2048
556+
if isinstance(gpu_chunk, int):
557+
gpu_chunk = (gpu_chunk, gpu_chunk)
558+
559+
row_chunks, col_chunks = _compute_chunk_layout(out_shape, gpu_chunk)
560+
out_h, out_w = out_shape
561+
src_left, src_bottom, src_right, src_top = src_bounds
562+
src_h, src_w = src_shape
563+
src_res_x = (src_right - src_left) / src_w
564+
src_res_y = (src_top - src_bottom) / src_h
565+
566+
result = cp.full(out_shape, nodata, dtype=cp.float64)
567+
568+
row_offset = 0
569+
for i, rchunk in enumerate(row_chunks):
570+
col_offset = 0
571+
for j, cchunk in enumerate(col_chunks):
572+
cb = _chunk_bounds(
573+
out_bounds, out_shape,
574+
row_offset, row_offset + rchunk,
575+
col_offset, col_offset + cchunk,
576+
)
577+
chunk_shape = (rchunk, cchunk)
578+
579+
# CUDA coordinate transform (reuses cached CRS objects)
580+
try:
581+
from ._projections_cuda import try_cuda_transform
582+
cuda_coords = try_cuda_transform(
583+
src_crs, tgt_crs, cb, chunk_shape,
584+
)
585+
except Exception:
586+
cuda_coords = None
587+
588+
if cuda_coords is not None:
589+
src_y, src_x = cuda_coords
590+
src_col_px = (src_x - src_left) / src_res_x - 0.5
591+
if y_desc:
592+
src_row_px = (src_top - src_y) / src_res_y - 0.5
593+
else:
594+
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
595+
596+
r_min = int(cp.floor(cp.nanmin(src_row_px)).get()) - 2
597+
r_max = int(cp.ceil(cp.nanmax(src_row_px)).get()) + 3
598+
c_min = int(cp.floor(cp.nanmin(src_col_px)).get()) - 2
599+
c_max = int(cp.ceil(cp.nanmax(src_col_px)).get()) + 3
600+
else:
601+
# CPU fallback for this chunk
602+
transformer = pyproj.Transformer.from_crs(
603+
tgt_crs, src_crs, always_xy=True
604+
)
605+
src_y, src_x = _transform_coords(
606+
transformer, cb, chunk_shape, precision,
607+
src_crs=src_crs, tgt_crs=tgt_crs,
608+
)
609+
src_col_px = (src_x - src_left) / src_res_x - 0.5
610+
if y_desc:
611+
src_row_px = (src_top - src_y) / src_res_y - 0.5
612+
else:
613+
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
614+
r_min = int(np.floor(np.nanmin(src_row_px))) - 2
615+
r_max = int(np.ceil(np.nanmax(src_row_px))) + 3
616+
c_min = int(np.floor(np.nanmin(src_col_px))) - 2
617+
c_max = int(np.ceil(np.nanmax(src_col_px))) + 3
618+
619+
# Check overlap
620+
if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
621+
col_offset += cchunk
622+
continue
623+
624+
r_min_clip = max(0, r_min)
625+
r_max_clip = min(src_h, r_max)
626+
c_min_clip = max(0, c_min)
627+
c_max_clip = min(src_w, c_max)
628+
629+
# Fetch only the needed source window from dask
630+
window = raster.data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
631+
if hasattr(window, 'compute'):
632+
window = window.compute()
633+
if not isinstance(window, cp.ndarray):
634+
window = cp.asarray(window)
635+
window = window.astype(cp.float64)
636+
637+
if not np.isnan(nodata):
638+
window = window.copy()
639+
window[window == nodata] = cp.nan
640+
641+
local_row = src_row_px - r_min_clip
642+
local_col = src_col_px - c_min_clip
643+
644+
if cuda_coords is not None:
645+
chunk_data = _resample_cupy_native(
646+
window, local_row, local_col,
647+
resampling=resampling, nodata=nodata,
648+
)
649+
else:
650+
chunk_data = _resample_cupy(
651+
window, local_row, local_col,
652+
resampling=resampling, nodata=nodata,
653+
)
654+
655+
result[row_offset:row_offset + rchunk,
656+
col_offset:col_offset + cchunk] = chunk_data
657+
col_offset += cchunk
658+
row_offset += rchunk
659+
660+
return result
661+
662+
536663
def _reproject_dask(
537664
raster, src_bounds, src_shape, y_desc,
538665
src_wkt, tgt_wkt,
539666
out_bounds, out_shape,
540667
resampling, nodata, precision,
541668
chunk_size, is_cupy,
542669
):
543-
"""Dask backend: build output as ``da.block`` of delayed chunks."""
670+
"""Dask+NumPy backend: build output as ``da.block`` of delayed chunks."""
544671
import dask
545672
import dask.array as da
546673

0 commit comments

Comments
 (0)