@@ -102,11 +102,18 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None:
102102def read_geotiff (source : str , * , window = None ,
103103 overview_level : int | None = None ,
104104 band : int | None = None ,
105- name : str | None = None ) -> xr .DataArray :
106- """Read a GeoTIFF or VRT file into an xarray.DataArray.
105+ name : str | None = None ,
106+ chunks : int | tuple | None = None ,
107+ gpu : bool = False ) -> xr .DataArray :
108+ """Read a GeoTIFF, COG, or VRT file into an xarray.DataArray.
107109
108- VRT files (.vrt extension) are automatically detected and assembled
109- from their source GeoTIFFs.
110+ Automatically dispatches to the best backend:
111+ - ``gpu=True``: GPU-accelerated read via nvCOMP (returns CuPy)
112+ - ``chunks=N``: Dask lazy read via windowed chunks
113+ - ``gpu=True, chunks=N``: Dask+CuPy for out-of-core GPU pipelines
114+ - Default: NumPy eager read
115+
116+ VRT files are auto-detected by extension.
110117
111118 Parameters
112119 ----------
@@ -115,20 +122,35 @@ def read_geotiff(source: str, *, window=None,
115122 window : tuple or None
116123 (row_start, col_start, row_stop, col_stop) for windowed reading.
117124 overview_level : int or None
118- Overview level to read (0 = full resolution). None reads full res .
119- band : int
120- Band index (0-based) for multi-band files .
125+ Overview level (0 = full resolution).
126+ band : int or None
127+ Band index (0-based). None returns all bands .
121128 name : str or None
122- Name for the DataArray. Defaults to filename stem.
129+ Name for the DataArray.
130+ chunks : int, tuple, or None
131+ Chunk size for Dask lazy reading.
132+ gpu : bool
133+ Use GPU-accelerated decompression (requires cupy + nvCOMP).
123134
124135 Returns
125136 -------
126137 xr.DataArray
127- 2D DataArray with y/x coordinates and geo attributes .
138+ NumPy, Dask, CuPy, or Dask+CuPy backed depending on options .
128139 """
129- # Auto-detect VRT files
140+ # VRT files
130141 if source .lower ().endswith ('.vrt' ):
131- return read_vrt (source , window = window , band = band , name = name )
142+ return read_vrt (source , window = window , band = band , name = name ,
143+ chunks = chunks , gpu = gpu )
144+
145+ # GPU path
146+ if gpu :
147+ return read_geotiff_gpu (source , overview_level = overview_level ,
148+ name = name , chunks = chunks )
149+
150+ # Dask path (CPU)
151+ if chunks is not None :
152+ return read_geotiff_dask (source , chunks = chunks ,
153+ overview_level = overview_level , name = name )
132154
133155 arr , geo_info = read_to_array (
134156 source , window = window ,
@@ -247,6 +269,23 @@ def read_geotiff(source: str, *, window=None,
247269 return da
248270
249271
272+ def _is_gpu_data (data ) -> bool :
273+ """Check if data is CuPy-backed (raw array or DataArray)."""
274+ try :
275+ import cupy
276+ _cupy_type = cupy .ndarray
277+ except ImportError :
278+ return False
279+
280+ if isinstance (data , xr .DataArray ):
281+ raw = data .data
282+ if hasattr (raw , 'compute' ):
283+ meta = getattr (raw , '_meta' , None )
284+ return isinstance (meta , _cupy_type )
285+ return isinstance (raw , _cupy_type )
286+ return isinstance (data , _cupy_type )
287+
288+
250289def write_geotiff (data : xr .DataArray | np .ndarray , path : str , * ,
251290 crs : int | str | None = None ,
252291 nodata = None ,
@@ -257,9 +296,17 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
257296 cog : bool = False ,
258297 overview_levels : list [int ] | None = None ,
259298 overview_resampling : str = 'mean' ,
260- bigtiff : bool | None = None ) -> None :
299+ bigtiff : bool | None = None ,
300+ gpu : bool | None = None ) -> None :
261301 """Write data as a GeoTIFF or Cloud Optimized GeoTIFF.
262302
303+ Automatically dispatches to GPU compression when:
304+ - ``gpu=True`` is passed, or
305+ - The input data is CuPy-backed (auto-detected)
306+
307+ GPU write uses nvCOMP batch compression (deflate/ZSTD) and keeps
308+ the array on device. Falls back to CPU if nvCOMP is not available.
309+
263310 Parameters
264311 ----------
265312 data : xr.DataArray or np.ndarray
@@ -287,7 +334,20 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
287334 overview_resampling : str
288335 Resampling method for overviews: 'mean' (default), 'nearest',
289336 'min', 'max', 'median', 'mode', or 'cubic'.
337+ gpu : bool or None
338+ Force GPU compression. None (default) auto-detects CuPy data.
290339 """
340+ # Auto-detect GPU data and dispatch to write_geotiff_gpu
341+ use_gpu = gpu if gpu is not None else _is_gpu_data (data )
342+ if use_gpu :
343+ try :
344+ write_geotiff_gpu (data , path , crs = crs , nodata = nodata ,
345+ compression = compression , tile_size = tile_size ,
346+ predictor = predictor )
347+ return
348+ except (ImportError , Exception ):
349+ pass # fall through to CPU path
350+
291351 geo_transform = None
292352 epsg = None
293353 raster_type = RASTER_PIXEL_IS_AREA
@@ -428,12 +488,9 @@ def read_geotiff_dask(source: str, *, chunks: int | tuple = 512,
428488 """
429489 import dask .array as da
430490
431- # VRT files: read eagerly (VRT mosaic isn't compatible with per-chunk
432- # windowed reads on the virtual dataset without a separate code path)
491+ # VRT files: delegate to read_vrt which handles chunks
433492 if source .lower ().endswith ('.vrt' ):
434- da_eager = read_vrt (source , name = name )
435- return da_eager .chunk ({'y' : chunks if isinstance (chunks , int ) else chunks [0 ],
436- 'x' : chunks if isinstance (chunks , int ) else chunks [1 ]})
493+ return read_vrt (source , name = name , chunks = chunks )
437494
438495 # First, do a metadata-only read to get shape, dtype, coords, attrs
439496 arr , geo_info = read_to_array (source , overview_level = overview_level )
@@ -807,7 +864,9 @@ def write_geotiff_gpu(data, path: str, *,
807864
808865def read_vrt (source : str , * , window = None ,
809866 band : int | None = None ,
810- name : str | None = None ) -> xr .DataArray :
867+ name : str | None = None ,
868+ chunks : int | tuple | None = None ,
869+ gpu : bool = False ) -> xr .DataArray :
811870 """Read a GDAL Virtual Raster Table (.vrt) into an xarray.DataArray.
812871
813872 The VRT's source GeoTIFFs are read via windowed reads and assembled
@@ -823,10 +882,16 @@ def read_vrt(source: str, *, window=None,
823882 Band index (0-based). None returns all bands.
824883 name : str or None
825884 Name for the DataArray.
885+ chunks : int, tuple, or None
886+ If set, return a Dask-chunked DataArray. int for square chunks,
887+ (row, col) tuple for rectangular.
888+ gpu : bool
889+ If True, return a CuPy-backed DataArray on GPU.
826890
827891 Returns
828892 -------
829893 xr.DataArray
894+ NumPy, Dask, CuPy, or Dask+CuPy backed depending on options.
830895 """
831896 from ._vrt import read_vrt as _read_vrt_internal
832897
@@ -854,27 +919,38 @@ def read_vrt(source: str, *, window=None,
854919 coords = {}
855920
856921 attrs = {}
857-
858- # CRS from VRT
859922 if vrt .crs_wkt :
860923 epsg = _wkt_to_epsg (vrt .crs_wkt )
861924 if epsg is not None :
862925 attrs ['crs' ] = epsg
863926 attrs ['crs_wkt' ] = vrt .crs_wkt
864-
865- # Nodata from first band
866927 if vrt .bands :
867928 nodata = vrt .bands [0 ].nodata
868929 if nodata is not None :
869930 attrs ['nodata' ] = nodata
870931
932+ # Transfer to GPU if requested
933+ if gpu :
934+ import cupy
935+ arr = cupy .asarray (arr )
936+
871937 if arr .ndim == 3 :
872938 dims = ['y' , 'x' , 'band' ]
873939 coords ['band' ] = np .arange (arr .shape [2 ])
874940 else :
875941 dims = ['y' , 'x' ]
876942
877- return xr .DataArray (arr , dims = dims , coords = coords , name = name , attrs = attrs )
943+ result = xr .DataArray (arr , dims = dims , coords = coords , name = name , attrs = attrs )
944+
945+ # Chunk for Dask (or Dask+CuPy if gpu=True)
946+ if chunks is not None :
947+ if isinstance (chunks , int ):
948+ chunk_dict = {'y' : chunks , 'x' : chunks }
949+ else :
950+ chunk_dict = {'y' : chunks [0 ], 'x' : chunks [1 ]}
951+ result = result .chunk (chunk_dict )
952+
953+ return result
878954
879955
880956def write_vrt (vrt_path : str , source_files : list [str ], ** kwargs ) -> str :
0 commit comments