Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions xrspatial/geotiff/_vrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,10 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
"""Generate a VRT file that mosaics multiple GeoTIFF tiles.

Each source file is placed in the virtual raster based on its
geo transform. Files must share the same CRS and pixel size.
geo transform. All sources must share the same pixel size, dtype
(sample format + bits-per-sample), band count, and CRS. Mismatches
raise ``ValueError`` rather than producing a misplaced or mis-typed
mosaic.

Parameters
----------
Expand All @@ -796,10 +799,12 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
crs_wkt : str or None
CRS as WKT string. If None, taken from the first source.
nodata : float, int, or None
NoData value. If None, taken from the first source. Integer
sentinels (e.g. ``65535`` for uint16, ``-9999`` for int32) are
accepted so the surface lines up with the ``nodata`` kwarg on
``to_geotiff`` and ``write_geotiff_gpu``.
NoData value applied to every band of the mosaic. Caller-supplied
value takes precedence; when ``None``, the first source's
per-band nodata is used. Integer sentinels (e.g. ``65535`` for
uint16, ``-9999`` for int32) are accepted so the surface lines up
with the ``nodata`` kwarg on ``to_geotiff`` and
``write_geotiff_gpu``.

Returns
-------
Expand Down Expand Up @@ -849,6 +854,74 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
res_x = first['transform'].pixel_width
res_y = first['transform'].pixel_height

# Enforce the docstring contract: every source must agree with the
# first on pixel size, sample format + bits-per-sample (i.e. dtype),
# band count, and CRS WKT. Without this, build_vrt would silently
# produce a syntactically valid VRT that misplaces or mis-types data
# downstream (issue #1733).
#
# Pixel size is compared with a small relative tolerance: TIFFs
# written by different tools occasionally round the GeoTransform
# slightly, and the existing mosaic-extent math rounds the final
# raster size anyway. Sample format + bps must match exactly because
# the VRT dtype is taken from ``first`` and applied to every band.
_PIXEL_SIZE_RTOL = 1e-6

def _pixel_size_mismatch(a: float, b: float) -> bool:
# Use a relative tolerance on the magnitude rather than a direct
# ``a == b`` check so that minor float rounding between tools is
# ignored. ``pixel_height`` is negative for the common north-up
# case; a flipped sign would change the magnitude ratio above
# the tolerance and so is correctly treated as a mismatch (which
# is the desired behavior, since opposite-sign sources do not
# stack into a valid VRT).
if a == b:
return False
denom = max(abs(a), abs(b))
if denom == 0.0:
return abs(a - b) > 0.0
return abs(a - b) / denom > _PIXEL_SIZE_RTOL

first_crs = first.get('crs_wkt')
for m in sources_meta[1:]:
t = m['transform']
if _pixel_size_mismatch(t.pixel_width, res_x) \
or _pixel_size_mismatch(t.pixel_height, res_y):
raise ValueError(
f"VRT source {m['path']!r} has pixel size "
f"({t.pixel_width}, {t.pixel_height}) which does not "
f"match the first source ({res_x}, {res_y}). All sources "
f"in a VRT must share the same pixel size."
)
if m['sample_format'] != first['sample_format'] \
or m['bps'] != first['bps']:
raise ValueError(
f"VRT source {m['path']!r} has sample_format="
f"{m['sample_format']}, bps={m['bps']} which does not "
f"match the first source (sample_format="
f"{first['sample_format']}, bps={first['bps']}). All "
f"sources in a VRT must share the same dtype."
)
if m['bands'] != first['bands']:
raise ValueError(
f"VRT source {m['path']!r} has {m['bands']} band(s) "
f"which does not match the first source "
f"({first['bands']} band(s)). All sources in a VRT "
f"must share the same band count."
)
m_crs = m.get('crs_wkt')
# Treat asymmetric CRS (one set, one missing/empty) as a
# mismatch too. If we only flagged when both were set, a source
# that lost its CRS during a re-write would silently inherit the
# first source's CRS in the VRT header and could end up tagged
# with the wrong projection.
if (first_crs or m_crs) and m_crs != first_crs:
raise ValueError(
f"VRT source {m['path']!r} has CRS WKT that does not "
f"match the first source. All sources in a VRT must "
f"share the same CRS."
)
Comment on lines +912 to +923
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in ff56e14: changed the guard from first_crs and m_crs and m_crs != first_crs to (first_crs or m_crs) and m_crs != first_crs. Now any asymmetric setting (one source has a CRS, the other doesn't, or vice versa) is also treated as a mismatch and raises ValueError. Added regression tests for both directions of the asymmetric case plus the two-different-non-empty-CRS case in test_vrt_writer_source_compat_1733.py.


# Compute the bounding box of all sources
all_x0, all_y0, all_x1, all_y1 = [], [], [], []
for m in sources_meta:
Expand Down
188 changes: 188 additions & 0 deletions xrspatial/geotiff/tests/test_vrt_writer_source_compat_1733.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Regression tests for issue #1733.

``write_vrt`` previously trusted the first source for resolution,
sample format + bps (dtype), band count, and CRS. A mismatched source
would silently produce a VRT that placed pixels incorrectly or
re-interpreted bytes as the wrong dtype downstream.

These tests assert that ``write_vrt`` now rejects mismatched sources
with a clear ``ValueError`` covering each of those properties, and
still accepts sources that match within a small float tolerance on
pixel size.
"""
from __future__ import annotations

import os
import uuid

import numpy as np
import pytest
import xarray as xr

from xrspatial.geotiff import to_geotiff
from xrspatial.geotiff._vrt import write_vrt


def _unique_dir(tmp_path, label: str) -> str:
d = tmp_path / f"vrt_1733_{label}_{uuid.uuid4().hex[:8]}"
d.mkdir()
return str(d)


def _write_tif(path: str, *, h: int, w: int, dtype, bands: int = 1,
px: float = 1.0, py: float = -1.0,
origin_x: float = 0.0, origin_y: float = 100.0,
crs: int | None = 4326) -> None:
if bands == 1:
arr = np.arange(h * w, dtype=dtype).reshape(h, w)
dims = ['y', 'x']
else:
arr = np.arange(h * w * bands, dtype=dtype).reshape(h, w, bands)
dims = ['y', 'x', 'band']
y = origin_y + (np.arange(h) + 0.5) * py
x = origin_x + (np.arange(w) + 0.5) * px
coords = {'y': y, 'x': x}
attrs = {}
if crs is not None:
attrs['crs'] = crs
da = xr.DataArray(arr, dims=dims, coords=coords, attrs=attrs)
to_geotiff(da, path, compression='none')


def test_mismatched_pixel_size_raises(tmp_path):
d = _unique_dir(tmp_path, "px")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32, px=1.0, py=-1.0)
Comment on lines +52 to +56
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in ff56e14: added five new test cases in test_vrt_writer_source_compat_1733.py covering the CRS code path. test_mismatched_crs_raises checks two different non-empty CRS values, test_asymmetric_crs_raises_first_set_second_missing and test_asymmetric_crs_raises_first_missing_second_set cover both directions of the asymmetric case, and test_matching_crs_succeeds plus test_both_missing_crs_succeeds defend against an overly aggressive check. All 11 tests in the file pass.

# Place b adjacent so the geometry would otherwise work, but the
# pixel size disagrees.
_write_tif(b, h=4, w=4, dtype=np.float32, px=2.0, py=-2.0,
origin_x=4.0)
vrt = os.path.join(d, "out.vrt")
with pytest.raises(ValueError, match="pixel size"):
write_vrt(vrt, [a, b])


def test_mismatched_dtype_raises(tmp_path):
d = _unique_dir(tmp_path, "dtype")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32)
_write_tif(b, h=4, w=4, dtype=np.int16, origin_x=4.0)
vrt = os.path.join(d, "out.vrt")
with pytest.raises(ValueError, match="dtype|sample_format|bps"):
write_vrt(vrt, [a, b])


def test_mismatched_band_count_raises(tmp_path):
d = _unique_dir(tmp_path, "bands")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32, bands=1)
_write_tif(b, h=4, w=4, dtype=np.float32, bands=3, origin_x=4.0)
vrt = os.path.join(d, "out.vrt")
with pytest.raises(ValueError, match="band count"):
write_vrt(vrt, [a, b])


def test_compatible_sources_succeed(tmp_path):
d = _unique_dir(tmp_path, "ok")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32)
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0)
vrt = os.path.join(d, "out.vrt")
write_vrt(vrt, [a, b])
assert os.path.exists(vrt)


def test_pixel_size_within_tolerance_accepted(tmp_path):
d = _unique_dir(tmp_path, "tol")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32, px=1.0, py=-1.0)
# Drift well below the 1e-6 relative tolerance.
_write_tif(b, h=4, w=4, dtype=np.float32,
px=1.0 + 1e-10, py=-1.0, origin_x=4.0)
vrt = os.path.join(d, "out.vrt")
write_vrt(vrt, [a, b])
assert os.path.exists(vrt)


def test_single_source_still_works(tmp_path):
d = _unique_dir(tmp_path, "one")
a = os.path.join(d, "a.tif")
_write_tif(a, h=4, w=4, dtype=np.float32)
vrt = os.path.join(d, "out.vrt")
write_vrt(vrt, [a])
assert os.path.exists(vrt)


def test_mismatched_crs_raises(tmp_path):
# Two sources with different non-empty CRS values must be rejected,
# otherwise the VRT would inherit the first source's CRS and silently
# misproject the second.
d = _unique_dir(tmp_path, "crs_diff")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32, crs=4326)
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=3857)
vrt = os.path.join(d, "out.vrt")
with pytest.raises(ValueError, match="CRS"):
write_vrt(vrt, [a, b])


def test_asymmetric_crs_raises_first_set_second_missing(tmp_path):
# First source has a CRS, second is written without one. The VRT
# would otherwise be tagged with the first source's CRS, which can
# misplace data when the second source actually came from a
# different (or unknown) projection.
d = _unique_dir(tmp_path, "crs_first")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32, crs=4326)
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=None)
vrt = os.path.join(d, "out.vrt")
with pytest.raises(ValueError, match="CRS"):
write_vrt(vrt, [a, b])


def test_asymmetric_crs_raises_first_missing_second_set(tmp_path):
# Symmetric case: first source missing a CRS, second has one. The
# earlier guard only triggered when both sides were set, so this
# would have silently produced an untagged VRT despite one source
# carrying a known projection.
d = _unique_dir(tmp_path, "crs_second")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32, crs=None)
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=4326)
vrt = os.path.join(d, "out.vrt")
with pytest.raises(ValueError, match="CRS"):
write_vrt(vrt, [a, b])


def test_matching_crs_succeeds(tmp_path):
# Sanity check: two sources with the same CRS should still be
# accepted (defends against an overly aggressive equality check).
d = _unique_dir(tmp_path, "crs_match")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32, crs=4326)
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=4326)
vrt = os.path.join(d, "out.vrt")
write_vrt(vrt, [a, b])
assert os.path.exists(vrt)


def test_both_missing_crs_succeeds(tmp_path):
# If neither source has a CRS, the VRT just won't be tagged with one
# and there's nothing to mis-tag. This must not raise.
d = _unique_dir(tmp_path, "crs_both_missing")
a = os.path.join(d, "a.tif")
b = os.path.join(d, "b.tif")
_write_tif(a, h=4, w=4, dtype=np.float32, crs=None)
_write_tif(b, h=4, w=4, dtype=np.float32, origin_x=4.0, crs=None)
vrt = os.path.join(d, "out.vrt")
write_vrt(vrt, [a, b])
assert os.path.exists(vrt)
Loading