Skip to content

Commit 3e918c9

Browse files
committed
Fix cubic NaN, add merge tests, validate grids, improve docs (#1045)
Cubic NaN handling: - When any of the 16 Catmull-Rom neighbors is NaN, falls back to bilinear with weight renormalization instead of returning nodata. Eliminates the one-pixel nodata halo around NaN regions that cubic resampling previously produced. Merge strategy tests: - Added end-to-end tests for last, max, min strategies (were only tested at the internal _merge_arrays_numpy level). Datum grid validation: - load_grid() now validates band shapes match, grid is >= 2x2, and bounds are sensible. Invalid grids return None (Helmert fallback) instead of producing garbage. Documentation: - reproject() and merge() docstrings now note output CRS is WKT format in attrs['crs'], and merge() documents CRS selection when target_crs=None.
1 parent e7e49f4 commit 3e918c9

4 files changed

Lines changed: 116 additions & 4 deletions

File tree

xrspatial/reproject/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def reproject(
429429
Returns
430430
-------
431431
xr.DataArray
432+
The output's ``attrs['crs']`` is always stored in WKT format.
432433
"""
433434
from ._crs_utils import _require_pyproj
434435

@@ -805,7 +806,8 @@ def merge(
805806
rasters : list of xr.DataArray
806807
Input rasters to merge.
807808
target_crs : optional
808-
Target CRS. Defaults to the CRS of the first raster.
809+
Target CRS. When None, the CRS of the first raster in the list
810+
is used.
809811
resolution : float or (float, float) or None
810812
Output resolution in target CRS units.
811813
bounds : (left, bottom, right, top) or None
@@ -824,6 +826,7 @@ def merge(
824826
Returns
825827
-------
826828
xr.DataArray
829+
The output's ``attrs['crs']`` is always stored in WKT format.
827830
"""
828831
from ._crs_utils import _require_pyproj
829832

xrspatial/reproject/_datum_grids.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ def load_grid(grid_key):
169169
b = ds.bounds
170170
bounds = (b.left, b.bottom, b.right, b.top)
171171
h, w = ds.height, ds.width
172+
# Validate grid shape and bounds
173+
if dlat.shape != dlon.shape:
174+
return None
175+
if h < 2 or w < 2:
176+
return None
177+
if b.left >= b.right or b.bottom >= b.top:
178+
return None
172179
# Compute resolution from bounds and shape (avoids ds.res ordering ambiguity)
173180
res_x = (b.right - b.left) / w if w > 1 else 0.25
174181
res_y = (b.top - b.bottom) / h if h > 1 else 0.25
@@ -191,10 +198,19 @@ def load_grid(grid_key):
191198
else:
192199
return None
193200

201+
# Validate grid shape and bounds
202+
if dlat.shape != dlon.shape:
203+
return None
204+
if dlat.shape[0] < 2 or dlat.shape[1] < 2:
205+
return None
206+
194207
y_coords = da.coords['y'].values
195208
x_coords = da.coords['x'].values
196209
bounds = (float(x_coords[0]), float(y_coords[-1]),
197210
float(x_coords[-1]), float(y_coords[0]))
211+
left, bottom, right, top = bounds
212+
if left >= right or bottom >= top:
213+
return None
198214
res_x = abs(float(x_coords[1] - x_coords[0])) if len(x_coords) > 1 else 0.25
199215
res_y = abs(float(y_coords[1] - y_coords[0])) if len(y_coords) > 1 else 0.25
200216
return dlat, dlon, bounds, (res_x, res_y)

xrspatial/reproject/_interpolate.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ def _resample_nearest_jit(src, row_coords, col_coords, nodata):
6767

6868
@njit(nogil=True, cache=True)
6969
def _resample_cubic_jit(src, row_coords, col_coords, nodata):
70-
"""Catmull-Rom cubic resampling with NaN propagation.
70+
"""Catmull-Rom cubic resampling with NaN-aware fallback to bilinear.
7171
7272
Separable: interpolate 4 row-slices along columns, then combine
73-
along rows. Handles NaN inline (no second pass needed).
73+
along rows. When any of the 16 neighbors is NaN, falls back to
74+
bilinear with weight renormalization (matching GDAL behavior).
7475
"""
7576
h_out, w_out = row_coords.shape
7677
sh, sw = src.shape
@@ -145,7 +146,51 @@ def _resample_cubic_jit(src, row_coords, col_coords, nodata):
145146
else:
146147
val += rv * wr3
147148

148-
out[i, j] = nodata if has_nan else val
149+
if not has_nan:
150+
out[i, j] = val
151+
else:
152+
# Fall back to bilinear with weight renormalization
153+
r1 = r0 + 1
154+
c1 = c0 + 1
155+
dr = r - r0
156+
dc = c - c0
157+
158+
w00 = (1.0 - dr) * (1.0 - dc)
159+
w01 = (1.0 - dr) * dc
160+
w10 = dr * (1.0 - dc)
161+
w11 = dr * dc
162+
163+
accum = 0.0
164+
wsum = 0.0
165+
166+
if 0 <= r0 < sh and 0 <= c0 < sw:
167+
v = src[r0, c0]
168+
if v == v:
169+
accum += w00 * v
170+
wsum += w00
171+
172+
if 0 <= r0 < sh and 0 <= c1 < sw:
173+
v = src[r0, c1]
174+
if v == v:
175+
accum += w01 * v
176+
wsum += w01
177+
178+
if 0 <= r1 < sh and 0 <= c0 < sw:
179+
v = src[r1, c0]
180+
if v == v:
181+
accum += w10 * v
182+
wsum += w10
183+
184+
if 0 <= r1 < sh and 0 <= c1 < sw:
185+
v = src[r1, c1]
186+
if v == v:
187+
accum += w11 * v
188+
wsum += w11
189+
190+
if wsum > 1e-10:
191+
out[i, j] = accum / wsum
192+
else:
193+
out[i, j] = nodata
149194
return out
150195

151196

xrspatial/tests/test_reproject.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,54 @@ def test_merge_invalid_strategy(self):
577577
with pytest.raises(ValueError, match="strategy"):
578578
merge([raster], strategy='median')
579579

580+
def test_merge_strategy_last(self):
581+
"""merge() with strategy='last' uses the last valid value."""
582+
from xrspatial.reproject import merge
583+
a = _make_raster(
584+
np.full((16, 16), 10.0), x_range=(-5, 5), y_range=(-5, 5)
585+
)
586+
b = _make_raster(
587+
np.full((16, 16), 20.0), x_range=(-5, 5), y_range=(-5, 5)
588+
)
589+
result = merge([a, b], strategy='last', resolution=1.0)
590+
vals = result.values
591+
interior = vals[2:-2, 2:-2]
592+
valid = ~np.isnan(interior) & (interior != 0)
593+
if valid.any():
594+
np.testing.assert_allclose(interior[valid], 20.0, atol=1.0)
595+
596+
def test_merge_strategy_max(self):
597+
"""merge() with strategy='max' takes the maximum."""
598+
from xrspatial.reproject import merge
599+
a = _make_raster(
600+
np.full((16, 16), 10.0), x_range=(-5, 5), y_range=(-5, 5)
601+
)
602+
b = _make_raster(
603+
np.full((16, 16), 20.0), x_range=(-5, 5), y_range=(-5, 5)
604+
)
605+
result = merge([a, b], strategy='max', resolution=1.0)
606+
vals = result.values
607+
interior = vals[2:-2, 2:-2]
608+
valid = ~np.isnan(interior) & (interior != 0)
609+
if valid.any():
610+
np.testing.assert_allclose(interior[valid], 20.0, atol=1.0)
611+
612+
def test_merge_strategy_min(self):
613+
"""merge() with strategy='min' takes the minimum."""
614+
from xrspatial.reproject import merge
615+
a = _make_raster(
616+
np.full((16, 16), 10.0), x_range=(-5, 5), y_range=(-5, 5)
617+
)
618+
b = _make_raster(
619+
np.full((16, 16), 20.0), x_range=(-5, 5), y_range=(-5, 5)
620+
)
621+
result = merge([a, b], strategy='min', resolution=1.0)
622+
vals = result.values
623+
interior = vals[2:-2, 2:-2]
624+
valid = ~np.isnan(interior) & (interior != 0)
625+
if valid.any():
626+
np.testing.assert_allclose(interior[valid], 10.0, atol=1.0)
627+
580628
@pytest.mark.skipif(not HAS_DASK, reason="dask required")
581629
def test_merge_dask(self):
582630
from xrspatial.reproject import merge

0 commit comments

Comments
 (0)