5151# Source geometry helpers
5252# ---------------------------------------------------------------------------
5353
54+ _Y_NAMES = {'y' , 'lat' , 'latitude' , 'Y' , 'Lat' , 'Latitude' }
55+ _X_NAMES = {'x' , 'lon' , 'longitude' , 'X' , 'Lon' , 'Longitude' }
56+
57+
58+ def _find_spatial_dims (raster ):
59+ """Find the y and x dimension names, handling multi-band rasters.
60+
61+ Returns (ydim, xdim). Checks dim names first, falls back to
62+ assuming the last two non-band dims are spatial.
63+ """
64+ dims = raster .dims
65+ ydim = xdim = None
66+ for d in dims :
67+ if d in _Y_NAMES :
68+ ydim = d
69+ elif d in _X_NAMES :
70+ xdim = d
71+ if ydim is not None and xdim is not None :
72+ return ydim , xdim
73+ # Fallback: last two dims
74+ return dims [- 2 ], dims [- 1 ]
75+
76+
5477def _source_bounds (raster ):
5578 """Extract (left, bottom, right, top) from a DataArray's coordinates."""
56- ydim = raster .dims [- 2 ]
57- xdim = raster .dims [- 1 ]
79+ ydim , xdim = _find_spatial_dims (raster )
5880 y = raster .coords [ydim ].values
5981 x = raster .coords [xdim ].values
6082 # Compute pixel-edge bounds from pixel-center coords
@@ -77,7 +99,7 @@ def _source_bounds(raster):
7799
78100def _is_y_descending (raster ):
79101 """Check if Y axis goes from top (large) to bottom (small)."""
80- ydim = raster . dims [ - 2 ]
102+ ydim , _ = _find_spatial_dims ( raster )
81103 y = raster .coords [ydim ].values
82104 if len (y ) < 2 :
83105 return True
@@ -240,17 +262,36 @@ def _reproject_chunk_numpy(
240262 window = window .compute ()
241263 window = np .asarray (window )
242264 orig_dtype = window .dtype
265+
266+ # Adjust coordinates relative to window
267+ local_row = src_row_px - r_min_clip
268+ local_col = src_col_px - c_min_clip
269+
270+ # Multi-band: reproject each band separately, share coordinates
271+ if window .ndim == 3 :
272+ n_bands = window .shape [2 ]
273+ bands = []
274+ for b in range (n_bands ):
275+ band_data = window [:, :, b ].astype (np .float64 )
276+ if not np .isnan (nodata ):
277+ band_data = band_data .copy ()
278+ band_data [band_data == nodata ] = np .nan
279+ band_result = _resample_numpy (band_data , local_row , local_col ,
280+ resampling = resampling , nodata = nodata )
281+ if np .issubdtype (orig_dtype , np .integer ):
282+ info = np .iinfo (orig_dtype )
283+ band_result = np .clip (np .round (band_result ), info .min , info .max ).astype (orig_dtype )
284+ bands .append (band_result )
285+ return np .stack (bands , axis = - 1 )
286+
287+ # Single-band path
243288 window = window .astype (np .float64 )
244289
245290 # Convert sentinel nodata to NaN so numba kernels can detect it
246291 if not np .isnan (nodata ):
247292 window = window .copy ()
248293 window [window == nodata ] = np .nan
249294
250- # Adjust coordinates relative to window
251- local_row = src_row_px - r_min_clip
252- local_col = src_col_px - c_min_clip
253-
254295 result = _resample_numpy (window , local_row , local_col ,
255296 resampling = resampling , nodata = nodata )
256297
@@ -482,7 +523,8 @@ def reproject(
482523
483524 # Source geometry
484525 src_bounds = _source_bounds (raster )
485- src_shape = (raster .sizes [raster .dims [- 2 ]], raster .sizes [raster .dims [- 1 ]])
526+ _ydim , _xdim = _find_spatial_dims (raster )
527+ src_shape = (raster .sizes [_ydim ], raster .sizes [_xdim ])
486528 y_desc = _is_y_descending (raster )
487529
488530 # Compute output grid
@@ -560,18 +602,31 @@ def reproject(
560602 tgt_crs_wkt = tgt_wkt ,
561603 )
562604
563- ydim = raster .dims [- 2 ]
564- xdim = raster .dims [- 1 ]
605+ ydim , xdim = _find_spatial_dims (raster )
565606 out_attrs = {
566607 'crs' : tgt_wkt ,
567608 'nodata' : nd ,
568609 }
569610 if tgt_vertical_crs is not None :
570611 out_attrs ['vertical_crs' ] = tgt_vertical_crs
612+
613+ # Handle multi-band output (3D result from multi-band source)
614+ if result_data .ndim == 3 :
615+ # Find the band dimension name from the source
616+ band_dims = [d for d in raster .dims if d not in (ydim , xdim )]
617+ band_dim = band_dims [0 ] if band_dims else 'band'
618+ out_dims = [ydim , xdim , band_dim ]
619+ out_coords = {ydim : y_coords , xdim : x_coords }
620+ if band_dim in raster .coords :
621+ out_coords [band_dim ] = raster .coords [band_dim ]
622+ else :
623+ out_dims = [ydim , xdim ]
624+ out_coords = {ydim : y_coords , xdim : x_coords }
625+
571626 result = xr .DataArray (
572627 result_data ,
573- dims = [ ydim , xdim ] ,
574- coords = { ydim : y_coords , xdim : x_coords } ,
628+ dims = out_dims ,
629+ coords = out_coords ,
575630 name = name or raster .name ,
576631 attrs = out_attrs ,
577632 )
0 commit comments