Skip to content

Commit 0d2c59a

Browse files
TimDettmersclaude
andcommitted
feat: Add GDS/kvikio integration for NVMe→GPU weight streaming
Implements GPUDirect Storage support for reading quantized weights directly from NVMe into GPU memory, bypassing CPU entirely on workstation/datacenter GPUs (RTX PRO, Quadro, A100+). Key changes: - parse_safetensors_offsets(): parse raw byte offsets from header - _detect_gds_support(): check kvikio availability and GPU type - _gds_load_to_gpu(): read via kvikio.CuFile.pread into GPU slots - GDS strategy in _init_weight_streaming(): stores per-tensor offset info instead of CPU weights, allocates GPU slots from offset shapes - _stream_load_layer(): dispatches to GDS, pinned, or mmap path - use_gds parameter on from_quantized() with automatic fallback - Falls back to CPU path on GeForce GPUs (compat mode only) Tests: 6 new tests covering detection, fallback, strategy selection, forward/backward correctness, gradient match vs pinned path, and offset parsing. All 45 tests pass. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent dc521ef commit 0d2c59a

File tree

2 files changed

+352
-4
lines changed

2 files changed

+352
-4
lines changed

bitsandbytes/kbit_lora.py

Lines changed: 183 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
Supported model_types: llama, mistral, qwen2, qwen3, qwen3_moe, glm4
1010
"""
1111

12+
import json
1213
import math
1314
import os
15+
import struct
16+
import warnings
1417
from dataclasses import dataclass, field
1518
from typing import Optional
1619

@@ -28,6 +31,34 @@
2831
from bitsandbytes.training import checkpoint_cpu_offload
2932

3033

34+
def parse_safetensors_offsets(path: str) -> dict:
35+
"""Parse safetensors header to get tensor name → (byte_offset, byte_size, shape, dtype)."""
36+
_dtype_map = {
37+
"F16": (torch.float16, 2),
38+
"BF16": (torch.bfloat16, 2),
39+
"F32": (torch.float32, 4),
40+
"I32": (torch.int32, 4),
41+
"I64": (torch.int64, 8),
42+
"I16": (torch.int16, 2),
43+
"I8": (torch.int8, 1),
44+
"U8": (torch.uint8, 1),
45+
}
46+
with open(path, "rb") as fp:
47+
header_size = struct.unpack("<Q", fp.read(8))[0]
48+
header = json.loads(fp.read(header_size))
49+
data_start = 8 + header_size
50+
offsets = {}
51+
for name, info in header.items():
52+
if name == "__metadata__":
53+
continue
54+
start, end = info["data_offsets"]
55+
dtype_str = info["dtype"]
56+
shape = info["shape"]
57+
torch_dtype, _ = _dtype_map.get(dtype_str, (torch.uint8, 1))
58+
offsets[name] = (data_start + start, end - start, shape, torch_dtype)
59+
return offsets
60+
61+
3162
def get_available_ram_bytes() -> int:
3263
"""Read MemAvailable from /proc/meminfo (Linux only)."""
3364
try:
@@ -260,6 +291,7 @@ def from_quantized(
260291
expert_chunk_size: int = 32,
261292
batch_size: int = 8,
262293
seq_len: int = 1024,
294+
use_gds: bool = False,
263295
lora_checkpoint: Optional[str] = None,
264296
) -> "KbitLoraModel":
265297
"""Load a pre-quantized model from a safetensors checkpoint.
@@ -283,6 +315,8 @@ def from_quantized(
283315
expert_chunk_size: Experts processed at once in MoE forward.
284316
batch_size: Batch size hint for VRAM estimation (partial residency).
285317
seq_len: Sequence length hint for VRAM estimation (partial residency).
318+
use_gds: If True, use GPUDirect Storage (kvikio) for NVMe→GPU
319+
streaming. Falls back to CPU path if kvikio unavailable.
286320
lora_checkpoint: Optional path to saved LoRA weights to load.
287321
"""
288322
from safetensors import safe_open
@@ -356,6 +390,15 @@ class _MinimalConfig:
356390
self.model = None
357391
self.lm_head_tied = False
358392

393+
# GDS detection and fallback
394+
if use_gds and not cls._detect_gds_support():
395+
warnings.warn(
396+
"GDS requested but not available (kvikio not installed or "
397+
"GeForce GPU detected). Falling back to CPU path."
398+
)
399+
use_gds = False
400+
self._use_gds = use_gds
401+
359402
# 5. Initialize parameter containers
360403
self._quantized_weights = nn.ParameterDict()
361404
self._lora_params = nn.ParameterDict()
@@ -849,6 +892,68 @@ def _extend_rope_cache(self, seq_len: int, device):
849892
return
850893
self._build_rope_cache(device, max_seq_len=seq_len)
851894

895+
# ─── GDS support ───
896+
897+
@staticmethod
898+
def _detect_gds_support() -> bool:
899+
"""Check if GDS (GPUDirect Storage) is available and beneficial.
900+
901+
Returns False if kvikio is not installed or if the GPU is a GeForce
902+
(which only supports GDS in compatibility mode with no benefit).
903+
"""
904+
try:
905+
import kvikio # noqa: F401
906+
except ImportError:
907+
return False
908+
# GeForce GPUs only support GDS in compat mode (bounce buffer through
909+
# CPU), which is no faster than the CPU pinned path. Only workstation/
910+
# datacenter GPUs (RTX PRO, Quadro, A100+) benefit from true GDS DMA.
911+
gpu_name = torch.cuda.get_device_name(0)
912+
if "GeForce" in gpu_name:
913+
return False
914+
return True
915+
916+
def _gds_load_to_gpu(self, cpu_idx: int, slot: int, sync: bool = False):
917+
"""Read from NVMe directly into GPU slot via kvikio.CuFile.
918+
919+
Uses kvikio's thread pool for parallel reads. The sync parameter is
920+
currently ignored — all reads complete before returning. This ensures
921+
correctness with the CUDA stream sync pattern used by the caller.
922+
"""
923+
import kvikio
924+
925+
gds_info = self._gds_layer_info[cpu_idx]
926+
gpu_slot = self._gpu_slots[slot]
927+
file_path = self._checkpoint_path
928+
futures = []
929+
930+
# CuFile must stay open until all futures complete — pread is async
931+
# and the file handle must remain valid until the reads finish.
932+
with kvikio.CuFile(file_path, "r") as f:
933+
for key, value in gds_info.items():
934+
if isinstance(value, dict):
935+
# Nested proj: {packed: (offset, size, shape, dtype), ...}
936+
for wk, (offset, size, shape, dtype) in value.items():
937+
fut = f.pread(
938+
buf=gpu_slot[key][wk],
939+
file_offset=offset,
940+
size=size,
941+
)
942+
futures.append(fut)
943+
else:
944+
# Flat tensor: (offset, size, shape, dtype)
945+
offset, size, shape, dtype = value
946+
fut = f.pread(
947+
buf=gpu_slot[key],
948+
file_offset=offset,
949+
size=size,
950+
)
951+
futures.append(fut)
952+
953+
# Wait for all reads to complete while the file handle is still open
954+
for fut in futures:
955+
fut.get()
956+
852957
# ─── Weight streaming ───
853958

854959
def _get_streaming_weight_keys(self, layer_info: dict) -> list[str]:
@@ -1049,11 +1154,16 @@ def _init_weight_streaming(self):
10491154

10501155
# Select RAM strategy
10511156
has_checkpoint = getattr(self, "_checkpoint_path", None) is not None
1157+
use_gds = getattr(self, "_use_gds", False)
10521158
available_ram = get_available_ram_bytes()
10531159
headroom = 4 * 1024**3 # 4 GB safety margin
10541160
usable_ram = max(0, available_ram - headroom)
10551161

1056-
if usable_ram >= total_streamed_bytes or not has_checkpoint:
1162+
if use_gds and has_checkpoint:
1163+
# GDS: read directly from NVMe to GPU, no CPU memory needed
1164+
self._ram_strategy = "gds"
1165+
n_pinned = 0
1166+
elif usable_ram >= total_streamed_bytes or not has_checkpoint:
10571167
# All-pinned: pre-load everything into CPU pinned RAM
10581168
# Also forced when no checkpoint file (from __init__ path)
10591169
self._ram_strategy = "pinned"
@@ -1078,20 +1188,50 @@ def _init_weight_streaming(self):
10781188
self._safetensors_file = None
10791189
self._staging_buffers = []
10801190
self._mmap_layer_names = {}
1191+
self._gds_layer_info = {}
10811192

10821193
if self._ram_strategy in ("hybrid", "mmap") and has_checkpoint:
10831194
from safetensors import safe_open
10841195
self._safetensors_file = safe_open(
10851196
self._checkpoint_path, framework="pt", device="cpu"
10861197
)
10871198

1088-
# Move non-resident layers: pinned or leave for mmap
1199+
# Parse byte offsets for GDS path
1200+
_sf_offsets = None
1201+
if self._ram_strategy == "gds" and has_checkpoint:
1202+
_sf_offsets = parse_safetensors_offsets(self._checkpoint_path)
1203+
1204+
# Move non-resident layers: pinned, mmap, or GDS
10891205
self._cpu_weights = []
10901206
for si in range(n_streamed):
10911207
layer_idx = self._n_resident + si
10921208
layer_info = self._layer_data[layer_idx]
10931209

1094-
if si < n_pinned:
1210+
if self._ram_strategy == "gds":
1211+
# GDS: store byte offset info for each tensor
1212+
tensor_names = self._tensor_name_map[layer_idx]
1213+
gds_layer = {}
1214+
for key, names in tensor_names.items():
1215+
if isinstance(names, dict):
1216+
gds_layer[key] = {
1217+
wk: _sf_offsets[tn]
1218+
for wk, tn in names.items()
1219+
}
1220+
else:
1221+
gds_layer[key] = _sf_offsets[names]
1222+
self._gds_layer_info[si] = gds_layer
1223+
# Clear weight tensors from _layer_data
1224+
proj_keys = self._get_streaming_weight_keys(layer_info)
1225+
for proj in proj_keys:
1226+
for wk in weight_keys:
1227+
layer_info[proj][wk] = None
1228+
if layer_info.get("is_moe"):
1229+
for expert_proj in ["gate", "up", "down"]:
1230+
for suffix in ["packed", "absmax"]:
1231+
layer_info[f"expert_{expert_proj}_{suffix}"] = None
1232+
layer_info["expert_codebook"] = None
1233+
self._cpu_weights.append(None)
1234+
elif si < n_pinned:
10951235
# Pinned: copy to CPU pinned memory
10961236
cpu_layer = {}
10971237
proj_keys = self._get_streaming_weight_keys(layer_info)
@@ -1173,6 +1313,21 @@ def _init_weight_streaming(self):
11731313
else:
11741314
ref_layer[key] = self._safetensors_file.get_tensor(value)
11751315

1316+
if ref_layer is None and self._gds_layer_info:
1317+
# GDS path — build reference from offset info (shapes + dtypes)
1318+
first_gds_si = min(self._gds_layer_info.keys())
1319+
gds_info = self._gds_layer_info[first_gds_si]
1320+
ref_layer = {}
1321+
for key, value in gds_info.items():
1322+
if isinstance(value, dict):
1323+
ref_layer[key] = {
1324+
wk: torch.empty(shape, dtype=dtype, device="cpu")
1325+
for wk, (offset, size, shape, dtype) in value.items()
1326+
}
1327+
else:
1328+
offset, size, shape, dtype = value
1329+
ref_layer[key] = torch.empty(shape, dtype=dtype, device="cpu")
1330+
11761331
def _entry_bytes(v):
11771332
return sum(t.nbytes for t in v.values()) if isinstance(v, dict) else v.nbytes
11781333

@@ -1185,6 +1340,26 @@ def _ref_bytes(layer):
11851340
if cpu_layer is not None and _ref_bytes(cpu_layer) > _ref_bytes(largest_ref):
11861341
largest_ref = cpu_layer
11871342

1343+
# For GDS, also check all GDS layers by building temp ref from offsets
1344+
for si, gds_info in self._gds_layer_info.items():
1345+
gds_bytes = sum(
1346+
sum(info[1] for info in v.values()) if isinstance(v, dict)
1347+
else v[1]
1348+
for v in gds_info.values()
1349+
)
1350+
if gds_bytes > _ref_bytes(largest_ref):
1351+
# Build a temp ref from this GDS layer's shapes
1352+
largest_ref = {}
1353+
for key, value in gds_info.items():
1354+
if isinstance(value, dict):
1355+
largest_ref[key] = {
1356+
wk: torch.empty(shape, dtype=dtype, device="cpu")
1357+
for wk, (offset, size, shape, dtype) in value.items()
1358+
}
1359+
else:
1360+
offset, size, shape, dtype = value
1361+
largest_ref[key] = torch.empty(shape, dtype=dtype, device="cpu")
1362+
11881363
self._gpu_slots = []
11891364
for _ in range(2):
11901365
slot = {}
@@ -1242,14 +1417,18 @@ def _ref_bytes(layer):
12421417
def _stream_load_layer(self, layer_idx: int, slot: int, sync: bool = False):
12431418
"""Load a layer's quantized weights into a GPU slot.
12441419
1245-
Handles both pinned (direct DMA) and mmap (safetensors → staging → GPU) sources.
1420+
Handles pinned (direct DMA), mmap (safetensors → staging → GPU),
1421+
and GDS (NVMe → GPU via kvikio) sources.
12461422
"""
12471423
cpu_idx = layer_idx - self._n_resident
12481424
cpu_layer = self._cpu_weights[cpu_idx]
12491425

12501426
if cpu_layer is not None:
12511427
# Pinned path: async DMA from CPU pinned to GPU
12521428
self._copy_pinned_to_gpu(cpu_layer, slot, sync)
1429+
elif self._ram_strategy == "gds":
1430+
# GDS path: read from NVMe directly into GPU slot
1431+
self._gds_load_to_gpu(cpu_idx, slot, sync)
12531432
else:
12541433
# Mmap path: load from safetensors → staging buffer → GPU
12551434
self._mmap_load_to_gpu(cpu_idx, slot, sync)

0 commit comments

Comments
 (0)