Skip to content

Commit 4a3791c

Browse files
committed
Fix 8 remaining gaps for production readiness
1. Band-first DataArray (CRITICAL): write_geotiff now detects (band, y, x) dimension order and transposes to (y, x, band). Prevents silent data corruption from rasterio-style arrays. 2. HTTP COG sub-byte support (CRITICAL): the COG HTTP reader now routes through _decode_strip_or_tile like the local readers, so 1-bit/4-bit/12-bit COGs over HTTP work correctly. 3. Dask VRT support (USEFUL): read_geotiff_dask detects .vrt files and reads eagerly then chunks, since VRT windowed reads need the virtual dataset's source layout. 4. VRT writer (USEFUL): write_vrt() generates a VRT XML file from multiple source GeoTIFFs, computing the mosaic layout from their geo transforms. Supports relative paths and CRS/nodata. 5. ExtraSamples tag (USEFUL): RGBA writes now include tag 338 with value 2 (unassociated alpha). Multi-band with >3 bands also gets ExtraSamples for bands beyond RGB. 6. MinIsWhite (USEFUL): photometric=0 (MinIsWhite) single-band files are now inverted on read so 0=black, 255=white. Integer values are inverted via max-value, floats via negation. 7. Post-write validation (POLISH): after writing, the header bytes are parsed to verify the output is a valid TIFF. Emits a warning if the header is corrupt. 8. Float16/bool auto-promotion (POLISH): float16 arrays are promoted to float32, bool arrays to uint8, instead of raising ValueError. 275 tests passing. 7 new tests for the fixes plus updated edge case tests.
1 parent af14140 commit 4a3791c

File tree

7 files changed

+425
-22
lines changed

7 files changed

+425
-22
lines changed

xrspatial/geotiff/__init__.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ._writer import write
2222

2323
__all__ = ['read_geotiff', 'write_geotiff', 'open_cog', 'read_geotiff_dask',
24-
'read_vrt']
24+
'read_vrt', 'write_vrt']
2525

2626

2727
def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
@@ -305,6 +305,9 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
305305

306306
if isinstance(data, xr.DataArray):
307307
arr = data.values
308+
# Handle band-first dimension order (band, y, x) -> (y, x, band)
309+
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
310+
arr = np.moveaxis(arr, 0, -1)
308311
if geo_transform is None:
309312
geo_transform = _coords_to_transform(data)
310313
if epsg is None and crs is None:
@@ -340,6 +343,12 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
340343
if arr.ndim not in (2, 3):
341344
raise ValueError(f"Expected 2D or 3D array, got {arr.ndim}D")
342345

346+
# Auto-promote unsupported dtypes
347+
if arr.dtype == np.float16:
348+
arr = arr.astype(np.float32)
349+
elif arr.dtype == np.bool_:
350+
arr = arr.astype(np.uint8)
351+
343352
write(
344353
arr, path,
345354
geo_transform=geo_transform,
@@ -407,6 +416,13 @@ def read_geotiff_dask(source: str, *, chunks: int | tuple = 512,
407416
"""
408417
import dask.array as da
409418

419+
# VRT files: read eagerly (VRT mosaic isn't compatible with per-chunk
420+
# windowed reads on the virtual dataset without a separate code path)
421+
if source.lower().endswith('.vrt'):
422+
da_eager = read_vrt(source, name=name)
423+
return da_eager.chunk({'y': chunks if isinstance(chunks, int) else chunks[0],
424+
'x': chunks if isinstance(chunks, int) else chunks[1]})
425+
410426
# First, do a metadata-only read to get shape, dtype, coords, attrs
411427
arr, geo_info = read_to_array(source, overview_level=overview_level)
412428
full_h, full_w = arr.shape[:2]
@@ -566,6 +582,27 @@ def read_vrt(source: str, *, window=None,
566582
return xr.DataArray(arr, dims=dims, coords=coords, name=name, attrs=attrs)
567583

568584

585+
def write_vrt(vrt_path: str, source_files: list[str], **kwargs) -> str:
586+
"""Generate a VRT file that mosaics multiple GeoTIFF tiles.
587+
588+
Parameters
589+
----------
590+
vrt_path : str
591+
Output .vrt file path.
592+
source_files : list of str
593+
Paths to the source GeoTIFF files.
594+
**kwargs
595+
relative, crs_wkt, nodata -- see _vrt.write_vrt.
596+
597+
Returns
598+
-------
599+
str
600+
Path to the written VRT file.
601+
"""
602+
from ._vrt import write_vrt as _write_vrt_internal
603+
return _write_vrt_internal(vrt_path, source_files, **kwargs)
604+
605+
569606
def plot_geotiff(da: xr.DataArray, **kwargs):
570607
"""Plot a DataArray using its embedded colormap if present.
571608

xrspatial/geotiff/_header.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TAG_TILE_OFFSETS = 324
3535
TAG_TILE_BYTE_COUNTS = 325
3636
TAG_COLORMAP = 320
37+
TAG_EXTRA_SAMPLES = 338
3738
TAG_SAMPLE_FORMAT = 339
3839
TAG_GDAL_METADATA = 42112
3940
TAG_GDAL_NODATA = 42113

xrspatial/geotiff/_reader.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def _read_cog_http(url: str, overview_level: int | None = None,
585585
compression = ifd.compression
586586
pred = ifd.predictor
587587
bytes_per_sample = bps // 8
588+
is_sub_byte = bps in SUB_BYTE_BPS
588589

589590
offsets = ifd.tile_offsets
590591
byte_counts = ifd.tile_byte_counts
@@ -609,22 +610,10 @@ def _read_cog_http(url: str, overview_level: int | None = None,
609610
continue
610611

611612
tile_data = source.read_range(off, bc)
612-
expected = tw * th * samples * bytes_per_sample
613-
chunk = decompress(tile_data, compression, expected,
614-
width=tw, height=th, samples=samples)
615-
616-
if pred in (2, 3):
617-
if not chunk.flags.writeable:
618-
chunk = chunk.copy()
619-
chunk = _apply_predictor(chunk, pred, tw, th, bytes_per_sample * samples)
620-
621-
file_dtype = dtype.newbyteorder(header.byte_order)
622-
if samples > 1:
623-
tile_pixels = chunk.view(file_dtype).reshape(th, tw, samples)
624-
else:
625-
tile_pixels = chunk.view(file_dtype).reshape(th, tw)
626-
if file_dtype.byteorder not in ('=', '|', _NATIVE_ORDER):
627-
tile_pixels = tile_pixels.astype(dtype)
613+
tile_pixels = _decode_strip_or_tile(
614+
tile_data, compression, tw, th, samples,
615+
bps, bytes_per_sample, is_sub_byte, dtype, pred,
616+
byte_order=header.byte_order)
628617

629618
# Place tile
630619
y0 = tr * th
@@ -699,6 +688,13 @@ def read_to_array(source: str, *, window=None, overview_level: int | None = None
699688
# For multi-band with band selection, extract single band
700689
if arr.ndim == 3 and ifd.samples_per_pixel > 1 and band is not None:
701690
arr = arr[:, :, band]
691+
692+
# MinIsWhite (photometric=0): invert single-band grayscale values
693+
if ifd.photometric == 0 and ifd.samples_per_pixel == 1:
694+
if arr.dtype.kind == 'u':
695+
arr = np.iinfo(arr.dtype).max - arr
696+
elif arr.dtype.kind == 'f':
697+
arr = -arr
702698
finally:
703699
src.close()
704700

xrspatial/geotiff/_vrt.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,166 @@ def read_vrt(vrt_path: str, *, window=None,
316316
band_idx] = src_arr[:actual_h, :actual_w]
317317

318318
return result, vrt
319+
320+
321+
# ---------------------------------------------------------------------------
322+
# VRT writer
323+
# ---------------------------------------------------------------------------
324+
325+
_NP_TO_VRT_DTYPE = {v: k for k, v in _DTYPE_MAP.items()}
326+
327+
328+
def write_vrt(vrt_path: str, source_files: list[str], *,
329+
relative: bool = True,
330+
crs_wkt: str | None = None,
331+
nodata: float | None = None) -> str:
332+
"""Generate a VRT file that mosaics multiple GeoTIFF tiles.
333+
334+
Each source file is placed in the virtual raster based on its
335+
geo transform. Files must share the same CRS and pixel size.
336+
337+
Parameters
338+
----------
339+
vrt_path : str
340+
Output .vrt file path.
341+
source_files : list of str
342+
Paths to the source GeoTIFF files.
343+
relative : bool
344+
Store source paths relative to the VRT file.
345+
crs_wkt : str or None
346+
CRS as WKT string. If None, taken from the first source.
347+
nodata : float or None
348+
NoData value. If None, taken from the first source.
349+
350+
Returns
351+
-------
352+
str
353+
Path to the written VRT file.
354+
"""
355+
from ._reader import read_to_array
356+
from ._header import parse_header, parse_all_ifds
357+
from ._geotags import extract_geo_info
358+
from ._reader import _FileSource
359+
360+
if not source_files:
361+
raise ValueError("source_files must not be empty")
362+
363+
# Read metadata from all sources
364+
sources_meta = []
365+
for src_path in source_files:
366+
src = _FileSource(src_path)
367+
data = src.read_all()
368+
header = parse_header(data)
369+
ifds = parse_all_ifds(data, header)
370+
ifd = ifds[0]
371+
geo = extract_geo_info(ifd, data, header.byte_order)
372+
src.close()
373+
374+
bps = ifd.bits_per_sample
375+
if isinstance(bps, tuple):
376+
bps = bps[0]
377+
378+
sources_meta.append({
379+
'path': src_path,
380+
'width': ifd.width,
381+
'height': ifd.height,
382+
'bands': ifd.samples_per_pixel,
383+
'dtype': np.dtype(_DTYPE_MAP.get(
384+
{v: k for k, v in _DTYPE_MAP.items()}.get(
385+
np.dtype(f'{"f" if ifd.sample_format == 3 else ("i" if ifd.sample_format == 2 else "u")}{bps // 8}').type,
386+
'Float32'),
387+
np.float32)),
388+
'bps': bps,
389+
'sample_format': ifd.sample_format,
390+
'transform': geo.transform,
391+
'crs_wkt': geo.crs_wkt,
392+
'nodata': geo.nodata,
393+
})
394+
395+
first = sources_meta[0]
396+
res_x = first['transform'].pixel_width
397+
res_y = first['transform'].pixel_height
398+
399+
# Compute the bounding box of all sources
400+
all_x0, all_y0, all_x1, all_y1 = [], [], [], []
401+
for m in sources_meta:
402+
t = m['transform']
403+
x0 = t.origin_x
404+
y0 = t.origin_y
405+
x1 = x0 + m['width'] * t.pixel_width
406+
y1 = y0 + m['height'] * t.pixel_height
407+
all_x0.append(min(x0, x1))
408+
all_y0.append(min(y0, y1))
409+
all_x1.append(max(x0, x1))
410+
all_y1.append(max(y0, y1))
411+
412+
mosaic_x0 = min(all_x0)
413+
mosaic_y_top = max(all_y1) # top edge (y increases upward in geo)
414+
mosaic_x1 = max(all_x1)
415+
mosaic_y_bottom = min(all_y0)
416+
417+
total_w = int(round((mosaic_x1 - mosaic_x0) / abs(res_x)))
418+
total_h = int(round((mosaic_y_top - mosaic_y_bottom) / abs(res_y)))
419+
420+
# Determine VRT dtype
421+
sf = first['sample_format']
422+
bps = first['bps']
423+
if sf == 3:
424+
vrt_dtype_name = 'Float64' if bps == 64 else 'Float32'
425+
elif sf == 2:
426+
vrt_dtype_name = {8: 'Int8', 16: 'Int16', 32: 'Int32'}.get(bps, 'Int32')
427+
else:
428+
vrt_dtype_name = {8: 'Byte', 16: 'UInt16', 32: 'UInt32'}.get(bps, 'Byte')
429+
430+
srs = crs_wkt or first.get('crs_wkt') or ''
431+
nd = nodata if nodata is not None else first.get('nodata')
432+
433+
vrt_dir = os.path.dirname(os.path.abspath(vrt_path))
434+
n_bands = first['bands']
435+
436+
# Build XML
437+
lines = [f'<VRTDataset rasterXSize="{total_w}" rasterYSize="{total_h}">']
438+
if srs:
439+
lines.append(f' <SRS>{srs}</SRS>')
440+
lines.append(f' <GeoTransform>{mosaic_x0}, {res_x}, 0.0, '
441+
f'{mosaic_y_top}, 0.0, {res_y}</GeoTransform>')
442+
443+
for band_num in range(1, n_bands + 1):
444+
lines.append(f' <VRTRasterBand dataType="{vrt_dtype_name}" band="{band_num}">')
445+
if nd is not None:
446+
lines.append(f' <NoDataValue>{nd}</NoDataValue>')
447+
448+
for m in sources_meta:
449+
t = m['transform']
450+
# Pixel offset in the virtual raster
451+
dst_x_off = int(round((t.origin_x - mosaic_x0) / abs(res_x)))
452+
dst_y_off = int(round((mosaic_y_top - t.origin_y) / abs(res_y)))
453+
454+
fname = m['path']
455+
rel_attr = '0'
456+
if relative:
457+
try:
458+
fname = os.path.relpath(fname, vrt_dir)
459+
rel_attr = '1'
460+
except ValueError:
461+
pass # different drives on Windows
462+
463+
lines.append(' <SimpleSource>')
464+
lines.append(f' <SourceFilename relativeToVRT="{rel_attr}">'
465+
f'{fname}</SourceFilename>')
466+
lines.append(f' <SourceBand>{band_num}</SourceBand>')
467+
lines.append(f' <SrcRect xOff="0" yOff="0" '
468+
f'xSize="{m["width"]}" ySize="{m["height"]}"/>')
469+
lines.append(f' <DstRect xOff="{dst_x_off}" yOff="{dst_y_off}" '
470+
f'xSize="{m["width"]}" ySize="{m["height"]}"/>')
471+
lines.append(' </SimpleSource>')
472+
473+
lines.append(' </VRTRasterBand>')
474+
475+
lines.append('</VRTDataset>')
476+
477+
xml = '\n'.join(lines) + '\n'
478+
with open(vrt_path, 'w') as f:
479+
f.write(xml)
480+
481+
return vrt_path

xrspatial/geotiff/_writer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
TAG_TILE_LENGTH,
5151
TAG_TILE_OFFSETS,
5252
TAG_TILE_BYTE_COUNTS,
53+
TAG_EXTRA_SAMPLES,
5354
TAG_PREDICTOR,
5455
TAG_GDAL_METADATA,
5556
)
@@ -483,6 +484,18 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype,
483484
else:
484485
tags.append((TAG_SAMPLE_FORMAT, SHORT, 1, sample_format))
485486

487+
# ExtraSamples: for bands beyond what Photometric accounts for
488+
# Photometric=2 (RGB) accounts for 3 bands; any extra are alpha/other
489+
if photometric == 2 and samples_per_pixel > 3:
490+
n_extra = samples_per_pixel - 3
491+
# 2 = unassociated alpha for the first extra, 0 = unspecified for rest
492+
extra_vals = [2] + [0] * (n_extra - 1)
493+
tags.append((TAG_EXTRA_SAMPLES, SHORT, n_extra, extra_vals))
494+
elif photometric == 1 and samples_per_pixel > 1:
495+
n_extra = samples_per_pixel - 1
496+
extra_vals = [0] * n_extra # unspecified
497+
tags.append((TAG_EXTRA_SAMPLES, SHORT, n_extra, extra_vals))
498+
486499
if pred_val != 1:
487500
tags.append((TAG_PREDICTOR, SHORT, 1, pred_val))
488501

@@ -814,6 +827,14 @@ def write(data: np.ndarray, path: str, *,
814827

815828
_write_bytes(file_bytes, path)
816829

830+
# Post-write validation: verify the header is parseable
831+
from ._header import parse_header as _ph
832+
try:
833+
_ph(file_bytes[:16])
834+
except Exception as e:
835+
import warnings
836+
warnings.warn(f"Written file may be corrupt: {e}", stacklevel=2)
837+
817838

818839
def _is_fsspec_uri(path: str) -> bool:
819840
"""Check if a path is a fsspec-compatible URI."""

xrspatial/geotiff/tests/test_edge_cases.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,14 @@ def test_complex_dtype(self, tmp_path):
5757
with pytest.raises(ValueError, match="Unsupported numpy dtype"):
5858
write_geotiff(arr, str(tmp_path / 'bad.tif'))
5959

60-
def test_bool_dtype(self, tmp_path):
61-
arr = np.ones((4, 4), dtype=bool)
62-
with pytest.raises(ValueError, match="Unsupported numpy dtype"):
63-
write_geotiff(arr, str(tmp_path / 'bad.tif'))
60+
def test_bool_dtype_auto_promoted(self, tmp_path):
61+
"""Bool arrays are auto-promoted to uint8."""
62+
arr = np.array([[True, False], [False, True]])
63+
path = str(tmp_path / 'bool.tif')
64+
write_geotiff(arr, path, compression='none')
65+
66+
result = read_geotiff(path)
67+
np.testing.assert_array_equal(result.values, arr.astype(np.uint8))
6468

6569

6670
# -----------------------------------------------------------------------

0 commit comments

Comments
 (0)