Skip to content

Commit acd7a5a

Browse files
committed
Add Numba JIT and CUDA projection kernels for reproject (#1045)
Ports six projections from the PROJ library to Numba (CPU) and Numba CUDA (GPU), bypassing pyproj for common CRS pairs: - Web Mercator (EPSG:3857) -- spherical, 3 lines per direction - Transverse Mercator / UTM (326xx, 327xx, 269xx) -- 6th-order Krueger series (Karney 2011), closed-form forward and inverse - Ellipsoidal Mercator (EPSG:3395) -- Newton inverse - Lambert Conformal Conic (e.g. EPSG:2154) -- Newton inverse - Albers Equal Area (e.g. EPSG:5070) -- authalic latitude series - Cylindrical Equal Area (e.g. EPSG:6933) -- authalic latitude series CPU Numba kernels are 6-9x faster than pyproj. CUDA kernels are 40-165x faster. Unsupported CRS pairs fall back to pyproj. _transform_coords now tries Numba first, then pyproj. The CuPy chunk worker tries CUDA first, keeping coordinates on-device.
1 parent 82d4798 commit acd7a5a

File tree

3 files changed

+1612
-39
lines changed

3 files changed

+1612
-39
lines changed

xrspatial/reproject/__init__.py

Lines changed: 125 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,75 @@ def _is_y_descending(raster):
6363
return float(y[0]) > float(y[-1])
6464

6565

66+
# ---------------------------------------------------------------------------
67+
# Per-chunk coordinate transform
68+
# ---------------------------------------------------------------------------
69+
70+
def _transform_coords(transformer, chunk_bounds, chunk_shape,
71+
transform_precision, src_crs=None, tgt_crs=None):
72+
"""Compute source CRS coordinates for every output pixel.
73+
74+
When *transform_precision* is 0, every pixel is transformed through
75+
pyproj exactly (same strategy as GDAL/rasterio). Otherwise an
76+
approximate bilinear control-grid interpolation is used.
77+
78+
For common CRS pairs (WGS84/NAD83 <-> UTM, WGS84 <-> Web Mercator),
79+
a Numba JIT fast path bypasses pyproj entirely for ~30x speedup.
80+
81+
Returns
82+
-------
83+
src_y, src_x : ndarray (height, width)
84+
"""
85+
# Try Numba fast path for common projections
86+
if src_crs is not None and tgt_crs is not None:
87+
try:
88+
from ._projections import try_numba_transform
89+
result = try_numba_transform(
90+
src_crs, tgt_crs, chunk_bounds, chunk_shape,
91+
)
92+
if result is not None:
93+
return result
94+
except Exception:
95+
pass # fall through to pyproj
96+
97+
height, width = chunk_shape
98+
left, bottom, right, top = chunk_bounds
99+
res_x = (right - left) / width
100+
res_y = (top - bottom) / height
101+
102+
if transform_precision == 0:
103+
# Exact per-pixel transform via pyproj bulk API.
104+
# Process in row strips to keep memory bounded and improve
105+
# cache locality for large rasters.
106+
out_x_1d = left + (np.arange(width, dtype=np.float64) + 0.5) * res_x
107+
src_x_out = np.empty((height, width), dtype=np.float64)
108+
src_y_out = np.empty((height, width), dtype=np.float64)
109+
strip = 256
110+
for r0 in range(0, height, strip):
111+
r1 = min(r0 + strip, height)
112+
n_rows = r1 - r0
113+
out_y_strip = top - (np.arange(r0, r1, dtype=np.float64) + 0.5) * res_y
114+
# Broadcast to (n_rows, width) without allocating a full copy
115+
sx, sy = transformer.transform(
116+
np.tile(out_x_1d, n_rows),
117+
np.repeat(out_y_strip, width),
118+
)
119+
src_x_out[r0:r1] = np.asarray(sx, dtype=np.float64).reshape(n_rows, width)
120+
src_y_out[r0:r1] = np.asarray(sy, dtype=np.float64).reshape(n_rows, width)
121+
return src_y_out, src_x_out
122+
123+
# Approximate: bilinear interpolation on a coarse control grid.
124+
approx = ApproximateTransform(
125+
transformer, chunk_bounds, chunk_shape,
126+
precision=transform_precision,
127+
)
128+
row_grid = np.arange(height, dtype=np.float64)[:, np.newaxis]
129+
col_grid = np.arange(width, dtype=np.float64)[np.newaxis, :]
130+
row_grid = np.broadcast_to(row_grid, (height, width))
131+
col_grid = np.broadcast_to(col_grid, (height, width))
132+
return approx(row_grid, col_grid)
133+
134+
66135
# ---------------------------------------------------------------------------
67136
# Per-chunk worker functions
68137
# ---------------------------------------------------------------------------
@@ -89,20 +158,11 @@ def _reproject_chunk_numpy(
89158
tgt_crs, src_crs, always_xy=True
90159
)
91160

92-
height, width = chunk_shape
93-
approx = ApproximateTransform(
94-
transformer, chunk_bounds_tuple, chunk_shape,
95-
precision=transform_precision,
96-
)
97-
98-
# All output pixel positions (broadcast 1-D arrays to avoid HxW meshgrid)
99-
row_grid = np.arange(height, dtype=np.float64)[:, np.newaxis]
100-
col_grid = np.arange(width, dtype=np.float64)[np.newaxis, :]
101-
row_grid = np.broadcast_to(row_grid, (height, width))
102-
col_grid = np.broadcast_to(col_grid, (height, width))
103-
104161
# Source CRS coordinates for each output pixel
105-
src_y, src_x = approx(row_grid, col_grid)
162+
src_y, src_x = _transform_coords(
163+
transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
164+
src_crs=src_crs, tgt_crs=tgt_crs,
165+
)
106166

107167
# Convert source CRS coordinates to source pixel coordinates
108168
src_left, src_bottom, src_right, src_top = source_bounds_tuple
@@ -170,35 +230,59 @@ def _reproject_chunk_cupy(
170230
tgt_crs, src_crs, always_xy=True
171231
)
172232

173-
height, width = chunk_shape
174-
approx = ApproximateTransform(
175-
transformer, chunk_bounds_tuple, chunk_shape,
176-
precision=transform_precision,
177-
)
178-
179-
row_grid = np.arange(height, dtype=np.float64)[:, np.newaxis]
180-
col_grid = np.arange(width, dtype=np.float64)[np.newaxis, :]
181-
row_grid = np.broadcast_to(row_grid, (height, width))
182-
col_grid = np.broadcast_to(col_grid, (height, width))
233+
# Try CUDA transform first (keeps coordinates on-device)
234+
cuda_result = None
235+
if src_crs is not None and tgt_crs is not None:
236+
try:
237+
from ._projections_cuda import try_cuda_transform
238+
cuda_result = try_cuda_transform(
239+
src_crs, tgt_crs, chunk_bounds_tuple, chunk_shape,
240+
)
241+
except Exception:
242+
pass
183243

184-
# Control grid is on CPU
185-
src_y, src_x = approx(row_grid, col_grid)
244+
if cuda_result is not None:
245+
src_y, src_x = cuda_result # cupy arrays
246+
src_left, src_bottom, src_right, src_top = source_bounds_tuple
247+
src_h, src_w = source_shape
248+
src_res_x = (src_right - src_left) / src_w
249+
src_res_y = (src_top - src_bottom) / src_h
250+
# Pixel coordinate math stays on GPU via cupy operators
251+
src_col_px = (src_x - src_left) / src_res_x - 0.5
252+
if source_y_desc:
253+
src_row_px = (src_top - src_y) / src_res_y - 0.5
254+
else:
255+
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
256+
# Need min/max on CPU for window selection
257+
r_min = int(cp.floor(cp.nanmin(src_row_px)).get()) - 2
258+
r_max = int(cp.ceil(cp.nanmax(src_row_px)).get()) + 3
259+
c_min = int(cp.floor(cp.nanmin(src_col_px)).get()) - 2
260+
c_max = int(cp.ceil(cp.nanmax(src_col_px)).get()) + 3
261+
# Convert to numpy for downstream resampling
262+
src_row_px = cp.asnumpy(src_row_px)
263+
src_col_px = cp.asnumpy(src_col_px)
264+
else:
265+
# CPU fallback (Numba JIT or pyproj)
266+
src_y, src_x = _transform_coords(
267+
transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
268+
src_crs=src_crs, tgt_crs=tgt_crs,
269+
)
186270

187-
src_left, src_bottom, src_right, src_top = source_bounds_tuple
188-
src_h, src_w = source_shape
189-
src_res_x = (src_right - src_left) / src_w
190-
src_res_y = (src_top - src_bottom) / src_h
271+
src_left, src_bottom, src_right, src_top = source_bounds_tuple
272+
src_h, src_w = source_shape
273+
src_res_x = (src_right - src_left) / src_w
274+
src_res_y = (src_top - src_bottom) / src_h
191275

192-
src_col_px = (src_x - src_left) / src_res_x - 0.5
193-
if source_y_desc:
194-
src_row_px = (src_top - src_y) / src_res_y - 0.5
195-
else:
196-
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
276+
src_col_px = (src_x - src_left) / src_res_x - 0.5
277+
if source_y_desc:
278+
src_row_px = (src_top - src_y) / src_res_y - 0.5
279+
else:
280+
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
197281

198-
r_min = int(np.floor(np.nanmin(src_row_px))) - 2
199-
r_max = int(np.ceil(np.nanmax(src_row_px))) + 3
200-
c_min = int(np.floor(np.nanmin(src_col_px))) - 2
201-
c_max = int(np.ceil(np.nanmax(src_col_px))) + 3
282+
r_min = int(np.floor(np.nanmin(src_row_px))) - 2
283+
r_max = int(np.ceil(np.nanmax(src_row_px))) + 3
284+
c_min = int(np.floor(np.nanmin(src_col_px))) - 2
285+
c_max = int(np.ceil(np.nanmax(src_col_px))) + 3
202286

203287
if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
204288
return cp.full(chunk_shape, nodata, dtype=cp.float64)
@@ -271,7 +355,9 @@ def reproject(
271355
nodata : float or None
272356
Nodata value. Auto-detected if None.
273357
transform_precision : int
274-
Coarse grid subdivisions for approximate transform (default 16).
358+
Control-grid subdivisions for the coordinate transform (default 16).
359+
Higher values increase accuracy at the cost of more pyproj calls.
360+
Set to 0 for exact per-pixel transforms matching GDAL/rasterio.
275361
chunk_size : int or (int, int) or None
276362
Output chunk size for dask. Defaults to 512.
277363
name : str or None

0 commit comments

Comments
 (0)