Skip to content

Commit f165756

Browse files
authored
geotiff: validate VRT writer source compatibility (#1733) (#1741)
1 parent 5af16e1 commit f165756

2 files changed

Lines changed: 266 additions & 5 deletions

File tree

xrspatial/geotiff/_vrt.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,10 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
783783
"""Generate a VRT file that mosaics multiple GeoTIFF tiles.
784784
785785
Each source file is placed in the virtual raster based on its
786-
geo transform. Files must share the same CRS and pixel size.
786+
geo transform. All sources must share the same pixel size, dtype
787+
(sample format + bits-per-sample), band count, and CRS. Mismatches
788+
raise ``ValueError`` rather than producing a misplaced or mis-typed
789+
mosaic.
787790
788791
Parameters
789792
----------
@@ -796,10 +799,12 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
796799
crs_wkt : str or None
797800
CRS as WKT string. If None, taken from the first source.
798801
nodata : float, int, or None
799-
NoData value. If None, taken from the first source. Integer
800-
sentinels (e.g. ``65535`` for uint16, ``-9999`` for int32) are
801-
accepted so the surface lines up with the ``nodata`` kwarg on
802-
``to_geotiff`` and ``write_geotiff_gpu``.
802+
NoData value applied to every band of the mosaic. Caller-supplied
803+
value takes precedence; when ``None``, the first source's
804+
per-band nodata is used. Integer sentinels (e.g. ``65535`` for
805+
uint16, ``-9999`` for int32) are accepted so the surface lines up
806+
with the ``nodata`` kwarg on ``to_geotiff`` and
807+
``write_geotiff_gpu``.
803808
804809
Returns
805810
-------
@@ -849,6 +854,74 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
849854
res_x = first['transform'].pixel_width
850855
res_y = first['transform'].pixel_height
851856

857+
# Enforce the docstring contract: every source must agree with the
858+
# first on pixel size, sample format + bits-per-sample (i.e. dtype),
859+
# band count, and CRS WKT. Without this, build_vrt would silently
860+
# produce a syntactically valid VRT that misplaces or mis-types data
861+
# downstream (issue #1733).
862+
#
863+
# Pixel size is compared with a small relative tolerance: TIFFs
864+
# written by different tools occasionally round the GeoTransform
865+
# slightly, and the existing mosaic-extent math rounds the final
866+
# raster size anyway. Sample format + bps must match exactly because
867+
# the VRT dtype is taken from ``first`` and applied to every band.
868+
_PIXEL_SIZE_RTOL = 1e-6
869+
870+
def _pixel_size_mismatch(a: float, b: float) -> bool:
871+
# Use a relative tolerance on the magnitude rather than a direct
872+
# ``a == b`` check so that minor float rounding between tools is
873+
# ignored. ``pixel_height`` is negative for the common north-up
874+
# case; a flipped sign would change the magnitude ratio above
875+
# the tolerance and so is correctly treated as a mismatch (which
876+
# is the desired behavior, since opposite-sign sources do not
877+
# stack into a valid VRT).
878+
if a == b:
879+
return False
880+
denom = max(abs(a), abs(b))
881+
if denom == 0.0:
882+
return abs(a - b) > 0.0
883+
return abs(a - b) / denom > _PIXEL_SIZE_RTOL
884+
885+
first_crs = first.get('crs_wkt')
886+
for m in sources_meta[1:]:
887+
t = m['transform']
888+
if _pixel_size_mismatch(t.pixel_width, res_x) \
889+
or _pixel_size_mismatch(t.pixel_height, res_y):
890+
raise ValueError(
891+
f"VRT source {m['path']!r} has pixel size "
892+
f"({t.pixel_width}, {t.pixel_height}) which does not "
893+
f"match the first source ({res_x}, {res_y}). All sources "
894+
f"in a VRT must share the same pixel size."
895+
)
896+
if m['sample_format'] != first['sample_format'] \
897+
or m['bps'] != first['bps']:
898+
raise ValueError(
899+
f"VRT source {m['path']!r} has sample_format="
900+
f"{m['sample_format']}, bps={m['bps']} which does not "
901+
f"match the first source (sample_format="
902+
f"{first['sample_format']}, bps={first['bps']}). All "
903+
f"sources in a VRT must share the same dtype."
904+
)
905+
if m['bands'] != first['bands']:
906+
raise ValueError(
907+
f"VRT source {m['path']!r} has {m['bands']} band(s) "
908+
f"which does not match the first source "
909+
f"({first['bands']} band(s)). All sources in a VRT "
910+
f"must share the same band count."
911+
)
912+
m_crs = m.get('crs_wkt')
913+
# Treat asymmetric CRS (one set, one missing/empty) as a
914+
# mismatch too. If we only flagged when both were set, a source
915+
# that lost its CRS during a re-write would silently inherit the
916+
# first source's CRS in the VRT header and could end up tagged
917+
# with the wrong projection.
918+
if (first_crs or m_crs) and m_crs != first_crs:
919+
raise ValueError(
920+
f"VRT source {m['path']!r} has CRS WKT that does not "
921+
f"match the first source. All sources in a VRT must "
922+
f"share the same CRS."
923+
)
924+
852925
# Compute the bounding box of all sources
853926
all_x0, all_y0, all_x1, all_y1 = [], [], [], []
854927
for m in sources_meta:
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""Regression tests for issue #1733.
2+
3+
``write_vrt`` previously trusted the first source for resolution,
4+
sample format + bps (dtype), band count, and CRS. A mismatched source
5+
would silently produce a VRT that placed pixels incorrectly or
6+
re-interpreted bytes as the wrong dtype downstream.
7+
8+
These tests assert that ``write_vrt`` now rejects mismatched sources
9+
with a clear ``ValueError`` covering each of those properties, and
10+
still accepts sources that match within a small float tolerance on
11+
pixel size.
12+
"""
13+
from __future__ import annotations
14+
15+
import os
16+
import uuid
17+
18+
import numpy as np
19+
import pytest
20+
import xarray as xr
21+
22+
from xrspatial.geotiff import to_geotiff
23+
from xrspatial.geotiff._vrt import write_vrt
24+
25+
26+
def _unique_dir(tmp_path, label: str) -> str:
27+
d = tmp_path / f"vrt_1733_{label}_{uuid.uuid4().hex[:8]}"
28+
d.mkdir()
29+
return str(d)
30+
31+
32+
def _write_tif(path: str, *, h: int, w: int, dtype, bands: int = 1,
33+
px: float = 1.0, py: float = -1.0,
34+
origin_x: float = 0.0, origin_y: float = 100.0,
35+
crs: int | None = 4326) -> None:
36+
if bands == 1:
37+
arr = np.arange(h * w, dtype=dtype).reshape(h, w)
38+
dims = ['y', 'x']
39+
else:
40+
arr = np.arange(h * w * bands, dtype=dtype).reshape(h, w, bands)
41+
dims = ['y', 'x', 'band']
42+
y = origin_y + (np.arange(h) + 0.5) * py
43+
x = origin_x + (np.arange(w) + 0.5) * px
44+
coords = {'y': y, 'x': x}
45+
attrs = {}
46+
if crs is not None:
47+
attrs['crs'] = crs
48+
da = xr.DataArray(arr, dims=dims, coords=coords, attrs=attrs)
49+
to_geotiff(da, path, compression='none')
50+
51+
52+
def test_mismatched_pixel_size_raises(tmp_path):
53+
d = _unique_dir(tmp_path, "px")
54+
a = os.path.join(d, "a.tif")
55+
b = os.path.join(d, "b.tif")
56+
_write_tif(a, h=4, w=4, dtype=np.float32, px=1.0, py=-1.0)
57+
# Place b adjacent so the geometry would otherwise work, but the
58+
# pixel size disagrees.
59+
_write_tif(b, h=4, w=4, dtype=np.float32, px=2.0, py=-2.0,
60+
origin_x=4.0)
61+
vrt = os.path.join(d, "out.vrt")
62+
with pytest.raises(ValueError, match="pixel size"):
63+
write_vrt(vrt, [a, b])
64+
65+
66+
def test_mismatched_dtype_raises(tmp_path):
67+
d = _unique_dir(tmp_path, "dtype")
68+
a = os.path.join(d, "a.tif")
69+
b = os.path.join(d, "b.tif")
70+
_write_tif(a, h=4, w=4, dtype=np.float32)
71+
_write_tif(b, h=4, w=4, dtype=np.int16, origin_x=4.0)
72+
vrt = os.path.join(d, "out.vrt")
73+
with pytest.raises(ValueError, match="dtype|sample_format|bps"):
74+
write_vrt(vrt, [a, b])
75+
76+
77+
def test_mismatched_band_count_raises(tmp_path):
78+
d = _unique_dir(tmp_path, "bands")
79+
a = os.path.join(d, "a.tif")
80+
b = os.path.join(d, "b.tif")
81+
_write_tif(a, h=4, w=4, dtype=np.float32, bands=1)
82+
_write_tif(b, h=4, w=4, dtype=np.float32, bands=3, origin_x=4.0)
83+
vrt = os.path.join(d, "out.vrt")
84+
with pytest.raises(ValueError, match="band count"):
85+
write_vrt(vrt, [a, b])
86+
87+
88+
def test_compatible_sources_succeed(tmp_path):
89+
d = _unique_dir(tmp_path, "ok")
90+
a = os.path.join(d, "a.tif")
91+
b = os.path.join(d, "b.tif")
92+
_write_tif(a, h=4, w=4, dtype=np.float32)
93+
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0)
94+
vrt = os.path.join(d, "out.vrt")
95+
write_vrt(vrt, [a, b])
96+
assert os.path.exists(vrt)
97+
98+
99+
def test_pixel_size_within_tolerance_accepted(tmp_path):
100+
d = _unique_dir(tmp_path, "tol")
101+
a = os.path.join(d, "a.tif")
102+
b = os.path.join(d, "b.tif")
103+
_write_tif(a, h=4, w=4, dtype=np.float32, px=1.0, py=-1.0)
104+
# Drift well below the 1e-6 relative tolerance.
105+
_write_tif(b, h=4, w=4, dtype=np.float32,
106+
px=1.0 + 1e-10, py=-1.0, origin_x=4.0)
107+
vrt = os.path.join(d, "out.vrt")
108+
write_vrt(vrt, [a, b])
109+
assert os.path.exists(vrt)
110+
111+
112+
def test_single_source_still_works(tmp_path):
113+
d = _unique_dir(tmp_path, "one")
114+
a = os.path.join(d, "a.tif")
115+
_write_tif(a, h=4, w=4, dtype=np.float32)
116+
vrt = os.path.join(d, "out.vrt")
117+
write_vrt(vrt, [a])
118+
assert os.path.exists(vrt)
119+
120+
121+
def test_mismatched_crs_raises(tmp_path):
122+
# Two sources with different non-empty CRS values must be rejected,
123+
# otherwise the VRT would inherit the first source's CRS and silently
124+
# misproject the second.
125+
d = _unique_dir(tmp_path, "crs_diff")
126+
a = os.path.join(d, "a.tif")
127+
b = os.path.join(d, "b.tif")
128+
_write_tif(a, h=4, w=4, dtype=np.float32, crs=4326)
129+
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=3857)
130+
vrt = os.path.join(d, "out.vrt")
131+
with pytest.raises(ValueError, match="CRS"):
132+
write_vrt(vrt, [a, b])
133+
134+
135+
def test_asymmetric_crs_raises_first_set_second_missing(tmp_path):
136+
# First source has a CRS, second is written without one. The VRT
137+
# would otherwise be tagged with the first source's CRS, which can
138+
# misplace data when the second source actually came from a
139+
# different (or unknown) projection.
140+
d = _unique_dir(tmp_path, "crs_first")
141+
a = os.path.join(d, "a.tif")
142+
b = os.path.join(d, "b.tif")
143+
_write_tif(a, h=4, w=4, dtype=np.float32, crs=4326)
144+
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=None)
145+
vrt = os.path.join(d, "out.vrt")
146+
with pytest.raises(ValueError, match="CRS"):
147+
write_vrt(vrt, [a, b])
148+
149+
150+
def test_asymmetric_crs_raises_first_missing_second_set(tmp_path):
151+
# Symmetric case: first source missing a CRS, second has one. The
152+
# earlier guard only triggered when both sides were set, so this
153+
# would have silently produced an untagged VRT despite one source
154+
# carrying a known projection.
155+
d = _unique_dir(tmp_path, "crs_second")
156+
a = os.path.join(d, "a.tif")
157+
b = os.path.join(d, "b.tif")
158+
_write_tif(a, h=4, w=4, dtype=np.float32, crs=None)
159+
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=4326)
160+
vrt = os.path.join(d, "out.vrt")
161+
with pytest.raises(ValueError, match="CRS"):
162+
write_vrt(vrt, [a, b])
163+
164+
165+
def test_matching_crs_succeeds(tmp_path):
166+
# Sanity check: two sources with the same CRS should still be
167+
# accepted (defends against an overly aggressive equality check).
168+
d = _unique_dir(tmp_path, "crs_match")
169+
a = os.path.join(d, "a.tif")
170+
b = os.path.join(d, "b.tif")
171+
_write_tif(a, h=4, w=4, dtype=np.float32, crs=4326)
172+
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=4326)
173+
vrt = os.path.join(d, "out.vrt")
174+
write_vrt(vrt, [a, b])
175+
assert os.path.exists(vrt)
176+
177+
178+
def test_both_missing_crs_succeeds(tmp_path):
179+
# If neither source has a CRS, the VRT just won't be tagged with one
180+
# and there's nothing to mis-tag. This must not raise.
181+
d = _unique_dir(tmp_path, "crs_both_missing")
182+
a = os.path.join(d, "a.tif")
183+
b = os.path.join(d, "b.tif")
184+
_write_tif(a, h=4, w=4, dtype=np.float32, crs=None)
185+
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=None)
186+
vrt = os.path.join(d, "out.vrt")
187+
write_vrt(vrt, [a, b])
188+
assert os.path.exists(vrt)

0 commit comments

Comments
 (0)