99Supported model_types: llama, mistral, qwen2, qwen3, qwen3_moe, glm4
1010"""
1111
12+ import json
1213import math
1314import os
15+ import struct
16+ import warnings
1417from dataclasses import dataclass , field
1518from typing import Optional
1619
2831from bitsandbytes .training import checkpoint_cpu_offload
2932
3033
34+ def parse_safetensors_offsets (path : str ) -> dict :
35+ """Parse safetensors header to get tensor name → (byte_offset, byte_size, shape, dtype)."""
36+ _dtype_map = {
37+ "F16" : (torch .float16 , 2 ),
38+ "BF16" : (torch .bfloat16 , 2 ),
39+ "F32" : (torch .float32 , 4 ),
40+ "I32" : (torch .int32 , 4 ),
41+ "I64" : (torch .int64 , 8 ),
42+ "I16" : (torch .int16 , 2 ),
43+ "I8" : (torch .int8 , 1 ),
44+ "U8" : (torch .uint8 , 1 ),
45+ }
46+ with open (path , "rb" ) as fp :
47+ header_size = struct .unpack ("<Q" , fp .read (8 ))[0 ]
48+ header = json .loads (fp .read (header_size ))
49+ data_start = 8 + header_size
50+ offsets = {}
51+ for name , info in header .items ():
52+ if name == "__metadata__" :
53+ continue
54+ start , end = info ["data_offsets" ]
55+ dtype_str = info ["dtype" ]
56+ shape = info ["shape" ]
57+ torch_dtype , _ = _dtype_map .get (dtype_str , (torch .uint8 , 1 ))
58+ offsets [name ] = (data_start + start , end - start , shape , torch_dtype )
59+ return offsets
60+
61+
3162def get_available_ram_bytes () -> int :
3263 """Read MemAvailable from /proc/meminfo (Linux only)."""
3364 try :
@@ -260,6 +291,7 @@ def from_quantized(
260291 expert_chunk_size : int = 32 ,
261292 batch_size : int = 8 ,
262293 seq_len : int = 1024 ,
294+ use_gds : bool = False ,
263295 lora_checkpoint : Optional [str ] = None ,
264296 ) -> "KbitLoraModel" :
265297 """Load a pre-quantized model from a safetensors checkpoint.
@@ -283,6 +315,8 @@ def from_quantized(
283315 expert_chunk_size: Experts processed at once in MoE forward.
284316 batch_size: Batch size hint for VRAM estimation (partial residency).
285317 seq_len: Sequence length hint for VRAM estimation (partial residency).
318+ use_gds: If True, use GPUDirect Storage (kvikio) for NVMe→GPU
319+ streaming. Falls back to CPU path if kvikio unavailable.
286320 lora_checkpoint: Optional path to saved LoRA weights to load.
287321 """
288322 from safetensors import safe_open
@@ -356,6 +390,15 @@ class _MinimalConfig:
356390 self .model = None
357391 self .lm_head_tied = False
358392
393+ # GDS detection and fallback
394+ if use_gds and not cls ._detect_gds_support ():
395+ warnings .warn (
396+ "GDS requested but not available (kvikio not installed or "
397+ "GeForce GPU detected). Falling back to CPU path."
398+ )
399+ use_gds = False
400+ self ._use_gds = use_gds
401+
359402 # 5. Initialize parameter containers
360403 self ._quantized_weights = nn .ParameterDict ()
361404 self ._lora_params = nn .ParameterDict ()
@@ -849,6 +892,68 @@ def _extend_rope_cache(self, seq_len: int, device):
849892 return
850893 self ._build_rope_cache (device , max_seq_len = seq_len )
851894
895+ # ─── GDS support ───
896+
897+ @staticmethod
898+ def _detect_gds_support () -> bool :
899+ """Check if GDS (GPUDirect Storage) is available and beneficial.
900+
901+ Returns False if kvikio is not installed or if the GPU is a GeForce
902+ (which only supports GDS in compatibility mode with no benefit).
903+ """
904+ try :
905+ import kvikio # noqa: F401
906+ except ImportError :
907+ return False
908+ # GeForce GPUs only support GDS in compat mode (bounce buffer through
909+ # CPU), which is no faster than the CPU pinned path. Only workstation/
910+ # datacenter GPUs (RTX PRO, Quadro, A100+) benefit from true GDS DMA.
911+ gpu_name = torch .cuda .get_device_name (0 )
912+ if "GeForce" in gpu_name :
913+ return False
914+ return True
915+
916+ def _gds_load_to_gpu (self , cpu_idx : int , slot : int , sync : bool = False ):
917+ """Read from NVMe directly into GPU slot via kvikio.CuFile.
918+
919+ Uses kvikio's thread pool for parallel reads. The sync parameter is
920+ currently ignored — all reads complete before returning. This ensures
921+ correctness with the CUDA stream sync pattern used by the caller.
922+ """
923+ import kvikio
924+
925+ gds_info = self ._gds_layer_info [cpu_idx ]
926+ gpu_slot = self ._gpu_slots [slot ]
927+ file_path = self ._checkpoint_path
928+ futures = []
929+
930+ # CuFile must stay open until all futures complete — pread is async
931+ # and the file handle must remain valid until the reads finish.
932+ with kvikio .CuFile (file_path , "r" ) as f :
933+ for key , value in gds_info .items ():
934+ if isinstance (value , dict ):
935+ # Nested proj: {packed: (offset, size, shape, dtype), ...}
936+ for wk , (offset , size , shape , dtype ) in value .items ():
937+ fut = f .pread (
938+ buf = gpu_slot [key ][wk ],
939+ file_offset = offset ,
940+ size = size ,
941+ )
942+ futures .append (fut )
943+ else :
944+ # Flat tensor: (offset, size, shape, dtype)
945+ offset , size , shape , dtype = value
946+ fut = f .pread (
947+ buf = gpu_slot [key ],
948+ file_offset = offset ,
949+ size = size ,
950+ )
951+ futures .append (fut )
952+
953+ # Wait for all reads to complete while the file handle is still open
954+ for fut in futures :
955+ fut .get ()
956+
852957 # ─── Weight streaming ───
853958
854959 def _get_streaming_weight_keys (self , layer_info : dict ) -> list [str ]:
@@ -1049,11 +1154,16 @@ def _init_weight_streaming(self):
10491154
10501155 # Select RAM strategy
10511156 has_checkpoint = getattr (self , "_checkpoint_path" , None ) is not None
1157+ use_gds = getattr (self , "_use_gds" , False )
10521158 available_ram = get_available_ram_bytes ()
10531159 headroom = 4 * 1024 ** 3 # 4 GB safety margin
10541160 usable_ram = max (0 , available_ram - headroom )
10551161
1056- if usable_ram >= total_streamed_bytes or not has_checkpoint :
1162+ if use_gds and has_checkpoint :
1163+ # GDS: read directly from NVMe to GPU, no CPU memory needed
1164+ self ._ram_strategy = "gds"
1165+ n_pinned = 0
1166+ elif usable_ram >= total_streamed_bytes or not has_checkpoint :
10571167 # All-pinned: pre-load everything into CPU pinned RAM
10581168 # Also forced when no checkpoint file (from __init__ path)
10591169 self ._ram_strategy = "pinned"
@@ -1078,20 +1188,50 @@ def _init_weight_streaming(self):
10781188 self ._safetensors_file = None
10791189 self ._staging_buffers = []
10801190 self ._mmap_layer_names = {}
1191+ self ._gds_layer_info = {}
10811192
10821193 if self ._ram_strategy in ("hybrid" , "mmap" ) and has_checkpoint :
10831194 from safetensors import safe_open
10841195 self ._safetensors_file = safe_open (
10851196 self ._checkpoint_path , framework = "pt" , device = "cpu"
10861197 )
10871198
1088- # Move non-resident layers: pinned or leave for mmap
1199+ # Parse byte offsets for GDS path
1200+ _sf_offsets = None
1201+ if self ._ram_strategy == "gds" and has_checkpoint :
1202+ _sf_offsets = parse_safetensors_offsets (self ._checkpoint_path )
1203+
1204+ # Move non-resident layers: pinned, mmap, or GDS
10891205 self ._cpu_weights = []
10901206 for si in range (n_streamed ):
10911207 layer_idx = self ._n_resident + si
10921208 layer_info = self ._layer_data [layer_idx ]
10931209
1094- if si < n_pinned :
1210+ if self ._ram_strategy == "gds" :
1211+ # GDS: store byte offset info for each tensor
1212+ tensor_names = self ._tensor_name_map [layer_idx ]
1213+ gds_layer = {}
1214+ for key , names in tensor_names .items ():
1215+ if isinstance (names , dict ):
1216+ gds_layer [key ] = {
1217+ wk : _sf_offsets [tn ]
1218+ for wk , tn in names .items ()
1219+ }
1220+ else :
1221+ gds_layer [key ] = _sf_offsets [names ]
1222+ self ._gds_layer_info [si ] = gds_layer
1223+ # Clear weight tensors from _layer_data
1224+ proj_keys = self ._get_streaming_weight_keys (layer_info )
1225+ for proj in proj_keys :
1226+ for wk in weight_keys :
1227+ layer_info [proj ][wk ] = None
1228+ if layer_info .get ("is_moe" ):
1229+ for expert_proj in ["gate" , "up" , "down" ]:
1230+ for suffix in ["packed" , "absmax" ]:
1231+ layer_info [f"expert_{ expert_proj } _{ suffix } " ] = None
1232+ layer_info ["expert_codebook" ] = None
1233+ self ._cpu_weights .append (None )
1234+ elif si < n_pinned :
10951235 # Pinned: copy to CPU pinned memory
10961236 cpu_layer = {}
10971237 proj_keys = self ._get_streaming_weight_keys (layer_info )
@@ -1173,6 +1313,21 @@ def _init_weight_streaming(self):
11731313 else :
11741314 ref_layer [key ] = self ._safetensors_file .get_tensor (value )
11751315
1316+ if ref_layer is None and self ._gds_layer_info :
1317+ # GDS path — build reference from offset info (shapes + dtypes)
1318+ first_gds_si = min (self ._gds_layer_info .keys ())
1319+ gds_info = self ._gds_layer_info [first_gds_si ]
1320+ ref_layer = {}
1321+ for key , value in gds_info .items ():
1322+ if isinstance (value , dict ):
1323+ ref_layer [key ] = {
1324+ wk : torch .empty (shape , dtype = dtype , device = "cpu" )
1325+ for wk , (offset , size , shape , dtype ) in value .items ()
1326+ }
1327+ else :
1328+ offset , size , shape , dtype = value
1329+ ref_layer [key ] = torch .empty (shape , dtype = dtype , device = "cpu" )
1330+
11761331 def _entry_bytes (v ):
11771332 return sum (t .nbytes for t in v .values ()) if isinstance (v , dict ) else v .nbytes
11781333
@@ -1185,6 +1340,26 @@ def _ref_bytes(layer):
11851340 if cpu_layer is not None and _ref_bytes (cpu_layer ) > _ref_bytes (largest_ref ):
11861341 largest_ref = cpu_layer
11871342
1343+ # For GDS, also check all GDS layers by building temp ref from offsets
1344+ for si , gds_info in self ._gds_layer_info .items ():
1345+ gds_bytes = sum (
1346+ sum (info [1 ] for info in v .values ()) if isinstance (v , dict )
1347+ else v [1 ]
1348+ for v in gds_info .values ()
1349+ )
1350+ if gds_bytes > _ref_bytes (largest_ref ):
1351+ # Build a temp ref from this GDS layer's shapes
1352+ largest_ref = {}
1353+ for key , value in gds_info .items ():
1354+ if isinstance (value , dict ):
1355+ largest_ref [key ] = {
1356+ wk : torch .empty (shape , dtype = dtype , device = "cpu" )
1357+ for wk , (offset , size , shape , dtype ) in value .items ()
1358+ }
1359+ else :
1360+ offset , size , shape , dtype = value
1361+ largest_ref [key ] = torch .empty (shape , dtype = dtype , device = "cpu" )
1362+
11881363 self ._gpu_slots = []
11891364 for _ in range (2 ):
11901365 slot = {}
@@ -1242,14 +1417,18 @@ def _ref_bytes(layer):
12421417 def _stream_load_layer (self , layer_idx : int , slot : int , sync : bool = False ):
12431418 """Load a layer's quantized weights into a GPU slot.
12441419
1245- Handles both pinned (direct DMA) and mmap (safetensors → staging → GPU) sources.
1420+ Handles pinned (direct DMA), mmap (safetensors → staging → GPU),
1421+ and GDS (NVMe → GPU via kvikio) sources.
12461422 """
12471423 cpu_idx = layer_idx - self ._n_resident
12481424 cpu_layer = self ._cpu_weights [cpu_idx ]
12491425
12501426 if cpu_layer is not None :
12511427 # Pinned path: async DMA from CPU pinned to GPU
12521428 self ._copy_pinned_to_gpu (cpu_layer , slot , sync )
1429+ elif self ._ram_strategy == "gds" :
1430+ # GDS path: read from NVMe directly into GPU slot
1431+ self ._gds_load_to_gpu (cpu_idx , slot , sync )
12531432 else :
12541433 # Mmap path: load from safetensors → staging buffer → GPU
12551434 self ._mmap_load_to_gpu (cpu_idx , slot , sync )
0 commit comments