@@ -676,54 +676,175 @@ def _assemble_tiles_kernel(
676676# nvCOMP batch decompression (optional, fast path)
677677# ---------------------------------------------------------------------------
678678
679+ def _find_nvcomp_lib ():
680+ """Find and load libnvcomp.so. Returns ctypes.CDLL or None."""
681+ import ctypes
682+ import os
683+
684+ # Try common locations
685+ search_paths = [
686+ 'libnvcomp.so' , # system LD_LIBRARY_PATH
687+ ]
688+
689+ # Check conda envs
690+ conda_prefix = os .environ .get ('CONDA_PREFIX' , '' )
691+ if conda_prefix :
692+ search_paths .append (os .path .join (conda_prefix , 'lib' , 'libnvcomp.so' ))
693+
694+ # Also check sibling conda envs that might have rapids
695+ conda_base = os .path .dirname (conda_prefix ) if conda_prefix else ''
696+ if conda_base :
697+ for env in ['rapids' , 'test-again' , 'rtxpy-fire' ]:
698+ p = os .path .join (conda_base , env , 'lib' , 'libnvcomp.so' )
699+ if os .path .exists (p ):
700+ search_paths .append (p )
701+
702+ for path in search_paths :
703+ try :
704+ return ctypes .CDLL (path )
705+ except OSError :
706+ continue
707+ return None
708+
709+
710+ _nvcomp_lib = None
711+ _nvcomp_checked = False
712+
713+
714+ def _get_nvcomp ():
715+ """Get the nvCOMP library handle (cached). Returns CDLL or None."""
716+ global _nvcomp_lib , _nvcomp_checked
717+ if not _nvcomp_checked :
718+ _nvcomp_checked = True
719+ _nvcomp_lib = _find_nvcomp_lib ()
720+ return _nvcomp_lib
721+
722+
679723def _try_nvcomp_batch_decompress (compressed_tiles , tile_bytes , compression ):
680- """Try batch decompression via nvCOMP. Returns CuPy array or None.
724+ """Try batch decompression via nvCOMP C API . Returns CuPy array or None.
681725
682- nvCOMP (NVIDIA's batched compression library) decompresses all tiles
683- in a single GPU API call using optimized CUDA kernels. Falls back
684- to None if nvCOMP is not available or doesn't support the codec.
726+ Uses nvcompBatchedDeflateDecompressAsync to decompress all tiles in
727+ one GPU API call. Falls back to None if nvCOMP is not available.
685728 """
686- try :
687- import kvikio .nvcomp as nvcomp
688- except ImportError :
729+ if compression not in (8 , 32946 , 50000 ): # Deflate and ZSTD
689730 return None
690731
691- import cupy
692-
693- codec_map = {
694- 8 : 'deflate' , # Deflate
695- 32946 : 'deflate' , # Adobe Deflate
696- 5 : 'lzw' , # LZW (nvCOMP doesn't support TIFF LZW variant)
697- }
698- codec_name = codec_map .get (compression )
699- if codec_name is None :
700- return None
732+ lib = _get_nvcomp ()
733+ if lib is None :
734+ # Try kvikio.nvcomp as alternative
735+ try :
736+ import kvikio .nvcomp as nvcomp
737+ except ImportError :
738+ return None
701739
702- # nvCOMP's DeflateManager handles batch deflate
703- if codec_name == 'deflate' :
740+ import cupy
704741 try :
705- # Strip 2-byte zlib headers + 4-byte checksums from each tile
706742 raw_tiles = []
707743 for tile in compressed_tiles :
708- # zlib format: 2-byte header, deflate data, 4-byte adler32
709744 raw_tiles .append (tile [2 :- 4 ] if len (tile ) > 6 else tile )
710-
711745 manager = nvcomp .DeflateManager (chunk_size = tile_bytes )
712-
713- # Copy compressed data to device
714746 d_compressed = [cupy .asarray (np .frombuffer (t , dtype = np .uint8 ))
715747 for t in raw_tiles ]
716-
717- # Batch decompress
718748 d_decompressed = manager .decompress (d_compressed )
719-
720- # Concatenate results into a single buffer
721- result = cupy .concatenate ([d .ravel () for d in d_decompressed ])
722- return result
749+ return cupy .concatenate ([d .ravel () for d in d_decompressed ])
723750 except Exception :
724751 return None
725752
726- return None
753+ # Direct ctypes nvCOMP C API
754+ import ctypes
755+ import cupy
756+
757+ class _NvcompDecompOpts (ctypes .Structure ):
758+ """nvCOMP batched decompression options (passed by value)."""
759+ _fields_ = [
760+ ('backend' , ctypes .c_int ),
761+ ('reserved' , ctypes .c_char * 60 ),
762+ ]
763+
764+ # Deflate has a different struct with sort_before_hw_decompress field
765+ class _NvcompDeflateDecompOpts (ctypes .Structure ):
766+ _fields_ = [
767+ ('backend' , ctypes .c_int ),
768+ ('sort_before_hw_decompress' , ctypes .c_int ),
769+ ('reserved' , ctypes .c_char * 56 ),
770+ ]
771+
772+ try :
773+ n_tiles = len (compressed_tiles )
774+
775+ # Prepare compressed tiles for nvCOMP
776+ if compression in (8 , 32946 ): # Deflate
777+ # Strip 2-byte zlib header + 4-byte adler32 checksum
778+ raw_tiles = [t [2 :- 4 ] if len (t ) > 6 else t for t in compressed_tiles ]
779+ get_temp_fn = 'nvcompBatchedDeflateDecompressGetTempSizeAsync'
780+ decomp_fn = 'nvcompBatchedDeflateDecompressAsync'
781+ opts = _NvcompDeflateDecompOpts (backend = 0 , sort_before_hw_decompress = 0 ,
782+ reserved = b'\x00 ' * 56 )
783+ elif compression == 50000 : # ZSTD
784+ raw_tiles = list (compressed_tiles ) # no header stripping
785+ get_temp_fn = 'nvcompBatchedZstdDecompressGetTempSizeAsync'
786+ decomp_fn = 'nvcompBatchedZstdDecompressAsync'
787+ opts = _NvcompDecompOpts (backend = 0 , reserved = b'\x00 ' * 60 )
788+ else :
789+ return None
790+
791+ # Upload compressed tiles to device
792+ d_comp_bufs = [cupy .asarray (np .frombuffer (t , dtype = np .uint8 )) for t in raw_tiles ]
793+ d_decomp_bufs = [cupy .empty (tile_bytes , dtype = cupy .uint8 ) for _ in range (n_tiles )]
794+
795+ d_comp_ptrs = cupy .array ([b .data .ptr for b in d_comp_bufs ], dtype = cupy .uint64 )
796+ d_decomp_ptrs = cupy .array ([b .data .ptr for b in d_decomp_bufs ], dtype = cupy .uint64 )
797+ d_comp_sizes = cupy .array ([len (t ) for t in raw_tiles ], dtype = cupy .uint64 )
798+ d_buf_sizes = cupy .full (n_tiles , tile_bytes , dtype = cupy .uint64 )
799+ d_actual = cupy .empty (n_tiles , dtype = cupy .uint64 )
800+
801+ # Set argtypes for proper struct passing
802+ temp_fn = getattr (lib , get_temp_fn )
803+ temp_fn .restype = ctypes .c_int
804+
805+ temp_size = ctypes .c_size_t (0 )
806+ status = temp_fn (
807+ ctypes .c_size_t (n_tiles ),
808+ ctypes .c_size_t (tile_bytes ),
809+ opts ,
810+ ctypes .byref (temp_size ),
811+ ctypes .c_size_t (n_tiles * tile_bytes ),
812+ )
813+ if status != 0 :
814+ return None
815+
816+ ts = max (temp_size .value , 1 )
817+ d_temp = cupy .empty (ts , dtype = cupy .uint8 )
818+ d_statuses = cupy .zeros (n_tiles , dtype = cupy .int32 )
819+
820+ dec_fn = getattr (lib , decomp_fn )
821+ dec_fn .restype = ctypes .c_int
822+
823+ status = dec_fn (
824+ ctypes .c_void_p (d_comp_ptrs .data .ptr ),
825+ ctypes .c_void_p (d_comp_sizes .data .ptr ),
826+ ctypes .c_void_p (d_buf_sizes .data .ptr ),
827+ ctypes .c_void_p (d_actual .data .ptr ),
828+ ctypes .c_size_t (n_tiles ),
829+ ctypes .c_void_p (d_temp .data .ptr ),
830+ ctypes .c_size_t (ts ),
831+ ctypes .c_void_p (d_decomp_ptrs .data .ptr ),
832+ opts ,
833+ ctypes .c_void_p (d_statuses .data .ptr ),
834+ ctypes .c_void_p (0 ), # default stream
835+ )
836+ if status != 0 :
837+ return None
838+
839+ cupy .cuda .Device ().synchronize ()
840+
841+ if int (cupy .any (d_statuses != 0 )):
842+ return None
843+
844+ return cupy .concatenate (d_decomp_bufs )
845+
846+ except Exception :
847+ return None
727848
728849
729850# ---------------------------------------------------------------------------
0 commit comments