Skip to content

Commit 9f8fb38

Browse files
committed
Rename GeoTIFF API to xarray conventions (#1047)
open_geotiff replaces read_geotiff, to_geotiff replaces write_geotiff. Adds .xrs.to_geotiff() accessor on DataArray and Dataset, and .xrs.open_geotiff() on Dataset for spatially-windowed reads.
1 parent 66fc110 commit 9f8fb38

File tree

6 files changed

+289
-165
lines changed

6 files changed

+289
-165
lines changed

xrspatial/accessor.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ def __init__(self, obj):
2626
def plot(self, **kwargs):
2727
"""Plot the DataArray, using an embedded TIFF colormap if present.
2828
29-
For palette/indexed-color GeoTIFFs (read via ``read_geotiff``),
29+
For palette/indexed-color GeoTIFFs (read via ``open_geotiff``),
3030
the TIFF's color table is applied automatically with correct
3131
normalization. For all other DataArrays, falls through to the
3232
standard ``da.plot()``.
3333
3434
Usage::
3535
36-
da = read_geotiff('landcover.tif')
36+
da = open_geotiff('landcover.tif')
3737
da.xrs.plot() # palette colors used automatically
3838
"""
3939
import numpy as np
@@ -460,6 +460,18 @@ def rasterize(self, geometries, **kwargs):
460460
from .rasterize import rasterize
461461
return rasterize(geometries, like=self._obj, **kwargs)
462462

463+
# ---- GeoTIFF I/O ----
464+
465+
def to_geotiff(self, path, **kwargs):
466+
"""Write this DataArray as a GeoTIFF.
467+
468+
Equivalent to ``to_geotiff(da, path, **kwargs)``.
469+
470+
See :func:`xrspatial.geotiff.to_geotiff` for full parameter docs.
471+
"""
472+
from .geotiff import to_geotiff
473+
return to_geotiff(self._obj, path, **kwargs)
474+
463475

464476
@xr.register_dataset_accessor("xrs")
465477
class XrsSpatialDatasetAccessor:
@@ -776,3 +788,75 @@ def rasterize(self, geometries, **kwargs):
776788
"Dataset has no 2D variable with 'y' and 'x' dimensions "
777789
"to use as rasterize template"
778790
)
791+
792+
# ---- GeoTIFF I/O ----
793+
794+
def to_geotiff(self, path, var=None, **kwargs):
795+
"""Write a Dataset variable as a GeoTIFF.
796+
797+
Parameters
798+
----------
799+
path : str
800+
Output file path.
801+
var : str or None
802+
Variable name to write. If None, uses the first 2D variable
803+
with y/x dimensions.
804+
**kwargs
805+
Passed to :func:`xrspatial.geotiff.to_geotiff`.
806+
"""
807+
from .geotiff import to_geotiff
808+
ds = self._obj
809+
if var is not None:
810+
return to_geotiff(ds[var], path, **kwargs)
811+
for v in ds.data_vars:
812+
da = ds[v]
813+
if da.ndim >= 2 and 'y' in da.dims and 'x' in da.dims:
814+
return to_geotiff(da, path, **kwargs)
815+
raise ValueError(
816+
"Dataset has no variable with 'y' and 'x' dimensions to write"
817+
)
818+
819+
def open_geotiff(self, source, **kwargs):
820+
"""Read a GeoTIFF windowed to this Dataset's spatial extent.
821+
822+
Uses the Dataset's y/x coordinates to compute a pixel window,
823+
then reads only that region from the file.
824+
825+
Parameters
826+
----------
827+
source : str
828+
File path to the GeoTIFF.
829+
**kwargs
830+
Passed to :func:`xrspatial.geotiff.open_geotiff` (except
831+
``window``, which is computed automatically).
832+
833+
Returns
834+
-------
835+
xr.DataArray
836+
The windowed portion of the GeoTIFF.
837+
"""
838+
from .geotiff import open_geotiff, _read_geo_info, _extent_to_window
839+
ds = self._obj
840+
if 'y' not in ds.coords or 'x' not in ds.coords:
841+
raise ValueError(
842+
"Dataset must have 'y' and 'x' coordinates to compute "
843+
"a spatial window"
844+
)
845+
y = ds.coords['y'].values
846+
x = ds.coords['x'].values
847+
y_min, y_max = float(y.min()), float(y.max())
848+
x_min, x_max = float(x.min()), float(x.max())
849+
850+
geo_info, file_h, file_w = _read_geo_info(source)
851+
t = geo_info.transform
852+
853+
# Expand extent by half a pixel so we capture edge pixels
854+
y_min -= abs(t.pixel_height) * 0.5
855+
y_max += abs(t.pixel_height) * 0.5
856+
x_min -= abs(t.pixel_width) * 0.5
857+
x_max += abs(t.pixel_width) * 0.5
858+
859+
window = _extent_to_window(t, file_h, file_w,
860+
y_min, y_max, x_min, x_max)
861+
kwargs.pop('window', None)
862+
return open_geotiff(source, window=window, **kwargs)

xrspatial/geotiff/__init__.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
55
Public API
66
----------
7-
read_geotiff(source, ...)
7+
open_geotiff(source, ...)
88
Read a GeoTIFF file to an xarray.DataArray.
9-
write_geotiff(data, path, ...)
9+
to_geotiff(data, path, ...)
1010
Write an xarray.DataArray as a GeoTIFF or COG.
11-
open_cog(url, ...)
12-
Read a Cloud Optimized GeoTIFF from an HTTP URL.
11+
write_vrt(vrt_path, source_files, ...)
12+
Generate a VRT mosaic XML from a list of GeoTIFF files.
1313
"""
1414
from __future__ import annotations
1515

@@ -20,7 +20,7 @@
2020
from ._reader import read_to_array
2121
from ._writer import write
2222

23-
__all__ = ['read_geotiff', 'write_geotiff', 'write_vrt']
23+
__all__ = ['open_geotiff', 'to_geotiff', 'write_vrt']
2424

2525

2626
def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
@@ -98,7 +98,55 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None:
9898
)
9999

100100

101-
def read_geotiff(source: str, *, window=None,
101+
def _read_geo_info(source: str):
102+
"""Read only the geographic metadata and image dimensions from a GeoTIFF.
103+
104+
Returns (geo_info, height, width) without reading pixel data.
105+
"""
106+
from ._geotags import extract_geo_info
107+
from ._header import parse_all_ifds, parse_header
108+
109+
with open(source, 'rb') as f:
110+
import mmap
111+
data = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
112+
try:
113+
header = parse_header(data)
114+
ifds = parse_all_ifds(data, header)
115+
ifd = ifds[0]
116+
geo_info = extract_geo_info(ifd, data, header.byte_order)
117+
return geo_info, ifd.height, ifd.width
118+
finally:
119+
data.close()
120+
121+
122+
def _extent_to_window(transform, file_height, file_width,
123+
y_min, y_max, x_min, x_max):
124+
"""Convert geographic extent to pixel window (row_start, col_start, row_stop, col_stop).
125+
126+
Clamps to file bounds.
127+
"""
128+
# Pixel coords from geographic coords
129+
col_start = (x_min - transform.origin_x) / transform.pixel_width
130+
col_stop = (x_max - transform.origin_x) / transform.pixel_width
131+
132+
row_start = (y_max - transform.origin_y) / transform.pixel_height
133+
row_stop = (y_min - transform.origin_y) / transform.pixel_height
134+
135+
# pixel_height is typically negative, so row_start/row_stop may be swapped
136+
if row_start > row_stop:
137+
row_start, row_stop = row_stop, row_start
138+
if col_start > col_stop:
139+
col_start, col_stop = col_stop, col_start
140+
141+
row_start = max(0, int(np.floor(row_start)))
142+
col_start = max(0, int(np.floor(col_start)))
143+
row_stop = min(file_height, int(np.ceil(row_stop)))
144+
col_stop = min(file_width, int(np.ceil(col_stop)))
145+
146+
return (row_start, col_start, row_stop, col_stop)
147+
148+
149+
def open_geotiff(source: str, *, window=None,
102150
overview_level: int | None = None,
103151
band: int | None = None,
104152
name: str | None = None,
@@ -285,18 +333,18 @@ def _is_gpu_data(data) -> bool:
285333
return isinstance(data, _cupy_type)
286334

287335

288-
def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
289-
crs: int | str | None = None,
290-
nodata=None,
291-
compression: str = 'deflate',
292-
tiled: bool = True,
293-
tile_size: int = 256,
294-
predictor: bool = False,
295-
cog: bool = False,
296-
overview_levels: list[int] | None = None,
297-
overview_resampling: str = 'mean',
298-
bigtiff: bool | None = None,
299-
gpu: bool | None = None) -> None:
336+
def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
337+
crs: int | str | None = None,
338+
nodata=None,
339+
compression: str = 'deflate',
340+
tiled: bool = True,
341+
tile_size: int = 256,
342+
predictor: bool = False,
343+
cog: bool = False,
344+
overview_levels: list[int] | None = None,
345+
overview_resampling: str = 'mean',
346+
bigtiff: bool | None = None,
347+
gpu: bool | None = None) -> None:
300348
"""Write data as a GeoTIFF or Cloud Optimized GeoTIFF.
301349
302350
Automatically dispatches to GPU compression when:
@@ -442,14 +490,6 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
442490
)
443491

444492

445-
def open_cog(url: str, **kwargs) -> xr.DataArray:
446-
"""Deprecated: use ``read_geotiff(url, ...)`` instead.
447-
448-
read_geotiff handles HTTP URLs, cloud URIs, and local files.
449-
"""
450-
return read_geotiff(url, **kwargs)
451-
452-
453493
def read_geotiff_dask(source: str, *, chunks: int | tuple = 512,
454494
overview_level: int | None = None,
455495
name: str | None = None) -> xr.DataArray:

xrspatial/geotiff/tests/bench_vs_rioxarray.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def _fmt_ms(seconds):
3636
def check_consistency(path):
3737
"""Compare pixel values and geo metadata between the two readers."""
3838
import rioxarray # noqa: F401
39-
from xrspatial.geotiff import read_geotiff
39+
from xrspatial.geotiff import open_geotiff
4040

4141
rio_da = xr.open_dataarray(path, engine='rasterio')
4242
rio_arr = rio_da.squeeze('band').values.astype(np.float64)
4343

44-
our_da = read_geotiff(path)
44+
our_da = open_geotiff(path)
4545
our_arr = our_da.values.astype(np.float64)
4646

4747
# Shape
@@ -96,7 +96,7 @@ def check_consistency(path):
9696
def bench_read(path, runs=10):
9797
"""Benchmark read performance."""
9898
import rioxarray # noqa: F401
99-
from xrspatial.geotiff import read_geotiff
99+
from xrspatial.geotiff import open_geotiff
100100

101101
def rio_read():
102102
da = xr.open_dataarray(path, engine='rasterio')
@@ -105,7 +105,7 @@ def rio_read():
105105
return da
106106

107107
def our_read():
108-
return read_geotiff(path)
108+
return open_geotiff(path)
109109

110110
rio_time, _ = _timer(rio_read, warmup=2, runs=runs)
111111
our_time, _ = _timer(our_read, warmup=2, runs=runs)
@@ -120,7 +120,7 @@ def our_read():
120120
def bench_write(shape=(512, 512), compression='deflate', runs=5):
121121
"""Benchmark write performance."""
122122
import rioxarray # noqa: F401
123-
from xrspatial.geotiff import write_geotiff
123+
from xrspatial.geotiff import to_geotiff
124124
from xrspatial.geotiff._geotags import GeoTransform
125125

126126
rng = np.random.RandomState(42)
@@ -150,7 +150,7 @@ def rio_write():
150150

151151
def our_write():
152152
p = os.path.join(tmpdir, 'our_out.tif')
153-
write_geotiff(da_ours, p, compression=compression, tiled=False)
153+
to_geotiff(da_ours, p, compression=compression, tiled=False)
154154
return os.path.getsize(p)
155155

156156
rio_time, rio_size = _timer(rio_write, warmup=1, runs=runs)
@@ -166,7 +166,7 @@ def our_write():
166166
def bench_round_trip(shape=(256, 256), compression='deflate'):
167167
"""Write with our module, read back with rioxarray, and vice versa."""
168168
import rioxarray # noqa: F401
169-
from xrspatial.geotiff import read_geotiff, write_geotiff
169+
from xrspatial.geotiff import open_geotiff, to_geotiff
170170

171171
rng = np.random.RandomState(99)
172172
arr = rng.rand(*shape).astype(np.float32)
@@ -179,7 +179,7 @@ def bench_round_trip(shape=(256, 256), compression='deflate'):
179179
our_path = os.path.join(tmpdir, 'ours.tif')
180180
da_ours = xr.DataArray(arr, dims=['y', 'x'], coords={'y': y, 'x': x},
181181
attrs={'crs': 4326})
182-
write_geotiff(da_ours, our_path, compression=compression, tiled=False)
182+
to_geotiff(da_ours, our_path, compression=compression, tiled=False)
183183

184184
rio_da = xr.open_dataarray(our_path, engine='rasterio')
185185
rio_arr = rio_da.squeeze('band').values if 'band' in rio_da.dims else rio_da.values
@@ -198,7 +198,7 @@ def bench_round_trip(shape=(256, 256), compression='deflate'):
198198
else:
199199
da_rio.rio.to_raster(rio_path)
200200

201-
our_da = read_geotiff(rio_path)
201+
our_da = open_geotiff(rio_path)
202202
our_arr = our_da.values
203203

204204
diff2 = float(np.nanmax(np.abs(arr - our_arr)))

xrspatial/geotiff/tests/test_cog.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import xarray as xr
77

8-
from xrspatial.geotiff import read_geotiff, write_geotiff
8+
from xrspatial.geotiff import open_geotiff, to_geotiff
99
from xrspatial.geotiff._header import parse_header, parse_all_ifds
1010
from xrspatial.geotiff._writer import write
1111
from xrspatial.geotiff._geotags import GeoTransform, extract_geo_info
@@ -83,52 +83,52 @@ def test_read_write_round_trip(self, tmp_path):
8383
)
8484

8585
path = str(tmp_path / 'round_trip.tif')
86-
write_geotiff(da, path, compression='deflate', tiled=False)
86+
to_geotiff(da, path, compression='deflate', tiled=False)
8787

88-
result = read_geotiff(path)
88+
result = open_geotiff(path)
8989
np.testing.assert_array_almost_equal(result.values, data, decimal=5)
9090
assert result.attrs.get('crs') == 4326
9191

92-
def test_read_geotiff_name(self, tmp_path):
92+
def test_open_geotiff_name(self, tmp_path):
9393
"""DataArray name defaults to filename stem."""
9494
arr = np.zeros((4, 4), dtype=np.float32)
9595
path = str(tmp_path / 'myfile.tif')
9696
write(arr, path, compression='none', tiled=False)
9797

98-
da = read_geotiff(path)
98+
da = open_geotiff(path)
9999
assert da.name == 'myfile'
100100

101-
def test_read_geotiff_custom_name(self, tmp_path):
101+
def test_open_geotiff_custom_name(self, tmp_path):
102102
arr = np.zeros((4, 4), dtype=np.float32)
103103
path = str(tmp_path / 'test.tif')
104104
write(arr, path, compression='none', tiled=False)
105105

106-
da = read_geotiff(path, name='custom')
106+
da = open_geotiff(path, name='custom')
107107
assert da.name == 'custom'
108108

109109
def test_write_numpy_array(self, tmp_path):
110-
"""write_geotiff should accept raw numpy arrays too."""
110+
"""to_geotiff should accept raw numpy arrays too."""
111111
arr = np.arange(16, dtype=np.float32).reshape(4, 4)
112112
path = str(tmp_path / 'numpy.tif')
113-
write_geotiff(arr, path, compression='none')
113+
to_geotiff(arr, path, compression='none')
114114

115-
result = read_geotiff(path)
115+
result = open_geotiff(path)
116116
np.testing.assert_array_equal(result.values, arr)
117117

118118
def test_write_3d_rgb(self, tmp_path):
119119
"""3D arrays (height, width, bands) should write multi-band."""
120120
arr = np.zeros((4, 4, 3), dtype=np.uint8)
121121
arr[:, :, 0] = 255 # red channel
122122
path = str(tmp_path / 'rgb.tif')
123-
write_geotiff(arr, path, compression='none')
123+
to_geotiff(arr, path, compression='none')
124124

125-
result = read_geotiff(path)
125+
result = open_geotiff(path)
126126
np.testing.assert_array_equal(result.values, arr)
127127

128128
def test_write_rejects_4d(self, tmp_path):
129129
arr = np.zeros((2, 3, 4, 4), dtype=np.float32)
130130
with pytest.raises(ValueError, match="Expected 2D or 3D"):
131-
write_geotiff(arr, str(tmp_path / 'bad.tif'))
131+
to_geotiff(arr, str(tmp_path / 'bad.tif'))
132132

133133

134134
def read_to_array_local(path):

0 commit comments

Comments
 (0)