Skip to content

Commit b1ed372

Browse files
committed
Add GPU-accelerated GeoTIFF write via nvCOMP batch compress
write_geotiff_gpu() compresses tiles on the GPU and writes a valid GeoTIFF. The CuPy array stays on device throughout -- only the compressed bytes transfer to CPU for file assembly. GPU pipeline: CuPy array → tile extraction (CUDA kernel) → predictor encode (CUDA kernel) → nvCOMP batch compress → CPU file assembly CUDA kernels added: - _extract_tiles_kernel: image → per-tile buffers (1 thread/pixel) - _predictor_encode_kernel: horizontal differencing (1 thread/row) - _fp_predictor_encode_kernel: float predictor (1 thread/row) - _nvcomp_batch_compress: deflate + ZSTD via nvCOMP C API Deflate write performance (tiled 256, A6000): 2048x2048: GPU 135ms vs CPU 424ms = 3.1x faster 4096x4096: GPU 302ms vs CPU 1678ms = 5.6x faster 8192x8192: GPU 1114ms vs CPU 6837ms = 6.1x faster GPU deflate is also 1.5-1.8x faster than rioxarray/GDAL at 4K+. All round-trips verified pixel-exact (deflate, ZSTD, uncompressed).
1 parent ce64901 commit b1ed372

File tree

2 files changed

+432
-1
lines changed

2 files changed

+432
-1
lines changed

xrspatial/geotiff/__init__.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ._writer import write
2222

2323
__all__ = ['read_geotiff', 'write_geotiff', 'open_cog', 'read_geotiff_dask',
24-
'read_vrt', 'write_vrt', 'read_geotiff_gpu']
24+
'read_vrt', 'write_vrt', 'read_geotiff_gpu', 'write_geotiff_gpu']
2525

2626

2727
def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
@@ -661,6 +661,116 @@ def read_geotiff_gpu(source: str, *,
661661
name=name, attrs=attrs)
662662

663663

664+
def write_geotiff_gpu(data, path: str, *,
665+
crs: int | str | None = None,
666+
nodata=None,
667+
compression: str = 'zstd',
668+
tile_size: int = 256,
669+
predictor: bool = False) -> None:
670+
"""Write a CuPy-backed DataArray as a GeoTIFF with GPU compression.
671+
672+
Tiles are extracted and compressed on the GPU via nvCOMP, then
673+
assembled into a TIFF file on CPU. The CuPy array stays on device
674+
throughout compression -- only the compressed bytes transfer to CPU
675+
for file writing.
676+
677+
Falls back to CPU compression if nvCOMP is not available.
678+
679+
Parameters
680+
----------
681+
data : xr.DataArray (CuPy-backed) or cupy.ndarray
682+
2D raster on GPU.
683+
path : str
684+
Output file path.
685+
crs : int, str, or None
686+
EPSG code or WKT string.
687+
nodata : float, int, or None
688+
NoData value.
689+
compression : str
690+
'zstd' (default, fastest on GPU), 'deflate', or 'none'.
691+
tile_size : int
692+
Tile size in pixels (default 256).
693+
predictor : bool
694+
Apply horizontal differencing predictor.
695+
"""
696+
try:
697+
import cupy
698+
except ImportError:
699+
raise ImportError("cupy is required for GPU writes")
700+
701+
from ._gpu_decode import gpu_compress_tiles
702+
from ._writer import (
703+
_compression_tag, _assemble_tiff, _write_bytes,
704+
GeoTransform as _GT,
705+
)
706+
from ._dtypes import numpy_to_tiff_dtype
707+
708+
# Extract array and metadata
709+
geo_transform = None
710+
epsg = None
711+
raster_type = 1
712+
713+
if isinstance(crs, int):
714+
epsg = crs
715+
elif isinstance(crs, str):
716+
epsg = _wkt_to_epsg(crs)
717+
718+
if isinstance(data, xr.DataArray):
719+
arr = data.data # keep as cupy
720+
if hasattr(arr, 'get'):
721+
# It's a CuPy array
722+
pass
723+
else:
724+
# Numpy DataArray -- send to GPU
725+
arr = cupy.asarray(data.values)
726+
727+
geo_transform = _coords_to_transform(data)
728+
if epsg is None:
729+
epsg = data.attrs.get('crs')
730+
if nodata is None:
731+
nodata = data.attrs.get('nodata')
732+
if data.attrs.get('raster_type') == 'point':
733+
raster_type = RASTER_PIXEL_IS_POINT
734+
else:
735+
arr = cupy.asarray(data) if not hasattr(data, 'device') else data
736+
737+
if arr.ndim not in (2, 3):
738+
raise ValueError(f"Expected 2D or 3D array, got {arr.ndim}D")
739+
740+
height, width = arr.shape[:2]
741+
samples = arr.shape[2] if arr.ndim == 3 else 1
742+
np_dtype = np.dtype(str(arr.dtype)) # cupy dtype -> numpy dtype
743+
744+
comp_tag = _compression_tag(compression)
745+
pred_val = 2 if predictor else 1
746+
747+
# GPU compress
748+
compressed_tiles = gpu_compress_tiles(
749+
arr, tile_size, tile_size, width, height,
750+
comp_tag, pred_val, np_dtype, samples)
751+
752+
# Build offset/bytecount lists
753+
rel_offsets = []
754+
byte_counts = []
755+
offset = 0
756+
for tile in compressed_tiles:
757+
rel_offsets.append(offset)
758+
byte_counts.append(len(tile))
759+
offset += len(tile)
760+
761+
# Assemble TIFF on CPU (only metadata + compressed bytes)
762+
# _assemble_tiff needs an array in parts[0] to detect samples_per_pixel
763+
shape_stub = np.empty((1, 1, samples) if samples > 1 else (1, 1), dtype=np_dtype)
764+
parts = [(shape_stub, width, height, rel_offsets, byte_counts, compressed_tiles)]
765+
766+
file_bytes = _assemble_tiff(
767+
width, height, np_dtype, comp_tag, predictor, True, tile_size,
768+
parts, geo_transform, epsg, nodata, is_cog=False,
769+
raster_type=raster_type)
770+
771+
_write_bytes(file_bytes, path)
772+
773+
664774
def read_vrt(source: str, *, window=None,
665775
band: int | None = None,
666776
name: str | None = None) -> xr.DataArray:

0 commit comments

Comments
 (0)