@@ -63,6 +63,75 @@ def _is_y_descending(raster):
6363 return float (y [0 ]) > float (y [- 1 ])
6464
6565
66+ # ---------------------------------------------------------------------------
67+ # Per-chunk coordinate transform
68+ # ---------------------------------------------------------------------------
69+
70+ def _transform_coords (transformer , chunk_bounds , chunk_shape ,
71+ transform_precision , src_crs = None , tgt_crs = None ):
72+ """Compute source CRS coordinates for every output pixel.
73+
74+ When *transform_precision* is 0, every pixel is transformed through
75+ pyproj exactly (same strategy as GDAL/rasterio). Otherwise an
76+ approximate bilinear control-grid interpolation is used.
77+
78+ For common CRS pairs (WGS84/NAD83 <-> UTM, WGS84 <-> Web Mercator),
79+ a Numba JIT fast path bypasses pyproj entirely for ~30x speedup.
80+
81+ Returns
82+ -------
83+ src_y, src_x : ndarray (height, width)
84+ """
85+ # Try Numba fast path for common projections
86+ if src_crs is not None and tgt_crs is not None :
87+ try :
88+ from ._projections import try_numba_transform
89+ result = try_numba_transform (
90+ src_crs , tgt_crs , chunk_bounds , chunk_shape ,
91+ )
92+ if result is not None :
93+ return result
94+ except Exception :
95+ pass # fall through to pyproj
96+
97+ height , width = chunk_shape
98+ left , bottom , right , top = chunk_bounds
99+ res_x = (right - left ) / width
100+ res_y = (top - bottom ) / height
101+
102+ if transform_precision == 0 :
103+ # Exact per-pixel transform via pyproj bulk API.
104+ # Process in row strips to keep memory bounded and improve
105+ # cache locality for large rasters.
106+ out_x_1d = left + (np .arange (width , dtype = np .float64 ) + 0.5 ) * res_x
107+ src_x_out = np .empty ((height , width ), dtype = np .float64 )
108+ src_y_out = np .empty ((height , width ), dtype = np .float64 )
109+ strip = 256
110+ for r0 in range (0 , height , strip ):
111+ r1 = min (r0 + strip , height )
112+ n_rows = r1 - r0
113+ out_y_strip = top - (np .arange (r0 , r1 , dtype = np .float64 ) + 0.5 ) * res_y
114+ # Broadcast to (n_rows, width) without allocating a full copy
115+ sx , sy = transformer .transform (
116+ np .tile (out_x_1d , n_rows ),
117+ np .repeat (out_y_strip , width ),
118+ )
119+ src_x_out [r0 :r1 ] = np .asarray (sx , dtype = np .float64 ).reshape (n_rows , width )
120+ src_y_out [r0 :r1 ] = np .asarray (sy , dtype = np .float64 ).reshape (n_rows , width )
121+ return src_y_out , src_x_out
122+
123+ # Approximate: bilinear interpolation on a coarse control grid.
124+ approx = ApproximateTransform (
125+ transformer , chunk_bounds , chunk_shape ,
126+ precision = transform_precision ,
127+ )
128+ row_grid = np .arange (height , dtype = np .float64 )[:, np .newaxis ]
129+ col_grid = np .arange (width , dtype = np .float64 )[np .newaxis , :]
130+ row_grid = np .broadcast_to (row_grid , (height , width ))
131+ col_grid = np .broadcast_to (col_grid , (height , width ))
132+ return approx (row_grid , col_grid )
133+
134+
66135# ---------------------------------------------------------------------------
67136# Per-chunk worker functions
68137# ---------------------------------------------------------------------------
@@ -89,20 +158,11 @@ def _reproject_chunk_numpy(
89158 tgt_crs , src_crs , always_xy = True
90159 )
91160
92- height , width = chunk_shape
93- approx = ApproximateTransform (
94- transformer , chunk_bounds_tuple , chunk_shape ,
95- precision = transform_precision ,
96- )
97-
98- # All output pixel positions (broadcast 1-D arrays to avoid HxW meshgrid)
99- row_grid = np .arange (height , dtype = np .float64 )[:, np .newaxis ]
100- col_grid = np .arange (width , dtype = np .float64 )[np .newaxis , :]
101- row_grid = np .broadcast_to (row_grid , (height , width ))
102- col_grid = np .broadcast_to (col_grid , (height , width ))
103-
104161 # Source CRS coordinates for each output pixel
105- src_y , src_x = approx (row_grid , col_grid )
162+ src_y , src_x = _transform_coords (
163+ transformer , chunk_bounds_tuple , chunk_shape , transform_precision ,
164+ src_crs = src_crs , tgt_crs = tgt_crs ,
165+ )
106166
107167 # Convert source CRS coordinates to source pixel coordinates
108168 src_left , src_bottom , src_right , src_top = source_bounds_tuple
@@ -170,35 +230,59 @@ def _reproject_chunk_cupy(
170230 tgt_crs , src_crs , always_xy = True
171231 )
172232
173- height , width = chunk_shape
174- approx = ApproximateTransform (
175- transformer , chunk_bounds_tuple , chunk_shape ,
176- precision = transform_precision ,
177- )
178-
179- row_grid = np . arange ( height , dtype = np . float64 )[:, np . newaxis ]
180- col_grid = np . arange ( width , dtype = np . float64 )[ np . newaxis , :]
181- row_grid = np . broadcast_to ( row_grid , ( height , width ))
182- col_grid = np . broadcast_to ( col_grid , ( height , width ))
233+ # Try CUDA transform first (keeps coordinates on-device)
234+ cuda_result = None
235+ if src_crs is not None and tgt_crs is not None :
236+ try :
237+ from . _projections_cuda import try_cuda_transform
238+ cuda_result = try_cuda_transform (
239+ src_crs , tgt_crs , chunk_bounds_tuple , chunk_shape ,
240+ )
241+ except Exception :
242+ pass
183243
184- # Control grid is on CPU
185- src_y , src_x = approx (row_grid , col_grid )
244+ if cuda_result is not None :
245+ src_y , src_x = cuda_result # cupy arrays
246+ src_left , src_bottom , src_right , src_top = source_bounds_tuple
247+ src_h , src_w = source_shape
248+ src_res_x = (src_right - src_left ) / src_w
249+ src_res_y = (src_top - src_bottom ) / src_h
250+ # Pixel coordinate math stays on GPU via cupy operators
251+ src_col_px = (src_x - src_left ) / src_res_x - 0.5
252+ if source_y_desc :
253+ src_row_px = (src_top - src_y ) / src_res_y - 0.5
254+ else :
255+ src_row_px = (src_y - src_bottom ) / src_res_y - 0.5
256+ # Need min/max on CPU for window selection
257+ r_min = int (cp .floor (cp .nanmin (src_row_px )).get ()) - 2
258+ r_max = int (cp .ceil (cp .nanmax (src_row_px )).get ()) + 3
259+ c_min = int (cp .floor (cp .nanmin (src_col_px )).get ()) - 2
260+ c_max = int (cp .ceil (cp .nanmax (src_col_px )).get ()) + 3
261+ # Convert to numpy for downstream resampling
262+ src_row_px = cp .asnumpy (src_row_px )
263+ src_col_px = cp .asnumpy (src_col_px )
264+ else :
265+ # CPU fallback (Numba JIT or pyproj)
266+ src_y , src_x = _transform_coords (
267+ transformer , chunk_bounds_tuple , chunk_shape , transform_precision ,
268+ src_crs = src_crs , tgt_crs = tgt_crs ,
269+ )
186270
187- src_left , src_bottom , src_right , src_top = source_bounds_tuple
188- src_h , src_w = source_shape
189- src_res_x = (src_right - src_left ) / src_w
190- src_res_y = (src_top - src_bottom ) / src_h
271+ src_left , src_bottom , src_right , src_top = source_bounds_tuple
272+ src_h , src_w = source_shape
273+ src_res_x = (src_right - src_left ) / src_w
274+ src_res_y = (src_top - src_bottom ) / src_h
191275
192- src_col_px = (src_x - src_left ) / src_res_x - 0.5
193- if source_y_desc :
194- src_row_px = (src_top - src_y ) / src_res_y - 0.5
195- else :
196- src_row_px = (src_y - src_bottom ) / src_res_y - 0.5
276+ src_col_px = (src_x - src_left ) / src_res_x - 0.5
277+ if source_y_desc :
278+ src_row_px = (src_top - src_y ) / src_res_y - 0.5
279+ else :
280+ src_row_px = (src_y - src_bottom ) / src_res_y - 0.5
197281
198- r_min = int (np .floor (np .nanmin (src_row_px ))) - 2
199- r_max = int (np .ceil (np .nanmax (src_row_px ))) + 3
200- c_min = int (np .floor (np .nanmin (src_col_px ))) - 2
201- c_max = int (np .ceil (np .nanmax (src_col_px ))) + 3
282+ r_min = int (np .floor (np .nanmin (src_row_px ))) - 2
283+ r_max = int (np .ceil (np .nanmax (src_row_px ))) + 3
284+ c_min = int (np .floor (np .nanmin (src_col_px ))) - 2
285+ c_max = int (np .ceil (np .nanmax (src_col_px ))) + 3
202286
203287 if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0 :
204288 return cp .full (chunk_shape , nodata , dtype = cp .float64 )
@@ -271,7 +355,9 @@ def reproject(
271355 nodata : float or None
272356 Nodata value. Auto-detected if None.
273357 transform_precision : int
274- Coarse grid subdivisions for approximate transform (default 16).
358+ Control-grid subdivisions for the coordinate transform (default 16).
359+ Higher values increase accuracy at the cost of more pyproj calls.
360+ Set to 0 for exact per-pixel transforms matching GDAL/rasterio.
275361 chunk_size : int or (int, int) or None
276362 Output chunk size for dask. Defaults to 512.
277363 name : str or None
0 commit comments