Skip to content

Commit 230573c

Browse files
committed
Unified API: read_geotiff/write_geotiff auto-dispatch CPU/GPU/Dask
read_geotiff and write_geotiff now dispatch to the correct backend automatically: read_geotiff('dem.tif') # NumPy (default) read_geotiff('dem.tif', gpu=True) # CuPy via nvCOMP read_geotiff('dem.tif', chunks=512) # Dask lazy read_geotiff('dem.tif', gpu=True, chunks=512) # Dask+CuPy write_geotiff(numpy_arr, 'out.tif') # CPU write write_geotiff(cupy_arr, 'out.tif') # auto-detects CuPy -> GPU write write_geotiff(data, 'out.tif', gpu=True) # force GPU write Auto-detection: write_geotiff checks isinstance(data, cupy.ndarray) to decide whether to use GPU compression. Falls back to CPU if cupy is not installed or nvCOMP fails. read_vrt also supports gpu= and chunks= parameters for all four backend combinations. Users no longer need to call read_geotiff_gpu/write_geotiff_gpu directly -- the main functions handle everything.
1 parent 4c53027 commit 230573c

File tree

2 files changed

+100
-24
lines changed

2 files changed

+100
-24
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ Native GeoTIFF and Cloud Optimized GeoTIFF reader/writer. No GDAL required.
144144
| [read_geotiff_gpu](xrspatial/geotiff/__init__.py) | GPU-native read (nvCOMP + GDS) | | ✅️ | ✅️ | |
145145
| [write_geotiff_gpu](xrspatial/geotiff/__init__.py) | GPU-native write (nvCOMP batch compress) | 🔄 | ✅️ | ✅️ | |
146146
| [read_geotiff_dask](xrspatial/geotiff/__init__.py) | Dask lazy read via windowed chunks | | ✅️ | | |
147-
| [read_vrt / write_vrt](xrspatial/geotiff/__init__.py) | Virtual Raster Table mosaic | ✅️ | ✅️ | | |
147+
| [read_vrt / write_vrt](xrspatial/geotiff/__init__.py) | Virtual Raster Table mosaic | ✅️ | ✅️ | ✅️ | |
148148
| [open_cog](xrspatial/geotiff/__init__.py) | HTTP range-request COG reader | ✅️ | | | ✅️ |
149149

150150
**Compression codecs:** Deflate, LZW (Numba JIT), ZSTD, PackBits, JPEG (Pillow), uncompressed

xrspatial/geotiff/__init__.py

Lines changed: 99 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,18 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None:
102102
def read_geotiff(source: str, *, window=None,
103103
overview_level: int | None = None,
104104
band: int | None = None,
105-
name: str | None = None) -> xr.DataArray:
106-
"""Read a GeoTIFF or VRT file into an xarray.DataArray.
105+
name: str | None = None,
106+
chunks: int | tuple | None = None,
107+
gpu: bool = False) -> xr.DataArray:
108+
"""Read a GeoTIFF, COG, or VRT file into an xarray.DataArray.
107109
108-
VRT files (.vrt extension) are automatically detected and assembled
109-
from their source GeoTIFFs.
110+
Automatically dispatches to the best backend:
111+
- ``gpu=True``: GPU-accelerated read via nvCOMP (returns CuPy)
112+
- ``chunks=N``: Dask lazy read via windowed chunks
113+
- ``gpu=True, chunks=N``: Dask+CuPy for out-of-core GPU pipelines
114+
- Default: NumPy eager read
115+
116+
VRT files are auto-detected by extension.
110117
111118
Parameters
112119
----------
@@ -115,20 +122,35 @@ def read_geotiff(source: str, *, window=None,
115122
window : tuple or None
116123
(row_start, col_start, row_stop, col_stop) for windowed reading.
117124
overview_level : int or None
118-
Overview level to read (0 = full resolution). None reads full res.
119-
band : int
120-
Band index (0-based) for multi-band files.
125+
Overview level (0 = full resolution).
126+
band : int or None
127+
Band index (0-based). None returns all bands.
121128
name : str or None
122-
Name for the DataArray. Defaults to filename stem.
129+
Name for the DataArray.
130+
chunks : int, tuple, or None
131+
Chunk size for Dask lazy reading.
132+
gpu : bool
133+
Use GPU-accelerated decompression (requires cupy + nvCOMP).
123134
124135
Returns
125136
-------
126137
xr.DataArray
127-
2D DataArray with y/x coordinates and geo attributes.
138+
NumPy, Dask, CuPy, or Dask+CuPy backed depending on options.
128139
"""
129-
# Auto-detect VRT files
140+
# VRT files
130141
if source.lower().endswith('.vrt'):
131-
return read_vrt(source, window=window, band=band, name=name)
142+
return read_vrt(source, window=window, band=band, name=name,
143+
chunks=chunks, gpu=gpu)
144+
145+
# GPU path
146+
if gpu:
147+
return read_geotiff_gpu(source, overview_level=overview_level,
148+
name=name, chunks=chunks)
149+
150+
# Dask path (CPU)
151+
if chunks is not None:
152+
return read_geotiff_dask(source, chunks=chunks,
153+
overview_level=overview_level, name=name)
132154

133155
arr, geo_info = read_to_array(
134156
source, window=window,
@@ -247,6 +269,23 @@ def read_geotiff(source: str, *, window=None,
247269
return da
248270

249271

272+
def _is_gpu_data(data) -> bool:
273+
"""Check if data is CuPy-backed (raw array or DataArray)."""
274+
try:
275+
import cupy
276+
_cupy_type = cupy.ndarray
277+
except ImportError:
278+
return False
279+
280+
if isinstance(data, xr.DataArray):
281+
raw = data.data
282+
if hasattr(raw, 'compute'):
283+
meta = getattr(raw, '_meta', None)
284+
return isinstance(meta, _cupy_type)
285+
return isinstance(raw, _cupy_type)
286+
return isinstance(data, _cupy_type)
287+
288+
250289
def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
251290
crs: int | str | None = None,
252291
nodata=None,
@@ -257,9 +296,17 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
257296
cog: bool = False,
258297
overview_levels: list[int] | None = None,
259298
overview_resampling: str = 'mean',
260-
bigtiff: bool | None = None) -> None:
299+
bigtiff: bool | None = None,
300+
gpu: bool | None = None) -> None:
261301
"""Write data as a GeoTIFF or Cloud Optimized GeoTIFF.
262302
303+
Automatically dispatches to GPU compression when:
304+
- ``gpu=True`` is passed, or
305+
- The input data is CuPy-backed (auto-detected)
306+
307+
GPU write uses nvCOMP batch compression (deflate/ZSTD) and keeps
308+
the array on device. Falls back to CPU if nvCOMP is not available.
309+
263310
Parameters
264311
----------
265312
data : xr.DataArray or np.ndarray
@@ -287,7 +334,20 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
287334
overview_resampling : str
288335
Resampling method for overviews: 'mean' (default), 'nearest',
289336
'min', 'max', 'median', 'mode', or 'cubic'.
337+
gpu : bool or None
338+
Force GPU compression. None (default) auto-detects CuPy data.
290339
"""
340+
# Auto-detect GPU data and dispatch to write_geotiff_gpu
341+
use_gpu = gpu if gpu is not None else _is_gpu_data(data)
342+
if use_gpu:
343+
try:
344+
write_geotiff_gpu(data, path, crs=crs, nodata=nodata,
345+
compression=compression, tile_size=tile_size,
346+
predictor=predictor)
347+
return
348+
except (ImportError, Exception):
349+
pass # fall through to CPU path
350+
291351
geo_transform = None
292352
epsg = None
293353
raster_type = RASTER_PIXEL_IS_AREA
@@ -428,12 +488,9 @@ def read_geotiff_dask(source: str, *, chunks: int | tuple = 512,
428488
"""
429489
import dask.array as da
430490

431-
# VRT files: read eagerly (VRT mosaic isn't compatible with per-chunk
432-
# windowed reads on the virtual dataset without a separate code path)
491+
# VRT files: delegate to read_vrt which handles chunks
433492
if source.lower().endswith('.vrt'):
434-
da_eager = read_vrt(source, name=name)
435-
return da_eager.chunk({'y': chunks if isinstance(chunks, int) else chunks[0],
436-
'x': chunks if isinstance(chunks, int) else chunks[1]})
493+
return read_vrt(source, name=name, chunks=chunks)
437494

438495
# First, do a metadata-only read to get shape, dtype, coords, attrs
439496
arr, geo_info = read_to_array(source, overview_level=overview_level)
@@ -807,7 +864,9 @@ def write_geotiff_gpu(data, path: str, *,
807864

808865
def read_vrt(source: str, *, window=None,
809866
band: int | None = None,
810-
name: str | None = None) -> xr.DataArray:
867+
name: str | None = None,
868+
chunks: int | tuple | None = None,
869+
gpu: bool = False) -> xr.DataArray:
811870
"""Read a GDAL Virtual Raster Table (.vrt) into an xarray.DataArray.
812871
813872
The VRT's source GeoTIFFs are read via windowed reads and assembled
@@ -823,10 +882,16 @@ def read_vrt(source: str, *, window=None,
823882
Band index (0-based). None returns all bands.
824883
name : str or None
825884
Name for the DataArray.
885+
chunks : int, tuple, or None
886+
If set, return a Dask-chunked DataArray. int for square chunks,
887+
(row, col) tuple for rectangular.
888+
gpu : bool
889+
If True, return a CuPy-backed DataArray on GPU.
826890
827891
Returns
828892
-------
829893
xr.DataArray
894+
NumPy, Dask, CuPy, or Dask+CuPy backed depending on options.
830895
"""
831896
from ._vrt import read_vrt as _read_vrt_internal
832897

@@ -854,27 +919,38 @@ def read_vrt(source: str, *, window=None,
854919
coords = {}
855920

856921
attrs = {}
857-
858-
# CRS from VRT
859922
if vrt.crs_wkt:
860923
epsg = _wkt_to_epsg(vrt.crs_wkt)
861924
if epsg is not None:
862925
attrs['crs'] = epsg
863926
attrs['crs_wkt'] = vrt.crs_wkt
864-
865-
# Nodata from first band
866927
if vrt.bands:
867928
nodata = vrt.bands[0].nodata
868929
if nodata is not None:
869930
attrs['nodata'] = nodata
870931

932+
# Transfer to GPU if requested
933+
if gpu:
934+
import cupy
935+
arr = cupy.asarray(arr)
936+
871937
if arr.ndim == 3:
872938
dims = ['y', 'x', 'band']
873939
coords['band'] = np.arange(arr.shape[2])
874940
else:
875941
dims = ['y', 'x']
876942

877-
return xr.DataArray(arr, dims=dims, coords=coords, name=name, attrs=attrs)
943+
result = xr.DataArray(arr, dims=dims, coords=coords, name=name, attrs=attrs)
944+
945+
# Chunk for Dask (or Dask+CuPy if gpu=True)
946+
if chunks is not None:
947+
if isinstance(chunks, int):
948+
chunk_dict = {'y': chunks, 'x': chunks}
949+
else:
950+
chunk_dict = {'y': chunks[0], 'x': chunks[1]}
951+
result = result.chunk(chunk_dict)
952+
953+
return result
878954

879955

880956
def write_vrt(vrt_path: str, source_files: list[str], **kwargs) -> str:

0 commit comments

Comments
 (0)