diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py index 090cff0b..fd4c6d04 100644 --- a/xrspatial/geotiff/_vrt.py +++ b/xrspatial/geotiff/_vrt.py @@ -783,7 +783,10 @@ def write_vrt(vrt_path: str, source_files: list[str], *, """Generate a VRT file that mosaics multiple GeoTIFF tiles. Each source file is placed in the virtual raster based on its - geo transform. Files must share the same CRS and pixel size. + geo transform. All sources must share the same pixel size, dtype + (sample format + bits-per-sample), band count, and CRS. Mismatches + raise ``ValueError`` rather than producing a misplaced or mis-typed + mosaic. Parameters ---------- @@ -796,10 +799,12 @@ def write_vrt(vrt_path: str, source_files: list[str], *, crs_wkt : str or None CRS as WKT string. If None, taken from the first source. nodata : float, int, or None - NoData value. If None, taken from the first source. Integer - sentinels (e.g. ``65535`` for uint16, ``-9999`` for int32) are - accepted so the surface lines up with the ``nodata`` kwarg on - ``to_geotiff`` and ``write_geotiff_gpu``. + NoData value applied to every band of the mosaic. Caller-supplied + value takes precedence; when ``None``, the first source's + per-band nodata is used. Integer sentinels (e.g. ``65535`` for + uint16, ``-9999`` for int32) are accepted so the surface lines up + with the ``nodata`` kwarg on ``to_geotiff`` and + ``write_geotiff_gpu``. Returns ------- @@ -849,6 +854,74 @@ def write_vrt(vrt_path: str, source_files: list[str], *, res_x = first['transform'].pixel_width res_y = first['transform'].pixel_height + # Enforce the docstring contract: every source must agree with the + # first on pixel size, sample format + bits-per-sample (i.e. dtype), + # band count, and CRS WKT. Without this, build_vrt would silently + # produce a syntactically valid VRT that misplaces or mis-types data + # downstream (issue #1733). + # + # Pixel size is compared with a small relative tolerance: TIFFs + # written by different tools occasionally round the GeoTransform + # slightly, and the existing mosaic-extent math rounds the final + # raster size anyway. Sample format + bps must match exactly because + # the VRT dtype is taken from ``first`` and applied to every band. + _PIXEL_SIZE_RTOL = 1e-6 + + def _pixel_size_mismatch(a: float, b: float) -> bool: + # Use a relative tolerance on the magnitude rather than a direct + # ``a == b`` check so that minor float rounding between tools is + # ignored. ``pixel_height`` is negative for the common north-up + # case; a flipped sign would change the magnitude ratio above + # the tolerance and so is correctly treated as a mismatch (which + # is the desired behavior, since opposite-sign sources do not + # stack into a valid VRT). + if a == b: + return False + denom = max(abs(a), abs(b)) + if denom == 0.0: + return abs(a - b) > 0.0 + return abs(a - b) / denom > _PIXEL_SIZE_RTOL + + first_crs = first.get('crs_wkt') + for m in sources_meta[1:]: + t = m['transform'] + if _pixel_size_mismatch(t.pixel_width, res_x) \ + or _pixel_size_mismatch(t.pixel_height, res_y): + raise ValueError( + f"VRT source {m['path']!r} has pixel size " + f"({t.pixel_width}, {t.pixel_height}) which does not " + f"match the first source ({res_x}, {res_y}). All sources " + f"in a VRT must share the same pixel size." + ) + if m['sample_format'] != first['sample_format'] \ + or m['bps'] != first['bps']: + raise ValueError( + f"VRT source {m['path']!r} has sample_format=" + f"{m['sample_format']}, bps={m['bps']} which does not " + f"match the first source (sample_format=" + f"{first['sample_format']}, bps={first['bps']}). All " + f"sources in a VRT must share the same dtype." + ) + if m['bands'] != first['bands']: + raise ValueError( + f"VRT source {m['path']!r} has {m['bands']} band(s) " + f"which does not match the first source " + f"({first['bands']} band(s)). All sources in a VRT " + f"must share the same band count." + ) + m_crs = m.get('crs_wkt') + # Treat asymmetric CRS (one set, one missing/empty) as a + # mismatch too. If we only flagged when both were set, a source + # that lost its CRS during a re-write would silently inherit the + # first source's CRS in the VRT header and could end up tagged + # with the wrong projection. + if (first_crs or m_crs) and m_crs != first_crs: + raise ValueError( + f"VRT source {m['path']!r} has CRS WKT that does not " + f"match the first source. All sources in a VRT must " + f"share the same CRS." + ) + # Compute the bounding box of all sources all_x0, all_y0, all_x1, all_y1 = [], [], [], [] for m in sources_meta: diff --git a/xrspatial/geotiff/tests/test_vrt_writer_source_compat_1733.py b/xrspatial/geotiff/tests/test_vrt_writer_source_compat_1733.py new file mode 100644 index 00000000..380f02fe --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_writer_source_compat_1733.py @@ -0,0 +1,188 @@ +"""Regression tests for issue #1733. + +``write_vrt`` previously trusted the first source for resolution, +sample format + bps (dtype), band count, and CRS. A mismatched source +would silently produce a VRT that placed pixels incorrectly or +re-interpreted bytes as the wrong dtype downstream. + +These tests assert that ``write_vrt`` now rejects mismatched sources +with a clear ``ValueError`` covering each of those properties, and +still accepts sources that match within a small float tolerance on +pixel size. +""" +from __future__ import annotations + +import os +import uuid + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff._vrt import write_vrt + + +def _unique_dir(tmp_path, label: str) -> str: + d = tmp_path / f"vrt_1733_{label}_{uuid.uuid4().hex[:8]}" + d.mkdir() + return str(d) + + +def _write_tif(path: str, *, h: int, w: int, dtype, bands: int = 1, + px: float = 1.0, py: float = -1.0, + origin_x: float = 0.0, origin_y: float = 100.0, + crs: int | None = 4326) -> None: + if bands == 1: + arr = np.arange(h * w, dtype=dtype).reshape(h, w) + dims = ['y', 'x'] + else: + arr = np.arange(h * w * bands, dtype=dtype).reshape(h, w, bands) + dims = ['y', 'x', 'band'] + y = origin_y + (np.arange(h) + 0.5) * py + x = origin_x + (np.arange(w) + 0.5) * px + coords = {'y': y, 'x': x} + attrs = {} + if crs is not None: + attrs['crs'] = crs + da = xr.DataArray(arr, dims=dims, coords=coords, attrs=attrs) + to_geotiff(da, path, compression='none') + + +def test_mismatched_pixel_size_raises(tmp_path): + d = _unique_dir(tmp_path, "px") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32, px=1.0, py=-1.0) + # Place b adjacent so the geometry would otherwise work, but the + # pixel size disagrees. + _write_tif(b, h=4, w=4, dtype=np.float32, px=2.0, py=-2.0, + origin_x=4.0) + vrt = os.path.join(d, "out.vrt") + with pytest.raises(ValueError, match="pixel size"): + write_vrt(vrt, [a, b]) + + +def test_mismatched_dtype_raises(tmp_path): + d = _unique_dir(tmp_path, "dtype") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32) + _write_tif(b, h=4, w=4, dtype=np.int16, origin_x=4.0) + vrt = os.path.join(d, "out.vrt") + with pytest.raises(ValueError, match="dtype|sample_format|bps"): + write_vrt(vrt, [a, b]) + + +def test_mismatched_band_count_raises(tmp_path): + d = _unique_dir(tmp_path, "bands") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32, bands=1) + _write_tif(b, h=4, w=4, dtype=np.float32, bands=3, origin_x=4.0) + vrt = os.path.join(d, "out.vrt") + with pytest.raises(ValueError, match="band count"): + write_vrt(vrt, [a, b]) + + +def test_compatible_sources_succeed(tmp_path): + d = _unique_dir(tmp_path, "ok") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32) + _write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0) + vrt = os.path.join(d, "out.vrt") + write_vrt(vrt, [a, b]) + assert os.path.exists(vrt) + + +def test_pixel_size_within_tolerance_accepted(tmp_path): + d = _unique_dir(tmp_path, "tol") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32, px=1.0, py=-1.0) + # Drift well below the 1e-6 relative tolerance. + _write_tif(b, h=4, w=4, dtype=np.float32, + px=1.0 + 1e-10, py=-1.0, origin_x=4.0) + vrt = os.path.join(d, "out.vrt") + write_vrt(vrt, [a, b]) + assert os.path.exists(vrt) + + +def test_single_source_still_works(tmp_path): + d = _unique_dir(tmp_path, "one") + a = os.path.join(d, "a.tif") + _write_tif(a, h=4, w=4, dtype=np.float32) + vrt = os.path.join(d, "out.vrt") + write_vrt(vrt, [a]) + assert os.path.exists(vrt) + + +def test_mismatched_crs_raises(tmp_path): + # Two sources with different non-empty CRS values must be rejected, + # otherwise the VRT would inherit the first source's CRS and silently + # misproject the second. + d = _unique_dir(tmp_path, "crs_diff") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32, crs=4326) + _write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=3857) + vrt = os.path.join(d, "out.vrt") + with pytest.raises(ValueError, match="CRS"): + write_vrt(vrt, [a, b]) + + +def test_asymmetric_crs_raises_first_set_second_missing(tmp_path): + # First source has a CRS, second is written without one. The VRT + # would otherwise be tagged with the first source's CRS, which can + # misplace data when the second source actually came from a + # different (or unknown) projection. + d = _unique_dir(tmp_path, "crs_first") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32, crs=4326) + _write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=None) + vrt = os.path.join(d, "out.vrt") + with pytest.raises(ValueError, match="CRS"): + write_vrt(vrt, [a, b]) + + +def test_asymmetric_crs_raises_first_missing_second_set(tmp_path): + # Symmetric case: first source missing a CRS, second has one. The + # earlier guard only triggered when both sides were set, so this + # would have silently produced an untagged VRT despite one source + # carrying a known projection. + d = _unique_dir(tmp_path, "crs_second") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32, crs=None) + _write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=4326) + vrt = os.path.join(d, "out.vrt") + with pytest.raises(ValueError, match="CRS"): + write_vrt(vrt, [a, b]) + + +def test_matching_crs_succeeds(tmp_path): + # Sanity check: two sources with the same CRS should still be + # accepted (defends against an overly aggressive equality check). + d = _unique_dir(tmp_path, "crs_match") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32, crs=4326) + _write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=4326) + vrt = os.path.join(d, "out.vrt") + write_vrt(vrt, [a, b]) + assert os.path.exists(vrt) + + +def test_both_missing_crs_succeeds(tmp_path): + # If neither source has a CRS, the VRT just won't be tagged with one + # and there's nothing to mis-tag. This must not raise. + d = _unique_dir(tmp_path, "crs_both_missing") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_tif(a, h=4, w=4, dtype=np.float32, crs=None) + _write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=None) + vrt = os.path.join(d, "out.vrt") + write_vrt(vrt, [a, b]) + assert os.path.exists(vrt)