Skip to content

Commit 53c63e3

Browse files
committed
Fix nvCOMP ctypes binding: ZSTD batch decompress working
Fixed the nvCOMP C API ctypes binding to pass opts structs by value using proper ctypes.Structure subclasses. The previous byte-array approach caused the struct to be misinterpreted by nvCOMP. Working: nvCOMP ZSTD batch decompress (nvcompBatchedZstdDecompressAsync) - 100% pixel-exact match on all tested files - 1.5x end-to-end speedup on 8192x8192 ZSTD with 1024 tiles (GPU pipeline: 404ms vs CPU+transfer: 620ms) Not working on Ampere: nvCOMP deflate returns nvcompErrorNotSupported (status 11). Deflate GPU decompression requires Ada Lovelace or newer GPU with HW decompression engine. Falls back to the Numba CUDA inflate kernel on Ampere. nvCOMP is auto-detected by searching for libnvcomp.so in CONDA_PREFIX and sibling conda environments. When found, ZSTD tiles are batch-decompressed in a single GPU API call.
1 parent 25c0d84 commit 53c63e3

1 file changed

Lines changed: 152 additions & 31 deletions

File tree

xrspatial/geotiff/_gpu_decode.py

Lines changed: 152 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -676,54 +676,175 @@ def _assemble_tiles_kernel(
676676
# nvCOMP batch decompression (optional, fast path)
677677
# ---------------------------------------------------------------------------
678678

679+
def _find_nvcomp_lib():
680+
"""Find and load libnvcomp.so. Returns ctypes.CDLL or None."""
681+
import ctypes
682+
import os
683+
684+
# Try common locations
685+
search_paths = [
686+
'libnvcomp.so', # system LD_LIBRARY_PATH
687+
]
688+
689+
# Check conda envs
690+
conda_prefix = os.environ.get('CONDA_PREFIX', '')
691+
if conda_prefix:
692+
search_paths.append(os.path.join(conda_prefix, 'lib', 'libnvcomp.so'))
693+
694+
# Also check sibling conda envs that might have rapids
695+
conda_base = os.path.dirname(conda_prefix) if conda_prefix else ''
696+
if conda_base:
697+
for env in ['rapids', 'test-again', 'rtxpy-fire']:
698+
p = os.path.join(conda_base, env, 'lib', 'libnvcomp.so')
699+
if os.path.exists(p):
700+
search_paths.append(p)
701+
702+
for path in search_paths:
703+
try:
704+
return ctypes.CDLL(path)
705+
except OSError:
706+
continue
707+
return None
708+
709+
710+
_nvcomp_lib = None
711+
_nvcomp_checked = False
712+
713+
714+
def _get_nvcomp():
715+
"""Get the nvCOMP library handle (cached). Returns CDLL or None."""
716+
global _nvcomp_lib, _nvcomp_checked
717+
if not _nvcomp_checked:
718+
_nvcomp_checked = True
719+
_nvcomp_lib = _find_nvcomp_lib()
720+
return _nvcomp_lib
721+
722+
679723
def _try_nvcomp_batch_decompress(compressed_tiles, tile_bytes, compression):
680-
"""Try batch decompression via nvCOMP. Returns CuPy array or None.
724+
"""Try batch decompression via nvCOMP C API. Returns CuPy array or None.
681725
682-
nvCOMP (NVIDIA's batched compression library) decompresses all tiles
683-
in a single GPU API call using optimized CUDA kernels. Falls back
684-
to None if nvCOMP is not available or doesn't support the codec.
726+
Uses nvcompBatchedDeflateDecompressAsync to decompress all tiles in
727+
one GPU API call. Falls back to None if nvCOMP is not available.
685728
"""
686-
try:
687-
import kvikio.nvcomp as nvcomp
688-
except ImportError:
729+
if compression not in (8, 32946, 50000): # Deflate and ZSTD
689730
return None
690731

691-
import cupy
692-
693-
codec_map = {
694-
8: 'deflate', # Deflate
695-
32946: 'deflate', # Adobe Deflate
696-
5: 'lzw', # LZW (nvCOMP doesn't support TIFF LZW variant)
697-
}
698-
codec_name = codec_map.get(compression)
699-
if codec_name is None:
700-
return None
732+
lib = _get_nvcomp()
733+
if lib is None:
734+
# Try kvikio.nvcomp as alternative
735+
try:
736+
import kvikio.nvcomp as nvcomp
737+
except ImportError:
738+
return None
701739

702-
# nvCOMP's DeflateManager handles batch deflate
703-
if codec_name == 'deflate':
740+
import cupy
704741
try:
705-
# Strip 2-byte zlib headers + 4-byte checksums from each tile
706742
raw_tiles = []
707743
for tile in compressed_tiles:
708-
# zlib format: 2-byte header, deflate data, 4-byte adler32
709744
raw_tiles.append(tile[2:-4] if len(tile) > 6 else tile)
710-
711745
manager = nvcomp.DeflateManager(chunk_size=tile_bytes)
712-
713-
# Copy compressed data to device
714746
d_compressed = [cupy.asarray(np.frombuffer(t, dtype=np.uint8))
715747
for t in raw_tiles]
716-
717-
# Batch decompress
718748
d_decompressed = manager.decompress(d_compressed)
719-
720-
# Concatenate results into a single buffer
721-
result = cupy.concatenate([d.ravel() for d in d_decompressed])
722-
return result
749+
return cupy.concatenate([d.ravel() for d in d_decompressed])
723750
except Exception:
724751
return None
725752

726-
return None
753+
# Direct ctypes nvCOMP C API
754+
import ctypes
755+
import cupy
756+
757+
class _NvcompDecompOpts(ctypes.Structure):
758+
"""nvCOMP batched decompression options (passed by value)."""
759+
_fields_ = [
760+
('backend', ctypes.c_int),
761+
('reserved', ctypes.c_char * 60),
762+
]
763+
764+
# Deflate has a different struct with sort_before_hw_decompress field
765+
class _NvcompDeflateDecompOpts(ctypes.Structure):
766+
_fields_ = [
767+
('backend', ctypes.c_int),
768+
('sort_before_hw_decompress', ctypes.c_int),
769+
('reserved', ctypes.c_char * 56),
770+
]
771+
772+
try:
773+
n_tiles = len(compressed_tiles)
774+
775+
# Prepare compressed tiles for nvCOMP
776+
if compression in (8, 32946): # Deflate
777+
# Strip 2-byte zlib header + 4-byte adler32 checksum
778+
raw_tiles = [t[2:-4] if len(t) > 6 else t for t in compressed_tiles]
779+
get_temp_fn = 'nvcompBatchedDeflateDecompressGetTempSizeAsync'
780+
decomp_fn = 'nvcompBatchedDeflateDecompressAsync'
781+
opts = _NvcompDeflateDecompOpts(backend=0, sort_before_hw_decompress=0,
782+
reserved=b'\x00' * 56)
783+
elif compression == 50000: # ZSTD
784+
raw_tiles = list(compressed_tiles) # no header stripping
785+
get_temp_fn = 'nvcompBatchedZstdDecompressGetTempSizeAsync'
786+
decomp_fn = 'nvcompBatchedZstdDecompressAsync'
787+
opts = _NvcompDecompOpts(backend=0, reserved=b'\x00' * 60)
788+
else:
789+
return None
790+
791+
# Upload compressed tiles to device
792+
d_comp_bufs = [cupy.asarray(np.frombuffer(t, dtype=np.uint8)) for t in raw_tiles]
793+
d_decomp_bufs = [cupy.empty(tile_bytes, dtype=cupy.uint8) for _ in range(n_tiles)]
794+
795+
d_comp_ptrs = cupy.array([b.data.ptr for b in d_comp_bufs], dtype=cupy.uint64)
796+
d_decomp_ptrs = cupy.array([b.data.ptr for b in d_decomp_bufs], dtype=cupy.uint64)
797+
d_comp_sizes = cupy.array([len(t) for t in raw_tiles], dtype=cupy.uint64)
798+
d_buf_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.uint64)
799+
d_actual = cupy.empty(n_tiles, dtype=cupy.uint64)
800+
801+
# Set argtypes for proper struct passing
802+
temp_fn = getattr(lib, get_temp_fn)
803+
temp_fn.restype = ctypes.c_int
804+
805+
temp_size = ctypes.c_size_t(0)
806+
status = temp_fn(
807+
ctypes.c_size_t(n_tiles),
808+
ctypes.c_size_t(tile_bytes),
809+
opts,
810+
ctypes.byref(temp_size),
811+
ctypes.c_size_t(n_tiles * tile_bytes),
812+
)
813+
if status != 0:
814+
return None
815+
816+
ts = max(temp_size.value, 1)
817+
d_temp = cupy.empty(ts, dtype=cupy.uint8)
818+
d_statuses = cupy.zeros(n_tiles, dtype=cupy.int32)
819+
820+
dec_fn = getattr(lib, decomp_fn)
821+
dec_fn.restype = ctypes.c_int
822+
823+
status = dec_fn(
824+
ctypes.c_void_p(d_comp_ptrs.data.ptr),
825+
ctypes.c_void_p(d_comp_sizes.data.ptr),
826+
ctypes.c_void_p(d_buf_sizes.data.ptr),
827+
ctypes.c_void_p(d_actual.data.ptr),
828+
ctypes.c_size_t(n_tiles),
829+
ctypes.c_void_p(d_temp.data.ptr),
830+
ctypes.c_size_t(ts),
831+
ctypes.c_void_p(d_decomp_ptrs.data.ptr),
832+
opts,
833+
ctypes.c_void_p(d_statuses.data.ptr),
834+
ctypes.c_void_p(0), # default stream
835+
)
836+
if status != 0:
837+
return None
838+
839+
cupy.cuda.Device().synchronize()
840+
841+
if int(cupy.any(d_statuses != 0)):
842+
return None
843+
844+
return cupy.concatenate(d_decomp_bufs)
845+
846+
except Exception:
847+
return None
727848

728849

729850
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)