Skip to content

Commit 0bb5437

Browse files
committed
Support multi-band (RGB/RGBA) raster reprojection (#1045)
Multi-band rasters (y, x, band) now reproject correctly: - Each band is reprojected independently using shared coordinates (coordinate transform computed once, reused for all bands) - Output preserves the band dimension name and coordinates - Works with any dtype (float32, uint8 with clamping, etc.) - Custom band dim names (e.g. 'channel') preserved Also fixed spatial dimension detection to use name-based lookup (_find_spatial_dims) instead of hardcoded dims[-2]/dims[-1], which failed for 3D rasters where the band dim was last. Previously crashed with TypingError on 3D input.
1 parent ec87f1a commit 0bb5437

File tree

1 file changed

+67
-12
lines changed

1 file changed

+67
-12
lines changed

xrspatial/reproject/__init__.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,32 @@
5151
# Source geometry helpers
5252
# ---------------------------------------------------------------------------
5353

54+
_Y_NAMES = {'y', 'lat', 'latitude', 'Y', 'Lat', 'Latitude'}
55+
_X_NAMES = {'x', 'lon', 'longitude', 'X', 'Lon', 'Longitude'}
56+
57+
58+
def _find_spatial_dims(raster):
59+
"""Find the y and x dimension names, handling multi-band rasters.
60+
61+
Returns (ydim, xdim). Checks dim names first, falls back to
62+
assuming the last two non-band dims are spatial.
63+
"""
64+
dims = raster.dims
65+
ydim = xdim = None
66+
for d in dims:
67+
if d in _Y_NAMES:
68+
ydim = d
69+
elif d in _X_NAMES:
70+
xdim = d
71+
if ydim is not None and xdim is not None:
72+
return ydim, xdim
73+
# Fallback: last two dims
74+
return dims[-2], dims[-1]
75+
76+
5477
def _source_bounds(raster):
5578
"""Extract (left, bottom, right, top) from a DataArray's coordinates."""
56-
ydim = raster.dims[-2]
57-
xdim = raster.dims[-1]
79+
ydim, xdim = _find_spatial_dims(raster)
5880
y = raster.coords[ydim].values
5981
x = raster.coords[xdim].values
6082
# Compute pixel-edge bounds from pixel-center coords
@@ -77,7 +99,7 @@ def _source_bounds(raster):
7799

78100
def _is_y_descending(raster):
79101
"""Check if Y axis goes from top (large) to bottom (small)."""
80-
ydim = raster.dims[-2]
102+
ydim, _ = _find_spatial_dims(raster)
81103
y = raster.coords[ydim].values
82104
if len(y) < 2:
83105
return True
@@ -240,17 +262,36 @@ def _reproject_chunk_numpy(
240262
window = window.compute()
241263
window = np.asarray(window)
242264
orig_dtype = window.dtype
265+
266+
# Adjust coordinates relative to window
267+
local_row = src_row_px - r_min_clip
268+
local_col = src_col_px - c_min_clip
269+
270+
# Multi-band: reproject each band separately, share coordinates
271+
if window.ndim == 3:
272+
n_bands = window.shape[2]
273+
bands = []
274+
for b in range(n_bands):
275+
band_data = window[:, :, b].astype(np.float64)
276+
if not np.isnan(nodata):
277+
band_data = band_data.copy()
278+
band_data[band_data == nodata] = np.nan
279+
band_result = _resample_numpy(band_data, local_row, local_col,
280+
resampling=resampling, nodata=nodata)
281+
if np.issubdtype(orig_dtype, np.integer):
282+
info = np.iinfo(orig_dtype)
283+
band_result = np.clip(np.round(band_result), info.min, info.max).astype(orig_dtype)
284+
bands.append(band_result)
285+
return np.stack(bands, axis=-1)
286+
287+
# Single-band path
243288
window = window.astype(np.float64)
244289

245290
# Convert sentinel nodata to NaN so numba kernels can detect it
246291
if not np.isnan(nodata):
247292
window = window.copy()
248293
window[window == nodata] = np.nan
249294

250-
# Adjust coordinates relative to window
251-
local_row = src_row_px - r_min_clip
252-
local_col = src_col_px - c_min_clip
253-
254295
result = _resample_numpy(window, local_row, local_col,
255296
resampling=resampling, nodata=nodata)
256297

@@ -482,7 +523,8 @@ def reproject(
482523

483524
# Source geometry
484525
src_bounds = _source_bounds(raster)
485-
src_shape = (raster.sizes[raster.dims[-2]], raster.sizes[raster.dims[-1]])
526+
_ydim, _xdim = _find_spatial_dims(raster)
527+
src_shape = (raster.sizes[_ydim], raster.sizes[_xdim])
486528
y_desc = _is_y_descending(raster)
487529

488530
# Compute output grid
@@ -560,18 +602,31 @@ def reproject(
560602
tgt_crs_wkt=tgt_wkt,
561603
)
562604

563-
ydim = raster.dims[-2]
564-
xdim = raster.dims[-1]
605+
ydim, xdim = _find_spatial_dims(raster)
565606
out_attrs = {
566607
'crs': tgt_wkt,
567608
'nodata': nd,
568609
}
569610
if tgt_vertical_crs is not None:
570611
out_attrs['vertical_crs'] = tgt_vertical_crs
612+
613+
# Handle multi-band output (3D result from multi-band source)
614+
if result_data.ndim == 3:
615+
# Find the band dimension name from the source
616+
band_dims = [d for d in raster.dims if d not in (ydim, xdim)]
617+
band_dim = band_dims[0] if band_dims else 'band'
618+
out_dims = [ydim, xdim, band_dim]
619+
out_coords = {ydim: y_coords, xdim: x_coords}
620+
if band_dim in raster.coords:
621+
out_coords[band_dim] = raster.coords[band_dim]
622+
else:
623+
out_dims = [ydim, xdim]
624+
out_coords = {ydim: y_coords, xdim: x_coords}
625+
571626
result = xr.DataArray(
572627
result_data,
573-
dims=[ydim, xdim],
574-
coords={ydim: y_coords, xdim: x_coords},
628+
dims=out_dims,
629+
coords=out_coords,
575630
name=name or raster.name,
576631
attrs=out_attrs,
577632
)

0 commit comments

Comments
 (0)