Skip to content

Commit babb72e

Browse files
authored
Fix _coords_to_transform for 3D (y,x,band) DataArrays (#1643) (#1648)
* Fix _coords_to_transform for 3D (y,x,band) DataArrays (#1643) _coords_to_transform read y/x coords via dims[-2:] which on a 3D (y, x, band) DataArray picked (x, band) instead of (y, x). to_geotiff and write_geotiff_gpu silently emitted a wrong GeoTransform on the fallback path when attrs['transform'] was absent (the round-tripped file used the band axis spacing as pixel_width). The helper now skips any trailing/leading dim named band/bands/channel and uses the two remaining spatial dims. 2D inputs and 3D (band, y, x) inputs are both handled. * Address Copilot review feedback on #1648 - Lift _BAND_DIM_NAMES to module scope and reuse at the three (band,y,x) remap sites in __init__.py to avoid drift between _coords_to_transform and the writer paths. - Reword _coords_to_transform docstring: filter is position-independent, not trailing/leading. - Drop unused os/tempfile imports from the regression test. - Replace `import cupy` guard with the repo's standard _gpu_available() pattern that also checks `cupy.cuda.is_available()` and swallows non-ImportError import failures. - Add parametrized helper coverage for 'bands' and 'channel' dim names.
1 parent 46f567f commit babb72e

2 files changed

Lines changed: 249 additions & 5 deletions

File tree

xrspatial/geotiff/__init__.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@
6868
_GPU_DEPRECATED_SENTINEL = object()
6969
_ON_GPU_FAILURE_SENTINEL = object()
7070

71+
# Names of dims that ``to_geotiff`` / ``write_geotiff_gpu`` treat as the
72+
# non-spatial band axis. Used both to remap ``(band, y, x)`` inputs to
73+
# ``(y, x, band)`` before writing and to skip the band axis when inferring
74+
# a GeoTransform from coords (see ``_coords_to_transform`` and issue #1643).
75+
_BAND_DIM_NAMES = ('band', 'bands', 'channel')
76+
7177

7278
def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
7379
"""Try to extract an EPSG code from a WKT or PROJ string.
@@ -191,9 +197,34 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None:
191197
on raster_type:
192198
- PixelIsArea (default): origin = center - half_pixel (edge of pixel 0)
193199
- PixelIsPoint: origin = center (center of pixel 0)
200+
201+
For 3D arrays the spatial dims are the two non-band dims. The helper
202+
filters out any dim named ``band`` / ``bands`` / ``channel`` (see
203+
``_BAND_DIM_NAMES``) regardless of position, so a ``(y, x, band)``,
204+
``(band, y, x)``, or ``(y, band, x)`` DataArray returns the y/x
205+
transform rather than picking up the band axis spacing as a pixel
206+
size. ``to_geotiff`` itself remaps ``(band, y, x)`` arrays to
207+
``(y, x, band)`` before writing pixel bytes, but it calls
208+
:func:`_coords_to_transform` against the original DataArray, so the
209+
helper must handle both layouts to keep the geo-transform consistent
210+
with the file's coord arrays. See issue #1643.
194211
"""
195-
ydim = da.dims[-2]
196-
xdim = da.dims[-1]
212+
if da.ndim == 3:
213+
# Drop the band-like dim and keep the two spatial dims in their
214+
# original (y, x) order. Position-based fallback covers the case
215+
# where none of the dims are named like a band axis.
216+
spatial = tuple(d for d in da.dims if d not in _BAND_DIM_NAMES)
217+
if len(spatial) == 2:
218+
ydim, xdim = spatial[0], spatial[1]
219+
else:
220+
# No identifiable band dim; fall back to dims[-2:] so the
221+
# original 2-D-style behaviour applies. This branch only
222+
# triggers for unusual 3D layouts callers built by hand.
223+
ydim = da.dims[-2]
224+
xdim = da.dims[-1]
225+
else:
226+
ydim = da.dims[-2]
227+
xdim = da.dims[-1]
197228

198229
if xdim not in da.coords or ydim not in da.coords:
199230
return None
@@ -1166,7 +1197,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *,
11661197
if hasattr(raw, 'dask') and not cog and not _path_is_file_like:
11671198
dask_arr = raw
11681199
# Handle band-first dimension order (band, y, x) -> (y, x, band)
1169-
if raw.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
1200+
if raw.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES:
11701201
import dask.array as da
11711202
dask_arr = da.moveaxis(raw, 0, -1)
11721203
if dask_arr.ndim not in (2, 3):
@@ -1215,7 +1246,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *,
12151246
else:
12161247
arr = np.asarray(raw)
12171248
# Handle band-first dimension order (band, y, x) -> (y, x, band)
1218-
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
1249+
if arr.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES:
12191250
arr = np.moveaxis(arr, 0, -1)
12201251
else:
12211252
if hasattr(data, 'get'):
@@ -2830,7 +2861,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray,
28302861
# this remap the writer treats arr.shape[2] as the band axis and
28312862
# produces a transposed file (issue #1580). The CPU writer does
28322863
# the same remap at the matching step in to_geotiff().
2833-
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
2864+
if arr.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES:
28342865
arr = cupy.ascontiguousarray(cupy.moveaxis(arr, 0, -1))
28352866

28362867
# Prefer attrs['transform'] over the coord-derived transform: it
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""Regression test for issue #1643.
2+
3+
``_coords_to_transform`` previously used ``dims[-2]`` and ``dims[-1]`` to
4+
look up y/x coords. On a 3D ``(y, x, band)`` DataArray that picked
5+
``x`` and ``band``, so ``to_geotiff`` silently wrote a wrong
6+
GeoTransform when ``attrs['transform']`` was absent. The helper now
7+
detects the band-like trailing/leading dim and uses the two spatial
8+
dims regardless of position.
9+
"""
10+
from __future__ import annotations
11+
12+
import importlib.util
13+
14+
import numpy as np
15+
import pytest
16+
import xarray as xr
17+
18+
from xrspatial.geotiff import _coords_to_transform, open_geotiff, to_geotiff
19+
20+
21+
def _gpu_available() -> bool:
22+
if importlib.util.find_spec("cupy") is None:
23+
return False
24+
try:
25+
import cupy
26+
return bool(cupy.cuda.is_available())
27+
except Exception:
28+
return False
29+
30+
31+
_HAS_GPU = _gpu_available()
32+
33+
34+
def _make_geo_da_3d(dims):
35+
"""3D DataArray with georeferenced y/x coords and a band axis."""
36+
shape = []
37+
for d in dims:
38+
if d in ('y',):
39+
shape.append(10)
40+
elif d in ('x',):
41+
shape.append(20)
42+
else:
43+
shape.append(3)
44+
arr = np.arange(int(np.prod(shape)), dtype=np.uint8).reshape(shape)
45+
coords = {
46+
'y': np.linspace(100.0, 200.0, 10),
47+
'x': np.linspace(500.0, 700.0, 20),
48+
'band': np.arange(3),
49+
}
50+
return xr.DataArray(arr, dims=list(dims), coords=coords)
51+
52+
53+
def test_coords_to_transform_yxband_returns_yx_spacing():
54+
"""3D (y, x, band) picks y/x spacing rather than (x, band) spacing."""
55+
da = _make_geo_da_3d(('y', 'x', 'band'))
56+
gt = _coords_to_transform(da)
57+
# y spacing = (200 - 100) / 9, x spacing = (700 - 500) / 19
58+
assert gt is not None
59+
np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19)
60+
np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9)
61+
62+
63+
def test_coords_to_transform_bandyx_returns_yx_spacing():
64+
"""3D (band, y, x) also returns the y/x transform."""
65+
da = _make_geo_da_3d(('band', 'y', 'x'))
66+
gt = _coords_to_transform(da)
67+
assert gt is not None
68+
np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19)
69+
np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9)
70+
71+
72+
@pytest.mark.parametrize('band_name', ['band', 'bands', 'channel'])
73+
def test_coords_to_transform_3d_band_name_variants(band_name):
74+
"""All recognized band-dim names (band, bands, channel) are filtered
75+
out when picking the y/x spatial dims."""
76+
arr = np.zeros((10, 20, 3), dtype=np.uint8)
77+
da = xr.DataArray(
78+
arr,
79+
dims=['y', 'x', band_name],
80+
coords={
81+
'y': np.linspace(100.0, 200.0, 10),
82+
'x': np.linspace(500.0, 700.0, 20),
83+
band_name: np.arange(3),
84+
},
85+
)
86+
gt = _coords_to_transform(da)
87+
assert gt is not None
88+
np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19)
89+
np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9)
90+
91+
92+
def test_coords_to_transform_2d_unchanged():
93+
"""2D (y, x) keeps its original behaviour."""
94+
da = xr.DataArray(
95+
np.zeros((10, 20), dtype=np.uint8),
96+
dims=['y', 'x'],
97+
coords={
98+
'y': np.linspace(100.0, 200.0, 10),
99+
'x': np.linspace(500.0, 700.0, 20),
100+
},
101+
)
102+
gt = _coords_to_transform(da)
103+
assert gt is not None
104+
np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19)
105+
np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9)
106+
107+
108+
def test_to_geotiff_roundtrip_3d_yxband_no_transform_attr(tmp_path):
109+
"""to_geotiff -> open_geotiff round-trip on 3D arrays preserves coords.
110+
111+
Before the fix the on-disk transform was derived from (x, band)
112+
spacing, so the round-tripped y/x coords had wrong pixel size and
113+
origin. After the fix the 3D output matches the 2D output.
114+
"""
115+
da_3d = _make_geo_da_3d(('y', 'x', 'band'))
116+
da_2d = xr.DataArray(
117+
np.zeros((10, 20), dtype=np.uint8),
118+
dims=['y', 'x'],
119+
coords={
120+
'y': np.linspace(100.0, 200.0, 10),
121+
'x': np.linspace(500.0, 700.0, 20),
122+
},
123+
)
124+
125+
p2 = str(tmp_path / 'roundtrip_1643_2d.tif')
126+
p3 = str(tmp_path / 'roundtrip_1643_3d.tif')
127+
to_geotiff(da_2d, p2)
128+
to_geotiff(da_3d, p3)
129+
130+
rt2 = open_geotiff(p2)
131+
rt3 = open_geotiff(p3)
132+
np.testing.assert_allclose(rt3.y.values, rt2.y.values)
133+
np.testing.assert_allclose(rt3.x.values, rt2.x.values)
134+
assert rt3.attrs.get('transform') == rt2.attrs.get('transform')
135+
136+
137+
def test_to_geotiff_roundtrip_3d_bandyx_no_transform_attr(tmp_path):
138+
"""(band, y, x) input round-trips with the correct transform.
139+
140+
``to_geotiff`` remaps a (band, y, x) input to (y, x, band) before
141+
writing, but ``_coords_to_transform`` runs against the original
142+
dim order. The fix handles both 3D layouts.
143+
"""
144+
da_3d = _make_geo_da_3d(('band', 'y', 'x'))
145+
da_2d = xr.DataArray(
146+
np.zeros((10, 20), dtype=np.uint8),
147+
dims=['y', 'x'],
148+
coords={
149+
'y': np.linspace(100.0, 200.0, 10),
150+
'x': np.linspace(500.0, 700.0, 20),
151+
},
152+
)
153+
154+
p2 = str(tmp_path / 'roundtrip_1643_2d_b.tif')
155+
p3 = str(tmp_path / 'roundtrip_1643_3d_bandfirst.tif')
156+
to_geotiff(da_2d, p2)
157+
to_geotiff(da_3d, p3)
158+
159+
rt2 = open_geotiff(p2)
160+
rt3 = open_geotiff(p3)
161+
np.testing.assert_allclose(rt3.y.values, rt2.y.values)
162+
np.testing.assert_allclose(rt3.x.values, rt2.x.values)
163+
164+
165+
def test_to_geotiff_3d_without_transform_attr_does_not_invent_unit_pixels(
166+
tmp_path):
167+
"""Regression sanity: the bad transform was pixel_width=1.0 (band
168+
axis spacing). Assert the round-tripped pixel_width is finite,
169+
non-unit, and matches the source x spacing.
170+
"""
171+
da = _make_geo_da_3d(('y', 'x', 'band'))
172+
p = str(tmp_path / 'roundtrip_1643_3d_not_unit.tif')
173+
to_geotiff(da, p)
174+
rt = open_geotiff(p)
175+
pw = abs(float(rt.x.values[1] - rt.x.values[0]))
176+
# Source x spacing is (700-500)/19 = ~10.526. The buggy path would
177+
# have produced pw=1.0 (the band axis spacing).
178+
assert pw > 1.5, (
179+
f"round-tripped pixel_width={pw} suggests the band-axis spacing "
180+
f"leaked into the GeoTransform; expected ~10.526")
181+
182+
183+
@pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required")
184+
def test_write_geotiff_gpu_roundtrip_3d_no_transform_attr(tmp_path):
185+
"""GPU writer shares ``_coords_to_transform`` with the CPU writer.
186+
187+
Same regression on the GPU path: a 3D ``(y, x, band)`` cupy
188+
DataArray without ``attrs['transform']`` would previously round-trip
189+
through a unit pixel-width transform.
190+
"""
191+
import cupy as cp
192+
193+
from xrspatial.geotiff import write_geotiff_gpu
194+
195+
np_arr = np.arange(10 * 20 * 3, dtype=np.uint8).reshape(10, 20, 3)
196+
da = xr.DataArray(
197+
cp.asarray(np_arr),
198+
dims=['y', 'x', 'band'],
199+
coords={
200+
'y': np.linspace(100.0, 200.0, 10),
201+
'x': np.linspace(500.0, 700.0, 20),
202+
'band': np.arange(3),
203+
},
204+
)
205+
p = str(tmp_path / 'roundtrip_1643_3d_gpu.tif')
206+
write_geotiff_gpu(da, p)
207+
rt = open_geotiff(p)
208+
pw = abs(float(rt.x.values[1] - rt.x.values[0]))
209+
assert pw > 1.5, (
210+
f"GPU writer round-tripped pixel_width={pw}; expected ~10.526")
211+
ph = abs(float(rt.y.values[1] - rt.y.values[0]))
212+
assert ph > 1.5, (
213+
f"GPU writer round-tripped pixel_height={ph}; expected ~11.111")

0 commit comments

Comments
 (0)