Skip to content

Commit 4c53027

Browse files
committed
Enable Dask+CuPy for GPU read and write
read_geotiff_gpu: - New chunks= parameter returns a Dask+CuPy DataArray - read_geotiff_gpu('dem.tif', chunks=512) decompresses on GPU then chunks the result for out-of-core GPU pipelines write_geotiff_gpu: - Accepts Dask+CuPy DataArrays (.compute() then compress on GPU) - Accepts Dask+NumPy DataArrays (.compute() then transfer to GPU) - Accepts raw CuPy, numpy, or list inputs All 7 input combinations verified: read_geotiff_gpu -> CuPy DataArray (existing) read_geotiff_gpu(chunks=N) -> Dask+CuPy DataArray (new) write_geotiff_gpu(cupy_array) (existing) write_geotiff_gpu(cupy_DataArray) (existing) write_geotiff_gpu(dask_cupy_DataArray) (new) write_geotiff_gpu(numpy_array) (auto-transfer) write_geotiff_gpu(dask_numpy_DataArray) (auto-compute+transfer) Also fixed write_geotiff CuPy fallback for raw arrays and Dask+CuPy DataArrays (compute then .get() to numpy).
1 parent 9cca00b commit 4c53027

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ Native GeoTIFF and Cloud Optimized GeoTIFF reader/writer. No GDAL required.
141141
|:-----|:------------|:-----:|:----:|:--------:|:-----:|
142142
| [read_geotiff](xrspatial/geotiff/__init__.py) | Read GeoTIFF / COG / VRT to DataArray | ✅️ | ✅️ | ✅️ | ✅️ |
143143
| [write_geotiff](xrspatial/geotiff/__init__.py) | Write DataArray as GeoTIFF / COG | ✅️ | ✅️ | 🔄 | ✅️ |
144-
| [read_geotiff_gpu](xrspatial/geotiff/__init__.py) | GPU-native read (nvCOMP + GDS) | | | ✅️ | |
145-
| [write_geotiff_gpu](xrspatial/geotiff/__init__.py) | GPU-native write (nvCOMP batch compress) | | | ✅️ | |
144+
| [read_geotiff_gpu](xrspatial/geotiff/__init__.py) | GPU-native read (nvCOMP + GDS) | | ✅️ | ✅️ | |
145+
| [write_geotiff_gpu](xrspatial/geotiff/__init__.py) | GPU-native write (nvCOMP batch compress) | 🔄 | ✅️ | ✅️ | |
146146
| [read_geotiff_dask](xrspatial/geotiff/__init__.py) | Dask lazy read via windowed chunks | | ✅️ | | |
147147
| [read_vrt / write_vrt](xrspatial/geotiff/__init__.py) | Virtual Raster Table mosaic | ✅️ | ✅️ | | |
148148
| [open_cog](xrspatial/geotiff/__init__.py) | HTTP range-request COG reader | ✅️ | | | ✅️ |

xrspatial/geotiff/__init__.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -524,15 +524,16 @@ def _read():
524524

525525
def read_geotiff_gpu(source: str, *,
526526
overview_level: int | None = None,
527-
name: str | None = None) -> xr.DataArray:
527+
name: str | None = None,
528+
chunks: int | tuple | None = None) -> xr.DataArray:
528529
"""Read a GeoTIFF with GPU-accelerated decompression via Numba CUDA.
529530
530531
Decompresses all tiles in parallel on the GPU and returns a
531532
CuPy-backed DataArray that stays on device memory. No CPU->GPU
532533
transfer needed for downstream xrspatial GPU operations.
533534
534-
Supports LZW and uncompressed tiled TIFFs with predictor 1, 2, or 3.
535-
For unsupported compression types, falls back to CPU.
535+
With ``chunks=``, returns a Dask+CuPy DataArray for out-of-core
536+
GPU pipelines.
536537
537538
Requires: cupy, numba with CUDA support.
538539
@@ -542,6 +543,9 @@ def read_geotiff_gpu(source: str, *,
542543
File path.
543544
overview_level : int or None
544545
Overview level (0 = full resolution).
546+
chunks : int, tuple, or None
547+
If set, return a Dask-chunked CuPy DataArray. int for square
548+
chunks, (row, col) tuple for rectangular.
545549
name : str or None
546550
Name for the DataArray.
547551
@@ -669,8 +673,17 @@ def read_geotiff_gpu(source: str, *,
669673
else:
670674
dims = ['y', 'x']
671675

672-
return xr.DataArray(arr_gpu, dims=dims, coords=coords,
673-
name=name, attrs=attrs)
676+
result = xr.DataArray(arr_gpu, dims=dims, coords=coords,
677+
name=name, attrs=attrs)
678+
679+
if chunks is not None:
680+
if isinstance(chunks, int):
681+
chunk_dict = {'y': chunks, 'x': chunks}
682+
else:
683+
chunk_dict = {'y': chunks[0], 'x': chunks[1]}
684+
result = result.chunk(chunk_dict)
685+
686+
return result
674687

675688

676689
def write_geotiff_gpu(data, path: str, *,
@@ -728,13 +741,15 @@ def write_geotiff_gpu(data, path: str, *,
728741
epsg = _wkt_to_epsg(crs)
729742

730743
if isinstance(data, xr.DataArray):
731-
arr = data.data # keep as cupy
744+
arr = data.data
745+
# Handle Dask arrays: compute to materialize
746+
if hasattr(arr, 'compute'):
747+
arr = arr.compute()
748+
# Now arr should be CuPy or numpy
732749
if hasattr(arr, 'get'):
733-
# It's a CuPy array
734-
pass
750+
pass # CuPy array, already on GPU
735751
else:
736-
# Numpy DataArray -- send to GPU
737-
arr = cupy.asarray(data.values)
752+
arr = cupy.asarray(np.asarray(arr)) # numpy -> GPU
738753

739754
geo_transform = _coords_to_transform(data)
740755
if epsg is None:
@@ -744,7 +759,14 @@ def write_geotiff_gpu(data, path: str, *,
744759
if data.attrs.get('raster_type') == 'point':
745760
raster_type = RASTER_PIXEL_IS_POINT
746761
else:
747-
arr = cupy.asarray(data) if not hasattr(data, 'device') else data
762+
if hasattr(data, 'compute'):
763+
data = data.compute() # Dask -> CuPy or numpy
764+
if hasattr(data, 'device'):
765+
arr = data # already CuPy
766+
elif hasattr(data, 'get'):
767+
arr = data # CuPy
768+
else:
769+
arr = cupy.asarray(np.asarray(data)) # numpy/list -> GPU
748770

749771
if arr.ndim not in (2, 3):
750772
raise ValueError(f"Expected 2D or 3D array, got {arr.ndim}D")

0 commit comments

Comments
 (0)