Skip to content

Commit d69d34f

Browse files
committed
Add GPU-accelerated TIFF reader via Numba CUDA
read_geotiff_gpu() decodes tiled GeoTIFFs on the GPU and returns a CuPy-backed DataArray that stays on device memory. No CPU->GPU transfer needed for downstream xrspatial GPU operations (slope, aspect, hillshade, etc.). CUDA kernels implemented: - LZW decode: one thread block per tile, LZW table in shared memory (20KB per block, fast on-chip SRAM) - Predictor decode (pred=2): one thread per row, horizontal cumsum - Float predictor (pred=3): one thread per row, byte-lane undiff + un-transpose - Tile assembly: one thread per pixel, copies from decompressed tile buffer to output image Supports LZW and uncompressed tiled TIFFs. Falls back to CPU for unsupported compression types or stripped files. 100% pixel-exact match with CPU reader on all tested files (USGS LZW+pred3 3612x3612, synthetic LZW tiled). Performance: GPU LZW is comparable to CPU (~330ms vs 270ms for 3612x3612) because LZW is inherently sequential per-stream. The value is in keeping data on GPU for end-to-end pipelines without CPU->GPU transfer overhead. Future work: CUDA inflate (deflate) kernel would unlock the parallel decompression win since deflate tiles are much more common in COGs.
1 parent f6b374e commit d69d34f

File tree

2 files changed

+531
-1
lines changed

2 files changed

+531
-1
lines changed

xrspatial/geotiff/__init__.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ._writer import write
2222

2323
__all__ = ['read_geotiff', 'write_geotiff', 'open_cog', 'read_geotiff_dask',
24-
'read_vrt', 'write_vrt']
24+
'read_vrt', 'write_vrt', 'read_geotiff_gpu']
2525

2626

2727
def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
@@ -510,6 +510,138 @@ def _read():
510510
return _read()
511511

512512

513+
def read_geotiff_gpu(source: str, *,
514+
overview_level: int | None = None,
515+
name: str | None = None) -> xr.DataArray:
516+
"""Read a GeoTIFF with GPU-accelerated decompression via Numba CUDA.
517+
518+
Decompresses all tiles in parallel on the GPU and returns a
519+
CuPy-backed DataArray that stays on device memory. No CPU->GPU
520+
transfer needed for downstream xrspatial GPU operations.
521+
522+
Supports LZW and uncompressed tiled TIFFs with predictor 1, 2, or 3.
523+
For unsupported compression types, falls back to CPU.
524+
525+
Requires: cupy, numba with CUDA support.
526+
527+
Parameters
528+
----------
529+
source : str
530+
File path.
531+
overview_level : int or None
532+
Overview level (0 = full resolution).
533+
name : str or None
534+
Name for the DataArray.
535+
536+
Returns
537+
-------
538+
xr.DataArray
539+
CuPy-backed DataArray on GPU device.
540+
"""
541+
try:
542+
import cupy
543+
except ImportError:
544+
raise ImportError(
545+
"cupy is required for GPU reads. "
546+
"Install it with: pip install cupy-cuda12x")
547+
548+
from ._reader import _FileSource
549+
from ._header import parse_header, parse_all_ifds
550+
from ._dtypes import tiff_dtype_to_numpy
551+
from ._geotags import extract_geo_info
552+
from ._gpu_decode import gpu_decode_tiles
553+
554+
# Parse metadata on CPU (fast, <1ms)
555+
src = _FileSource(source)
556+
data = src.read_all()
557+
558+
try:
559+
header = parse_header(data)
560+
ifds = parse_all_ifds(data, header)
561+
562+
if len(ifds) == 0:
563+
raise ValueError("No IFDs found in TIFF file")
564+
565+
ifd_idx = 0
566+
if overview_level is not None:
567+
ifd_idx = min(overview_level, len(ifds) - 1)
568+
ifd = ifds[ifd_idx]
569+
570+
bps = ifd.bits_per_sample
571+
if isinstance(bps, tuple):
572+
bps = bps[0]
573+
dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
574+
geo_info = extract_geo_info(ifd, data, header.byte_order)
575+
576+
if not ifd.is_tiled:
577+
# Fall back to CPU for stripped files
578+
src.close()
579+
arr_cpu, _ = read_to_array(source, overview_level=overview_level)
580+
arr_gpu = cupy.asarray(arr_cpu)
581+
coords = _geo_to_coords(geo_info, arr_gpu.shape[0], arr_gpu.shape[1])
582+
if name is None:
583+
import os
584+
name = os.path.splitext(os.path.basename(source))[0]
585+
attrs = {}
586+
if geo_info.crs_epsg is not None:
587+
attrs['crs'] = geo_info.crs_epsg
588+
return xr.DataArray(arr_gpu, dims=['y', 'x'],
589+
coords=coords, name=name, attrs=attrs)
590+
591+
# Extract compressed tile bytes
592+
offsets = ifd.tile_offsets
593+
byte_counts = ifd.tile_byte_counts
594+
compressed_tiles = []
595+
for i in range(len(offsets)):
596+
compressed_tiles.append(
597+
bytes(data[offsets[i]:offsets[i] + byte_counts[i]]))
598+
599+
compression = ifd.compression
600+
predictor = ifd.predictor
601+
samples = ifd.samples_per_pixel
602+
tw = ifd.tile_width
603+
th = ifd.tile_height
604+
width = ifd.width
605+
height = ifd.height
606+
607+
finally:
608+
src.close()
609+
610+
# GPU decode
611+
try:
612+
arr_gpu = gpu_decode_tiles(
613+
compressed_tiles,
614+
tw, th, width, height,
615+
compression, predictor, dtype, samples,
616+
)
617+
except ValueError:
618+
# Unsupported compression -- fall back to CPU then transfer
619+
arr_cpu, _ = read_to_array(source, overview_level=overview_level)
620+
arr_gpu = cupy.asarray(arr_cpu)
621+
622+
# Build DataArray
623+
if name is None:
624+
import os
625+
name = os.path.splitext(os.path.basename(source))[0]
626+
627+
coords = _geo_to_coords(geo_info, height, width)
628+
629+
attrs = {}
630+
if geo_info.crs_epsg is not None:
631+
attrs['crs'] = geo_info.crs_epsg
632+
if geo_info.crs_wkt is not None:
633+
attrs['crs_wkt'] = geo_info.crs_wkt
634+
635+
if arr_gpu.ndim == 3:
636+
dims = ['y', 'x', 'band']
637+
coords['band'] = np.arange(arr_gpu.shape[2])
638+
else:
639+
dims = ['y', 'x']
640+
641+
return xr.DataArray(arr_gpu, dims=dims, coords=coords,
642+
name=name, attrs=attrs)
643+
644+
513645
def read_vrt(source: str, *, window=None,
514646
band: int | None = None,
515647
name: str | None = None) -> xr.DataArray:

0 commit comments

Comments
 (0)