@@ -672,6 +672,39 @@ def _assemble_tiles_kernel(
672672 output [dst_byte + b ] = decompressed_buf [src_byte + b ]
673673
674674
675+ # ---------------------------------------------------------------------------
676+ # KvikIO GDS (GPUDirect Storage) -- read file directly to GPU
677+ # ---------------------------------------------------------------------------
678+
679+ def _try_kvikio_read_tiles (file_path , tile_offsets , tile_byte_counts , tile_bytes ):
680+ """Read compressed tile bytes directly from SSD to GPU via GDS.
681+
682+ When kvikio is available and GDS is supported, file data is DMA'd
683+ directly from the NVMe drive to GPU VRAM, bypassing CPU entirely.
684+ Falls back to None if kvikio is not installed or GDS is not available.
685+
686+ Returns list of cupy arrays (one per tile) on GPU, or None.
687+ """
688+ try :
689+ import kvikio
690+ import cupy
691+ except ImportError :
692+ return None
693+
694+ try :
695+ d_tiles = []
696+ with kvikio .CuFile (file_path , 'r' ) as f :
697+ for off , bc in zip (tile_offsets , tile_byte_counts ):
698+ buf = cupy .empty (bc , dtype = cupy .uint8 )
699+ f .pread (buf , file_offset = off )
700+ d_tiles .append (buf )
701+ return d_tiles
702+ except Exception :
703+ # GDS not available (no NVMe, no kernel module, etc.)
704+ # Fall back to normal CPU read path
705+ return None
706+
707+
675708# ---------------------------------------------------------------------------
676709# nvCOMP batch decompression (optional, fast path)
677710# ---------------------------------------------------------------------------
@@ -851,6 +884,175 @@ class _NvcompDeflateDecompOpts(ctypes.Structure):
851884# High-level GPU decode pipeline
852885# ---------------------------------------------------------------------------
853886
887+ def gpu_decode_tiles_from_file (
888+ file_path : str ,
889+ tile_offsets : list | tuple ,
890+ tile_byte_counts : list | tuple ,
891+ tile_width : int ,
892+ tile_height : int ,
893+ image_width : int ,
894+ image_height : int ,
895+ compression : int ,
896+ predictor : int ,
897+ dtype : np .dtype ,
898+ samples : int = 1 ,
899+ ):
900+ """Decode tiles from a file, using GDS if available.
901+
902+ Tries KvikIO GDS (SSD → GPU direct) first, then falls back to
903+ CPU mmap + gpu_decode_tiles.
904+ """
905+ import cupy
906+
907+ # Try GDS: read compressed tiles directly from SSD to GPU
908+ d_tiles = _try_kvikio_read_tiles (
909+ file_path , tile_offsets , tile_byte_counts ,
910+ tile_width * tile_height * dtype .itemsize * samples )
911+
912+ if d_tiles is not None :
913+ # Tiles are already on GPU as cupy arrays.
914+ # Try nvCOMP batch decompress on them directly.
915+ tile_bytes = tile_width * tile_height * dtype .itemsize * samples
916+
917+ if compression in (50000 ,) and _get_nvcomp () is not None :
918+ # ZSTD: nvCOMP can decompress directly from GPU buffers
919+ result = _try_nvcomp_from_device_bufs (
920+ d_tiles , tile_bytes , compression )
921+ if result is not None :
922+ decomp_offsets = np .arange (len (d_tiles ), dtype = np .int64 ) * tile_bytes
923+ d_decomp = result
924+ d_decomp_offsets = cupy .asarray (decomp_offsets )
925+ # Apply predictor + assemble (shared code below)
926+ return _apply_predictor_and_assemble (
927+ d_decomp , d_decomp_offsets , len (d_tiles ),
928+ tile_width , tile_height , image_width , image_height ,
929+ predictor , dtype , samples , tile_bytes )
930+
931+ # GDS read succeeded but nvCOMP can't decompress on GPU,
932+ # or it's LZW/deflate. Copy tiles to host and use normal path.
933+ compressed_tiles = [t .get ().tobytes () for t in d_tiles ]
934+ else :
935+ # No GDS -- read tiles via CPU mmap (caller provides bytes)
936+ # This path is used when called from gpu_decode_tiles()
937+ return None # signal caller to use the bytes-based path
938+
939+ return gpu_decode_tiles (
940+ compressed_tiles , tile_width , tile_height ,
941+ image_width , image_height , compression , predictor , dtype , samples )
942+
943+
944+ def _try_nvcomp_from_device_bufs (d_tiles , tile_bytes , compression ):
945+ """Run nvCOMP batch decompress on tiles already in GPU memory."""
946+ import ctypes
947+ import cupy
948+
949+ lib = _get_nvcomp ()
950+ if lib is None :
951+ return None
952+
953+ class _NvcompDecompOpts (ctypes .Structure ):
954+ _fields_ = [('backend' , ctypes .c_int ), ('reserved' , ctypes .c_char * 60 )]
955+
956+ try :
957+ n = len (d_tiles )
958+ d_decomp_bufs = [cupy .empty (tile_bytes , dtype = cupy .uint8 ) for _ in range (n )]
959+
960+ d_comp_ptrs = cupy .array ([t .data .ptr for t in d_tiles ], dtype = cupy .uint64 )
961+ d_decomp_ptrs = cupy .array ([b .data .ptr for b in d_decomp_bufs ], dtype = cupy .uint64 )
962+ d_comp_sizes = cupy .array ([t .size for t in d_tiles ], dtype = cupy .uint64 )
963+ d_buf_sizes = cupy .full (n , tile_bytes , dtype = cupy .uint64 )
964+ d_actual = cupy .empty (n , dtype = cupy .uint64 )
965+
966+ opts = _NvcompDecompOpts (backend = 0 , reserved = b'\x00 ' * 60 )
967+
968+ fn_name = {50000 : 'nvcompBatchedZstdDecompressGetTempSizeAsync' }.get (compression )
969+ dec_name = {50000 : 'nvcompBatchedZstdDecompressAsync' }.get (compression )
970+ if fn_name is None :
971+ return None
972+
973+ temp_fn = getattr (lib , fn_name )
974+ temp_fn .restype = ctypes .c_int
975+ temp_size = ctypes .c_size_t (0 )
976+ s = temp_fn (n , tile_bytes , opts , ctypes .byref (temp_size ), n * tile_bytes )
977+ if s != 0 :
978+ return None
979+
980+ ts = max (temp_size .value , 1 )
981+ d_temp = cupy .empty (ts , dtype = cupy .uint8 )
982+ d_statuses = cupy .zeros (n , dtype = cupy .int32 )
983+
984+ dec_fn = getattr (lib , dec_name )
985+ dec_fn .restype = ctypes .c_int
986+ s = dec_fn (
987+ ctypes .c_void_p (d_comp_ptrs .data .ptr ),
988+ ctypes .c_void_p (d_comp_sizes .data .ptr ),
989+ ctypes .c_void_p (d_buf_sizes .data .ptr ),
990+ ctypes .c_void_p (d_actual .data .ptr ),
991+ ctypes .c_size_t (n ),
992+ ctypes .c_void_p (d_temp .data .ptr ), ctypes .c_size_t (ts ),
993+ ctypes .c_void_p (d_decomp_ptrs .data .ptr ),
994+ opts ,
995+ ctypes .c_void_p (d_statuses .data .ptr ),
996+ ctypes .c_void_p (0 ),
997+ )
998+ if s != 0 :
999+ return None
1000+
1001+ cupy .cuda .Device ().synchronize ()
1002+ if int (cupy .any (d_statuses != 0 )):
1003+ return None
1004+
1005+ return cupy .concatenate (d_decomp_bufs )
1006+ except Exception :
1007+ return None
1008+
1009+
1010+ def _apply_predictor_and_assemble (d_decomp , d_decomp_offsets , n_tiles ,
1011+ tile_width , tile_height ,
1012+ image_width , image_height ,
1013+ predictor , dtype , samples , tile_bytes ):
1014+ """Apply predictor decode and tile assembly on GPU."""
1015+ import cupy
1016+
1017+ bytes_per_pixel = dtype .itemsize * samples
1018+
1019+ if predictor == 2 :
1020+ total_rows = n_tiles * tile_height
1021+ tpb = min (256 , total_rows )
1022+ bpg = math .ceil (total_rows / tpb )
1023+ _predictor_decode_kernel [bpg , tpb ](
1024+ d_decomp , tile_width * samples , total_rows , dtype .itemsize * samples )
1025+ cuda .synchronize ()
1026+ elif predictor == 3 :
1027+ total_rows = n_tiles * tile_height
1028+ tpb = min (256 , total_rows )
1029+ bpg = math .ceil (total_rows / tpb )
1030+ d_tmp = cupy .empty_like (d_decomp )
1031+ _fp_predictor_decode_kernel [bpg , tpb ](
1032+ d_decomp , d_tmp , tile_width * samples , total_rows , dtype .itemsize )
1033+ cuda .synchronize ()
1034+
1035+ tiles_across = math .ceil (image_width / tile_width )
1036+ total_pixels = image_width * image_height
1037+ d_output = cupy .empty (total_pixels * bytes_per_pixel , dtype = cupy .uint8 )
1038+
1039+ tpb = 256
1040+ bpg = math .ceil (total_pixels / tpb )
1041+ _assemble_tiles_kernel [bpg , tpb ](
1042+ d_decomp , d_decomp_offsets ,
1043+ tile_width , tile_height , bytes_per_pixel ,
1044+ image_width , image_height , tiles_across ,
1045+ d_output ,
1046+ )
1047+ cuda .synchronize ()
1048+
1049+ if samples > 1 :
1050+ return d_output .view (dtype = cupy .dtype (dtype )).reshape (
1051+ image_height , image_width , samples )
1052+ return d_output .view (dtype = cupy .dtype (dtype )).reshape (
1053+ image_height , image_width )
1054+
1055+
8541056def gpu_decode_tiles (
8551057 compressed_tiles : list [bytes ],
8561058 tile_width : int ,
0 commit comments