Skip to content

Commit 2771fca

Browse files
committed
Wire chunk functions and merge helper to use lite CRS (#1057)
1 parent 7a4f70f commit 2771fca

File tree

2 files changed

+69
-41
lines changed

2 files changed

+69
-41
lines changed

xrspatial/reproject/__init__.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,10 @@ def _reproject_chunk_numpy(
192192
Called inside ``dask.delayed`` for the dask path, or directly for numpy.
193193
CRS objects are passed as WKT strings for pickle safety.
194194
"""
195-
from ._crs_utils import _require_pyproj
195+
from ._crs_utils import _crs_from_wkt
196196

197-
pyproj = _require_pyproj()
198-
src_crs = pyproj.CRS.from_wkt(src_wkt)
199-
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
197+
src_crs = _crs_from_wkt(src_wkt)
198+
tgt_crs = _crs_from_wkt(tgt_wkt)
200199

201200
# Try Numba fast path first (avoids creating pyproj Transformer)
202201
numba_result = None
@@ -212,6 +211,8 @@ def _reproject_chunk_numpy(
212211
src_y, src_x = numba_result
213212
else:
214213
# Fallback: create pyproj Transformer (expensive)
214+
from ._crs_utils import _require_pyproj
215+
pyproj = _require_pyproj()
215216
transformer = pyproj.Transformer.from_crs(
216217
tgt_crs, src_crs, always_xy=True
217218
)
@@ -321,15 +322,10 @@ def _reproject_chunk_cupy(
321322
"""CuPy variant of ``_reproject_chunk_numpy``."""
322323
import cupy as cp
323324

324-
from ._crs_utils import _require_pyproj
325+
from ._crs_utils import _crs_from_wkt
325326

326-
pyproj = _require_pyproj()
327-
src_crs = pyproj.CRS.from_wkt(src_wkt)
328-
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
329-
330-
transformer = pyproj.Transformer.from_crs(
331-
tgt_crs, src_crs, always_xy=True
332-
)
327+
src_crs = _crs_from_wkt(src_wkt)
328+
tgt_crs = _crs_from_wkt(tgt_wkt)
333329

334330
# Try CUDA transform first (keeps coordinates on-device)
335331
cuda_result = None
@@ -371,6 +367,11 @@ def _reproject_chunk_cupy(
371367
_use_native_cuda = True
372368
else:
373369
# CPU fallback (Numba JIT or pyproj)
370+
from ._crs_utils import _require_pyproj
371+
pyproj = _require_pyproj()
372+
transformer = pyproj.Transformer.from_crs(
373+
tgt_crs, src_crs, always_xy=True
374+
)
374375
src_y, src_x = _transform_coords(
375376
transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
376377
src_crs=src_crs, tgt_crs=tgt_crs,
@@ -513,16 +514,13 @@ def reproject(
513514
If vertical transformation was applied, ``attrs['vertical_crs']``
514515
records the target vertical datum.
515516
"""
516-
from ._crs_utils import _require_pyproj
517-
518517
if not isinstance(raster, xr.DataArray):
519518
raise TypeError(
520519
f"reproject(): raster must be an xr.DataArray, "
521520
f"got {type(raster).__name__}"
522521
)
523522

524523
_validate_resampling(resampling)
525-
_require_pyproj()
526524

527525
# Resolve CRS
528526
src_crs = _resolve_crs(source_crs)
@@ -984,11 +982,10 @@ def _reproject_dask_cupy(
984982
"""
985983
import cupy as cp
986984

987-
from ._crs_utils import _require_pyproj
985+
from ._crs_utils import _crs_from_wkt
988986

989-
pyproj = _require_pyproj()
990-
src_crs = pyproj.CRS.from_wkt(src_wkt)
991-
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
987+
src_crs = _crs_from_wkt(src_wkt)
988+
tgt_crs = _crs_from_wkt(tgt_wkt)
992989

993990
# Use larger chunks for GPU to amortize kernel launch overhead
994991
gpu_chunk = chunk_size or 2048
@@ -1048,6 +1045,8 @@ def _reproject_dask_cupy(
10481045
c_max = int(np.ceil(c_max_val)) + 3
10491046
else:
10501047
# CPU fallback for this chunk
1048+
from ._crs_utils import _require_pyproj
1049+
pyproj = _require_pyproj()
10511050
transformer = pyproj.Transformer.from_crs(
10521051
tgt_crs, src_crs, always_xy=True
10531052
)
@@ -1120,30 +1119,44 @@ def _reproject_dask_cupy(
11201119

11211120

11221121
def _source_footprint_in_target(src_bounds, src_wkt, tgt_wkt):
1123-
"""Compute an approximate bounding box of the source raster in target CRS.
1124-
1125-
Transforms corners and edge midpoints (12 points) to handle non-linear
1126-
projections. Returns ``(left, bottom, right, top)`` in target CRS, or
1127-
*None* if the transform fails (e.g. out-of-domain).
1128-
"""
1122+
"""Compute approximate bounding box of source raster in target CRS."""
11291123
try:
1130-
from ._crs_utils import _require_pyproj
1131-
pyproj = _require_pyproj()
1132-
src_crs = pyproj.CRS(src_wkt)
1133-
tgt_crs = pyproj.CRS(tgt_wkt)
1134-
transformer = pyproj.Transformer.from_crs(
1135-
src_crs, tgt_crs, always_xy=True
1136-
)
1124+
from ._crs_utils import _crs_from_wkt, _resolve_crs
1125+
try:
1126+
src_crs = _crs_from_wkt(src_wkt)
1127+
except Exception:
1128+
src_crs = _resolve_crs(src_wkt)
1129+
try:
1130+
tgt_crs = _crs_from_wkt(tgt_wkt)
1131+
except Exception:
1132+
tgt_crs = _resolve_crs(tgt_wkt)
11371133
except Exception:
11381134
return None
11391135

11401136
sl, sb, sr, st = src_bounds
11411137
mx = (sl + sr) / 2
11421138
my = (sb + st) / 2
1143-
xs = [sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx]
1144-
ys = [sb, sb, sb, my, my, my, st, st, st, mx, mx, sb]
1139+
xs = np.array([sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx])
1140+
ys = np.array([sb, sb, sb, my, my, my, st, st, st, mx, mx, sb])
1141+
11451142
try:
1146-
tx, ty = transformer.transform(xs, ys)
1143+
from ._projections import transform_points
1144+
result = transform_points(src_crs, tgt_crs, xs, ys)
1145+
if result is not None:
1146+
tx, ty = result
1147+
tx = [v for v in tx if np.isfinite(v)]
1148+
ty = [v for v in ty if np.isfinite(v)]
1149+
if not tx or not ty:
1150+
return None
1151+
return (min(tx), min(ty), max(tx), max(ty))
1152+
except (ImportError, ModuleNotFoundError):
1153+
pass
1154+
1155+
try:
1156+
from ._crs_utils import _require_pyproj
1157+
pyproj = _require_pyproj()
1158+
transformer = pyproj.Transformer.from_crs(src_crs, tgt_crs, always_xy=True)
1159+
tx, ty = transformer.transform(xs.tolist(), ys.tolist())
11471160
tx = [v for v in tx if np.isfinite(v)]
11481161
ty = [v for v in ty if np.isfinite(v)]
11491162
if not tx or not ty:
@@ -1298,14 +1311,11 @@ def merge(
12981311
-------
12991312
xr.DataArray
13001313
"""
1301-
from ._crs_utils import _require_pyproj
1302-
13031314
if not rasters:
13041315
raise ValueError("merge(): rasters list must not be empty")
13051316

13061317
_validate_resampling(resampling)
13071318
_validate_strategy(strategy)
1308-
pyproj = _require_pyproj()
13091319

13101320
# Resolve target CRS
13111321
tgt_crs = _resolve_crs(target_crs)
@@ -1485,9 +1495,8 @@ def _merge_inmemory(
14851495
Detects same-CRS tiles and uses fast direct placement instead
14861496
of reprojection.
14871497
"""
1488-
from ._crs_utils import _require_pyproj
1489-
pyproj = _require_pyproj()
1490-
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
1498+
from ._crs_utils import _crs_from_wkt
1499+
tgt_crs = _crs_from_wkt(tgt_wkt)
14911500

14921501
arrays = []
14931502
for info in raster_infos:

xrspatial/tests/test_reproject.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,3 +1313,22 @@ def test_bounds_overlap(self):
13131313
assert _bounds_overlap(a, (0, 0, 10, 10)) # identical
13141314
assert not _bounds_overlap(a, (11, 0, 20, 10)) # no overlap x
13151315
assert not _bounds_overlap(a, (0, 11, 10, 20)) # no overlap y
1316+
1317+
1318+
class TestReprojWithLiteCRS:
1319+
def test_reproject_wgs84_to_utm_with_lite_crs(self):
1320+
import xarray as xr
1321+
from xrspatial.reproject import reproject
1322+
import numpy as np
1323+
h, w = 32, 32
1324+
y = np.linspace(49, 47, h)
1325+
x = np.linspace(8, 10, w)
1326+
data = np.random.default_rng(42).random((h, w))
1327+
raster = xr.DataArray(
1328+
data, dims=['y', 'x'],
1329+
coords={'y': y, 'x': x},
1330+
attrs={'crs': 4326},
1331+
)
1332+
result = reproject(raster, target_crs=32632)
1333+
assert result.attrs['crs'] is not None
1334+
assert result.shape[0] > 0 and result.shape[1] > 0

0 commit comments

Comments
 (0)