Skip to content

Commit 7a4f70f

Browse files
committed
Use Numba scatter-point transform for grid boundary estimation (#1057)
1 parent 8b92b4f commit 7a4f70f

File tree

3 files changed

+88
-17
lines changed

3 files changed

+88
-17
lines changed

xrspatial/reproject/_grid.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,38 @@
44
import numpy as np
55

66

7+
def _transform_boundary(source_crs, target_crs, xs, ys):
8+
"""Transform coordinate arrays, preferring Numba fast path over pyproj.
9+
10+
Parameters
11+
----------
12+
source_crs, target_crs : CRS-like
13+
Source and target coordinate reference systems.
14+
xs, ys : ndarray
15+
1-D arrays of x and y coordinates in *source_crs*.
16+
17+
Returns
18+
-------
19+
tx, ty : ndarray
20+
Transformed coordinates as numpy arrays.
21+
"""
22+
from ._projections import transform_points
23+
24+
result = transform_points(source_crs, target_crs, xs, ys)
25+
if result is not None:
26+
return result
27+
28+
# Fall back to pyproj
29+
from ._crs_utils import _require_pyproj
30+
31+
pyproj = _require_pyproj()
32+
transformer = pyproj.Transformer.from_crs(
33+
source_crs, target_crs, always_xy=True
34+
)
35+
tx, ty = transformer.transform(xs, ys)
36+
return np.asarray(tx), np.asarray(ty)
37+
38+
739
def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
840
resolution=None, bounds=None, width=None, height=None):
941
"""Compute the output raster grid parameters.
@@ -14,7 +46,7 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
1446
(left, bottom, right, top) in source CRS.
1547
source_shape : tuple
1648
(height, width) of source raster.
17-
source_crs, target_crs : pyproj.CRS
49+
source_crs, target_crs : CRS-like
1850
Source and target coordinate reference systems.
1951
resolution : float or tuple or None
2052
Target resolution. If tuple, (x_res, y_res).
@@ -27,13 +59,6 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
2759
-------
2860
dict with keys: bounds, shape, res_x, res_y
2961
"""
30-
from ._crs_utils import _require_pyproj
31-
32-
pyproj = _require_pyproj()
33-
transformer = pyproj.Transformer.from_crs(
34-
source_crs, target_crs, always_xy=True
35-
)
36-
3762
if bounds is None:
3863
# Transform source corners and edges to target CRS
3964
src_left, src_bottom, src_right, src_top = source_bounds
@@ -76,7 +101,7 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
76101
ixx, iyy = np.meshgrid(ix, iy)
77102
xs = np.concatenate([edge_xs, ixx.ravel()])
78103
ys = np.concatenate([edge_ys, iyy.ravel()])
79-
tx, ty = transformer.transform(xs, ys)
104+
tx, ty = _transform_boundary(source_crs, target_crs, xs, ys)
80105
tx = np.asarray(tx)
81106
ty = np.asarray(ty)
82107
# Filter out inf/nan from failed transforms
@@ -110,7 +135,9 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
110135
ix = np.linspace(src_left, src_right, n_dense)
111136
iy = np.linspace(src_bottom, src_top, n_dense)
112137
ixx, iyy = np.meshgrid(ix, iy)
113-
itx, ity = transformer.transform(ixx.ravel(), iyy.ravel())
138+
itx, ity = _transform_boundary(
139+
source_crs, target_crs, ixx.ravel(), iyy.ravel()
140+
)
114141
itx = np.asarray(itx)
115142
ity = np.asarray(ity)
116143
ivalid = np.isfinite(itx) & np.isfinite(ity)
@@ -150,13 +177,15 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
150177
src_res_y = (src_top - src_bottom) / src_h
151178
center_x = (src_left + src_right) / 2
152179
center_y = (src_bottom + src_top) / 2
153-
tc_x, tc_y = transformer.transform(center_x, center_y)
154-
# Step along x only
155-
tx_x, tx_y = transformer.transform(center_x + src_res_x, center_y)
156-
dx = np.hypot(float(tx_x) - float(tc_x), float(tx_y) - float(tc_y))
157-
# Step along y only
158-
ty_x, ty_y = transformer.transform(center_x, center_y + src_res_y)
159-
dy = np.hypot(float(ty_x) - float(tc_x), float(ty_y) - float(tc_y))
180+
# Batch the three resolution-estimation points into one call
181+
pts_x = np.array([center_x, center_x + src_res_x, center_x])
182+
pts_y = np.array([center_y, center_y, center_y + src_res_y])
183+
tp_x, tp_y = _transform_boundary(source_crs, target_crs, pts_x, pts_y)
184+
tc_x, tc_y = float(tp_x[0]), float(tp_y[0])
185+
tx_x, tx_y = float(tp_x[1]), float(tp_y[1])
186+
ty_x, ty_y = float(tp_x[2]), float(tp_y[2])
187+
dx = np.hypot(tx_x - tc_x, tx_y - tc_y)
188+
dy = np.hypot(ty_x - tc_x, ty_y - tc_y)
160189
if dx == 0 or dy == 0:
161190
res_x = (right - left) / src_w
162191
res_y = (top - bottom) / src_h

xrspatial/reproject/_lite_crs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,24 @@ def to_wkt(self) -> str:
357357
@staticmethod
358358
def _wkt_parameters(entry: dict) -> list[str]:
359359
"""Build WKT PARAMETER[] entries from a proj dict."""
360+
# For UTM entries, expand zone into explicit TM parameters so that
361+
# parsers (including pyproj) get the correct central meridian and
362+
# scale factor rather than defaulting to 0 / 1.
363+
if entry.get("proj") == "utm" and "zone" in entry:
364+
zone = entry["zone"]
365+
lon_0 = zone * 6 - 183
366+
k_0 = 0.9996
367+
lat_0 = 0
368+
x_0 = entry.get("x_0", 500000)
369+
y_0 = entry.get("y_0", 0)
370+
return [
371+
f'PARAMETER["latitude_of_origin",{lat_0}]',
372+
f'PARAMETER["central_meridian",{lon_0}]',
373+
f'PARAMETER["scale_factor",{k_0}]',
374+
f'PARAMETER["false_easting",{x_0}]',
375+
f'PARAMETER["false_northing",{y_0}]',
376+
]
377+
360378
# Map from proj keys to WKT parameter names
361379
key_map = {
362380
"lat_0": "latitude_of_origin",

xrspatial/tests/test_lite_crs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,30 @@ def test_crs_from_wkt_lite(self):
214214
assert result.to_epsg() == 4326
215215

216216

217+
# -----------------------------------------------------------------------
218+
# Grid computation with lite CRS (no pyproj needed)
219+
# -----------------------------------------------------------------------
220+
class TestGridWithoutPyproj:
221+
def test_compute_output_grid_with_lite_crs(self):
222+
from xrspatial.reproject._grid import _compute_output_grid
223+
from xrspatial.reproject._lite_crs import CRS
224+
225+
src_crs = CRS(4326)
226+
tgt_crs = CRS(32632)
227+
source_bounds = (6.0, 47.0, 12.0, 55.0)
228+
source_shape = (64, 64)
229+
grid = _compute_output_grid(
230+
source_bounds, source_shape, src_crs, tgt_crs
231+
)
232+
assert 'bounds' in grid
233+
assert 'shape' in grid
234+
h, w = grid['shape']
235+
assert h > 0 and w > 0
236+
left, bottom, right, top = grid['bounds']
237+
assert right > left
238+
assert top > bottom
239+
240+
217241
# -----------------------------------------------------------------------
218242
# Validate against pyproj (skipped when pyproj not installed)
219243
# -----------------------------------------------------------------------

0 commit comments

Comments
 (0)