Skip to content

Commit 34d4dc3

Browse files
TimDettmersclaude
andcommitted
feat: Add from_quantized classmethod and fix weight streaming bugs
- Add KbitLoraModel.from_quantized() classmethod that loads a pre-quantized safetensors checkpoint without requiring the original HuggingFace model (Path B). Reconstructs ArchConfig from metadata, populates _layer_data, creates LoRA adapters, and optionally initializes weight streaming. - Fix _init_weight_streaming GPU slot allocation to handle nested projection dicts (packed/absmax/codebook per projection). Previously only worked with flat tensor values. - Fix _stream_load_layer to handle both nested projection dicts and flat expert tensors. Previously would crash on the first call with nested dicts. - Fix byte-counting in streaming summary to handle nested dicts. - Add comprehensive round-trip tests: - Dense model: data match, forward match, streaming, attributes - MoE model: data match, streaming with expert weights Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 23ee333 commit 34d4dc3

File tree

2 files changed

+606
-15
lines changed

2 files changed

+606
-15
lines changed

bitsandbytes/kbit_lora.py

Lines changed: 341 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,306 @@ def __init__(
200200
for p in self._norm_weights.parameters():
201201
p.requires_grad_(True)
202202

203+
# ─── Load from pre-quantized checkpoint ───
204+
205+
@classmethod
206+
def from_quantized(
207+
cls,
208+
checkpoint_path: str,
209+
lora_r: int = 64,
210+
lora_alpha: float = 16.0,
211+
attn_chunk_size: int = 4096,
212+
mlp_chunk_size: int = 4096,
213+
ce_chunk_size: int = 8192,
214+
compute_dtype: torch.dtype = torch.bfloat16,
215+
weight_streaming: bool = True,
216+
target_device: torch.device = torch.device("cuda:0"),
217+
lora_on_experts: bool = False,
218+
expert_chunk_size: int = 32,
219+
lora_checkpoint: Optional[str] = None,
220+
) -> "KbitLoraModel":
221+
"""Load a pre-quantized model from a safetensors checkpoint.
222+
223+
This is Path B: load pre-quantized weights without requiring the
224+
original HuggingFace model. Use save_quantized() to create the
225+
checkpoint (Path A).
226+
227+
Args:
228+
checkpoint_path: Path to safetensors file from save_quantized().
229+
lora_r: LoRA rank.
230+
lora_alpha: LoRA scaling factor.
231+
attn_chunk_size: Sequence chunk size for attention.
232+
mlp_chunk_size: Sequence chunk size for MLP.
233+
ce_chunk_size: Vocab chunk size for cross-entropy.
234+
compute_dtype: Computation dtype.
235+
weight_streaming: If True, keep weights in CPU pinned memory
236+
and stream to GPU layer-by-layer.
237+
target_device: GPU device for computation.
238+
lora_on_experts: If True, add LoRA to expert projections.
239+
expert_chunk_size: Experts processed at once in MoE forward.
240+
lora_checkpoint: Optional path to saved LoRA weights to load.
241+
"""
242+
from safetensors import safe_open
243+
244+
# 1. Open safetensors and read metadata
245+
sf = safe_open(checkpoint_path, framework="pt", device="cpu")
246+
meta = sf.metadata()
247+
248+
# 2. Create instance without calling __init__
249+
self = cls.__new__(cls)
250+
nn.Module.__init__(self)
251+
252+
# 3. Reconstruct ArchConfig from metadata
253+
class _MinimalConfig:
254+
pass
255+
256+
cfg = _MinimalConfig()
257+
cfg.model_type = meta["model_type"]
258+
if meta.get("is_moe") == "True":
259+
cfg.num_experts = int(meta["num_experts"])
260+
cfg.num_local_experts = int(meta["num_experts"])
261+
cfg.num_experts_per_tok = int(meta["num_active_experts"])
262+
cfg.moe_intermediate_size = int(meta["expert_intermediate_size"])
263+
264+
self.arch = detect_arch_config(cfg)
265+
266+
# 4. Set attributes from metadata and parameters
267+
self.config = None
268+
self.model_type = meta["model_type"]
269+
self.lora_r = lora_r
270+
self.lora_s = lora_alpha / lora_r
271+
self.k = int(meta.get("k_attention", "4"))
272+
self.k_config = {}
273+
self.k_attention = int(meta["k_attention"])
274+
self.k_mlp = int(meta["k_mlp"])
275+
self.k_lm_head = int(meta["k_lm_head"])
276+
self.k_experts = int(meta["k_experts"])
277+
self.k_shared_expert = int(meta["k_shared_expert"])
278+
self.attn_chunk_size = attn_chunk_size
279+
self.mlp_chunk_size = mlp_chunk_size
280+
self.ce_chunk_size = ce_chunk_size
281+
self.compute_dtype = compute_dtype
282+
self.cpu_offload = weight_streaming
283+
self.weight_streaming = weight_streaming
284+
self.include_embed = True
285+
self.include_lm_head = True
286+
self.lora_on_experts = lora_on_experts
287+
self.expert_chunk_size = expert_chunk_size
288+
289+
self.hidden_size = int(meta["hidden_size"])
290+
self.num_heads = int(meta["num_attention_heads"])
291+
self.num_kv_heads = int(meta["num_key_value_heads"])
292+
self.head_dim = int(meta["head_dim"])
293+
self.q_dim = self.num_heads * self.head_dim
294+
self.kv_dim = self.num_kv_heads * self.head_dim
295+
self.intermediate_size = int(meta["intermediate_size"])
296+
self.vocab_size = int(meta["vocab_size"])
297+
self.num_layers = int(meta["num_layers"])
298+
self.rms_norm_eps = float(meta["rms_norm_eps"])
299+
self.rope_theta = float(meta["rope_theta"])
300+
301+
self._layer_start = int(meta.get("layer_start", "0"))
302+
self._layer_end = int(meta.get("layer_end", meta["num_layers"]))
303+
self._num_loaded_layers = int(meta.get("num_loaded_layers", meta["num_layers"]))
304+
305+
self._streaming = True
306+
self._target_device = target_device
307+
self.model = None
308+
self.lm_head_tied = False
309+
310+
# 5. Initialize parameter containers
311+
self._quantized_weights = nn.ParameterDict()
312+
self._lora_params = nn.ParameterDict()
313+
self._norm_weights = nn.ParameterDict()
314+
315+
# 6. Load embedding
316+
if "embed_tokens.weight" in sf.keys():
317+
embed_weight = sf.get_tensor("embed_tokens.weight").to(target_device)
318+
self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size)
319+
self.embed_tokens.weight = nn.Parameter(embed_weight, requires_grad=False)
320+
else:
321+
self.embed_tokens = None
322+
323+
# 7. Populate _layer_data from safetensors
324+
self._layer_data = []
325+
for i in range(self._num_loaded_layers):
326+
prefix = f"layer.{i}"
327+
layer_info = {}
328+
329+
# Attention projections
330+
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
331+
N = int(meta[f"{prefix}.attn.{proj}.N"])
332+
K = int(meta[f"{prefix}.attn.{proj}.K"])
333+
N_padded = int(meta[f"{prefix}.attn.{proj}.N_padded"])
334+
k_val = int(meta[f"{prefix}.attn.{proj}.k"])
335+
336+
packed = sf.get_tensor(f"{prefix}.attn.{proj}.packed")
337+
absmax = sf.get_tensor(f"{prefix}.attn.{proj}.absmax")
338+
codebook = sf.get_tensor(f"{prefix}.attn.{proj}.codebook")
339+
340+
if not weight_streaming:
341+
packed = packed.to(target_device)
342+
absmax = absmax.to(target_device)
343+
codebook = codebook.to(target_device)
344+
345+
A, B = self._create_lora(f"layers_{i}_attn_{proj}", N, K)
346+
347+
layer_info[proj] = {
348+
"packed": packed, "absmax": absmax, "codebook": codebook,
349+
"N_padded": N_padded, "N": N, "K": K, "k": k_val,
350+
"A": A, "B": B,
351+
}
352+
353+
# MLP or MoE
354+
global_layer_idx = self._layer_start + i
355+
is_moe_layer = self.arch.is_moe_layer(global_layer_idx)
356+
357+
if is_moe_layer:
358+
layer_info["is_moe"] = True
359+
360+
# Router weight (always on GPU, not quantized)
361+
router_weight = sf.get_tensor(f"{prefix}.moe.router_weight")
362+
layer_info["router_weight"] = router_weight.to(
363+
target_device, dtype=compute_dtype
364+
)
365+
366+
# Shared expert (if present)
367+
if self.arch.has_shared_expert:
368+
for proj in ["shared_gate_proj", "shared_up_proj", "shared_down_proj"]:
369+
N = int(meta[f"{prefix}.moe.{proj}.N"])
370+
K = int(meta[f"{prefix}.moe.{proj}.K"])
371+
N_padded = int(meta[f"{prefix}.moe.{proj}.N_padded"])
372+
k_val = int(meta[f"{prefix}.moe.{proj}.k"])
373+
374+
packed = sf.get_tensor(f"{prefix}.moe.{proj}.packed")
375+
absmax = sf.get_tensor(f"{prefix}.moe.{proj}.absmax")
376+
codebook = sf.get_tensor(f"{prefix}.moe.{proj}.codebook")
377+
378+
if not weight_streaming:
379+
packed = packed.to(target_device)
380+
absmax = absmax.to(target_device)
381+
codebook = codebook.to(target_device)
382+
383+
A, B = self._create_lora(f"layers_{i}_moe_{proj}", N, K)
384+
385+
layer_info[proj] = {
386+
"packed": packed, "absmax": absmax, "codebook": codebook,
387+
"N_padded": N_padded, "N": N, "K": K, "k": k_val,
388+
"A": A, "B": B,
389+
}
390+
391+
# Expert weights (concatenated across all experts)
392+
expert_N = int(meta[f"{prefix}.moe.experts.N"])
393+
expert_K = int(meta[f"{prefix}.moe.experts.K"])
394+
expert_N_padded = int(meta[f"{prefix}.moe.experts.N_padded"])
395+
expert_k = int(meta[f"{prefix}.moe.experts.k"])
396+
397+
for expert_proj in ["gate", "up", "down"]:
398+
for suffix in ["packed", "absmax"]:
399+
key = f"expert_{expert_proj}_{suffix}"
400+
tensor = sf.get_tensor(f"{prefix}.moe.experts.{expert_proj}.{suffix}")
401+
if not weight_streaming:
402+
tensor = tensor.to(target_device)
403+
layer_info[key] = tensor
404+
405+
expert_codebook = sf.get_tensor(f"{prefix}.moe.experts.codebook")
406+
if not weight_streaming:
407+
expert_codebook = expert_codebook.to(target_device)
408+
layer_info["expert_codebook"] = expert_codebook
409+
layer_info["expert_k"] = expert_k
410+
layer_info["expert_N"] = expert_N
411+
layer_info["expert_K"] = expert_K
412+
layer_info["expert_N_padded"] = expert_N_padded
413+
else:
414+
# Dense MLP
415+
for proj in ["gate_proj", "up_proj", "down_proj"]:
416+
N = int(meta[f"{prefix}.mlp.{proj}.N"])
417+
K = int(meta[f"{prefix}.mlp.{proj}.K"])
418+
N_padded = int(meta[f"{prefix}.mlp.{proj}.N_padded"])
419+
k_val = int(meta[f"{prefix}.mlp.{proj}.k"])
420+
421+
packed = sf.get_tensor(f"{prefix}.mlp.{proj}.packed")
422+
absmax = sf.get_tensor(f"{prefix}.mlp.{proj}.absmax")
423+
codebook = sf.get_tensor(f"{prefix}.mlp.{proj}.codebook")
424+
425+
if not weight_streaming:
426+
packed = packed.to(target_device)
427+
absmax = absmax.to(target_device)
428+
codebook = codebook.to(target_device)
429+
430+
A, B = self._create_lora(f"layers_{i}_mlp_{proj}", N, K)
431+
432+
layer_info[proj] = {
433+
"packed": packed, "absmax": absmax, "codebook": codebook,
434+
"N_padded": N_padded, "N": N, "K": K, "k": k_val,
435+
"A": A, "B": B,
436+
}
437+
438+
# Norm weights (always on GPU)
439+
for nk in ["input_layernorm", "post_attention_layernorm"]:
440+
tensor_name = f"{prefix}.{nk}.weight"
441+
if tensor_name in sf.keys():
442+
weight = sf.get_tensor(tensor_name).to(
443+
target_device, dtype=compute_dtype
444+
)
445+
safe_name = f"layers_{i}_{nk}_weight"
446+
self._norm_weights[safe_name] = nn.Parameter(weight)
447+
layer_info[nk] = self._norm_weights[safe_name]
448+
449+
# QK norms (Qwen3)
450+
if self.arch.has_qk_norm:
451+
for nk in ["q_norm", "k_norm"]:
452+
tensor_name = f"{prefix}.{nk}.weight"
453+
if tensor_name in sf.keys():
454+
weight = sf.get_tensor(tensor_name).to(
455+
target_device, dtype=compute_dtype
456+
)
457+
safe_name = f"layers_{i}_attn_{nk}_weight"
458+
self._norm_weights[safe_name] = nn.Parameter(weight)
459+
layer_info[nk] = self._norm_weights[safe_name]
460+
461+
self._layer_data.append(layer_info)
462+
463+
# 8. Final norm
464+
if "final_norm.weight" in sf.keys():
465+
weight = sf.get_tensor("final_norm.weight").to(
466+
target_device, dtype=compute_dtype
467+
)
468+
self._norm_weights["final_norm_weight"] = nn.Parameter(weight)
469+
470+
# 9. LM head (always on GPU — small relative to layer weights)
471+
self._lm_head_info = None
472+
if "lm_head.packed" in sf.keys():
473+
self._lm_head_info = {
474+
"packed": sf.get_tensor("lm_head.packed").to(target_device),
475+
"absmax": sf.get_tensor("lm_head.absmax").to(target_device),
476+
"codebook": sf.get_tensor("lm_head.codebook").to(target_device),
477+
"N_padded": int(meta["lm_head.N_padded"]),
478+
"N": int(meta["lm_head.N"]),
479+
"K": int(meta["lm_head.K"]),
480+
"k": int(meta["lm_head.k"]),
481+
}
482+
483+
# 10. Build RoPE cache
484+
self._build_rope_cache(target_device)
485+
486+
# 11. Init weight streaming
487+
if weight_streaming:
488+
self._init_weight_streaming()
489+
490+
# 12. Set trainable params
491+
for p in self._lora_params.parameters():
492+
p.requires_grad_(True)
493+
for p in self._norm_weights.parameters():
494+
p.requires_grad_(True)
495+
496+
# 13. Load LoRA checkpoint (optional)
497+
if lora_checkpoint is not None:
498+
from bitsandbytes.checkpoint import load_lora
499+
load_lora(self, lora_checkpoint)
500+
501+
return self
502+
203503
# ─── Quantization & LoRA creation ───
204504

205505
def _quantize_weight(self, weight: torch.Tensor, name: str, k: int | None = None):
@@ -550,21 +850,37 @@ def _init_weight_streaming(self):
550850

551851
# Pre-allocate 2 GPU buffer slots sized for the largest layer
552852
self._copy_stream = torch.cuda.Stream(device=device)
853+
854+
def _layer_bytes(cpu_layer):
855+
total = 0
856+
for v in cpu_layer.values():
857+
if isinstance(v, dict):
858+
total += sum(t.nbytes for t in v.values())
859+
else:
860+
total += v.nbytes
861+
return total
862+
863+
largest_idx = max(range(len(self._cpu_weights)), key=lambda i: _layer_bytes(self._cpu_weights[i]))
864+
largest_cpu_layer = self._cpu_weights[largest_idx]
865+
553866
self._gpu_slots = []
554867
for _ in range(2):
555868
slot = {}
556-
for i, cpu_layer in enumerate(self._cpu_weights):
557-
if i == 0:
558-
for key, cpu_tensor in cpu_layer.items():
559-
slot[key] = torch.empty_like(cpu_tensor, device=device)
560-
break
869+
for key, value in largest_cpu_layer.items():
870+
if isinstance(value, dict):
871+
slot[key] = {wk: torch.empty_like(t, device=device) for wk, t in value.items()}
872+
else:
873+
slot[key] = torch.empty_like(value, device=device)
561874
self._gpu_slots.append(slot)
562875
self._current_slot = 0
563876

877+
def _entry_bytes(v):
878+
return sum(t.nbytes for t in v.values()) if isinstance(v, dict) else v.nbytes
879+
564880
total_cpu_bytes = sum(
565-
sum(t.nbytes for t in cl.values()) for cl in self._cpu_weights
881+
sum(_entry_bytes(v) for v in cl.values()) for cl in self._cpu_weights
566882
)
567-
slot_bytes = sum(t.nbytes for t in self._gpu_slots[0].values())
883+
slot_bytes = sum(_entry_bytes(v) for v in self._gpu_slots[0].values())
568884
print(
569885
f"Weight streaming: {total_cpu_bytes / 1e9:.1f} GB on CPU pinned, "
570886
f"{2 * slot_bytes / 1e6:.0f} MB GPU double-buffer "
@@ -576,17 +892,27 @@ def _stream_load_layer(self, layer_idx: int, slot: int, sync: bool = False):
576892
cpu_layer = self._cpu_weights[layer_idx]
577893
gpu_slot = self._gpu_slots[slot]
578894

895+
def _do_copies(non_blocking: bool):
896+
for key, cpu_value in cpu_layer.items():
897+
if isinstance(cpu_value, dict):
898+
# Nested proj dict: {packed: tensor, absmax: tensor, codebook: tensor}
899+
if key not in gpu_slot:
900+
gpu_slot[key] = {}
901+
for wk, cpu_tensor in cpu_value.items():
902+
if wk not in gpu_slot[key]:
903+
gpu_slot[key][wk] = torch.empty_like(cpu_tensor, device=self._target_device)
904+
gpu_slot[key][wk].copy_(cpu_tensor, non_blocking=non_blocking)
905+
else:
906+
# Flat tensor (expert concatenated weights)
907+
if key not in gpu_slot:
908+
gpu_slot[key] = torch.empty_like(cpu_value, device=self._target_device)
909+
gpu_slot[key].copy_(cpu_value, non_blocking=non_blocking)
910+
579911
if sync:
580-
for key, cpu_tensor in cpu_layer.items():
581-
if key not in gpu_slot:
582-
gpu_slot[key] = torch.empty_like(cpu_tensor, device=self._target_device)
583-
gpu_slot[key].copy_(cpu_tensor)
912+
_do_copies(non_blocking=False)
584913
else:
585914
with torch.cuda.stream(self._copy_stream):
586-
for key, cpu_tensor in cpu_layer.items():
587-
if key not in gpu_slot:
588-
gpu_slot[key] = torch.empty_like(cpu_tensor, device=self._target_device)
589-
gpu_slot[key].copy_(cpu_tensor, non_blocking=True)
915+
_do_copies(non_blocking=True)
590916

591917
def _get_layer_gpu_weights(self, layer_idx: int, slot: int) -> dict:
592918
"""Build a layer_info-compatible dict from GPU slot + always-resident data."""

0 commit comments

Comments
 (0)