@@ -200,6 +200,306 @@ def __init__(
200200 for p in self ._norm_weights .parameters ():
201201 p .requires_grad_ (True )
202202
203+ # ─── Load from pre-quantized checkpoint ───
204+
205+ @classmethod
206+ def from_quantized (
207+ cls ,
208+ checkpoint_path : str ,
209+ lora_r : int = 64 ,
210+ lora_alpha : float = 16.0 ,
211+ attn_chunk_size : int = 4096 ,
212+ mlp_chunk_size : int = 4096 ,
213+ ce_chunk_size : int = 8192 ,
214+ compute_dtype : torch .dtype = torch .bfloat16 ,
215+ weight_streaming : bool = True ,
216+ target_device : torch .device = torch .device ("cuda:0" ),
217+ lora_on_experts : bool = False ,
218+ expert_chunk_size : int = 32 ,
219+ lora_checkpoint : Optional [str ] = None ,
220+ ) -> "KbitLoraModel" :
221+ """Load a pre-quantized model from a safetensors checkpoint.
222+
223+ This is Path B: load pre-quantized weights without requiring the
224+ original HuggingFace model. Use save_quantized() to create the
225+ checkpoint (Path A).
226+
227+ Args:
228+ checkpoint_path: Path to safetensors file from save_quantized().
229+ lora_r: LoRA rank.
230+ lora_alpha: LoRA scaling factor.
231+ attn_chunk_size: Sequence chunk size for attention.
232+ mlp_chunk_size: Sequence chunk size for MLP.
233+ ce_chunk_size: Vocab chunk size for cross-entropy.
234+ compute_dtype: Computation dtype.
235+ weight_streaming: If True, keep weights in CPU pinned memory
236+ and stream to GPU layer-by-layer.
237+ target_device: GPU device for computation.
238+ lora_on_experts: If True, add LoRA to expert projections.
239+ expert_chunk_size: Experts processed at once in MoE forward.
240+ lora_checkpoint: Optional path to saved LoRA weights to load.
241+ """
242+ from safetensors import safe_open
243+
244+ # 1. Open safetensors and read metadata
245+ sf = safe_open (checkpoint_path , framework = "pt" , device = "cpu" )
246+ meta = sf .metadata ()
247+
248+ # 2. Create instance without calling __init__
249+ self = cls .__new__ (cls )
250+ nn .Module .__init__ (self )
251+
252+ # 3. Reconstruct ArchConfig from metadata
253+ class _MinimalConfig :
254+ pass
255+
256+ cfg = _MinimalConfig ()
257+ cfg .model_type = meta ["model_type" ]
258+ if meta .get ("is_moe" ) == "True" :
259+ cfg .num_experts = int (meta ["num_experts" ])
260+ cfg .num_local_experts = int (meta ["num_experts" ])
261+ cfg .num_experts_per_tok = int (meta ["num_active_experts" ])
262+ cfg .moe_intermediate_size = int (meta ["expert_intermediate_size" ])
263+
264+ self .arch = detect_arch_config (cfg )
265+
266+ # 4. Set attributes from metadata and parameters
267+ self .config = None
268+ self .model_type = meta ["model_type" ]
269+ self .lora_r = lora_r
270+ self .lora_s = lora_alpha / lora_r
271+ self .k = int (meta .get ("k_attention" , "4" ))
272+ self .k_config = {}
273+ self .k_attention = int (meta ["k_attention" ])
274+ self .k_mlp = int (meta ["k_mlp" ])
275+ self .k_lm_head = int (meta ["k_lm_head" ])
276+ self .k_experts = int (meta ["k_experts" ])
277+ self .k_shared_expert = int (meta ["k_shared_expert" ])
278+ self .attn_chunk_size = attn_chunk_size
279+ self .mlp_chunk_size = mlp_chunk_size
280+ self .ce_chunk_size = ce_chunk_size
281+ self .compute_dtype = compute_dtype
282+ self .cpu_offload = weight_streaming
283+ self .weight_streaming = weight_streaming
284+ self .include_embed = True
285+ self .include_lm_head = True
286+ self .lora_on_experts = lora_on_experts
287+ self .expert_chunk_size = expert_chunk_size
288+
289+ self .hidden_size = int (meta ["hidden_size" ])
290+ self .num_heads = int (meta ["num_attention_heads" ])
291+ self .num_kv_heads = int (meta ["num_key_value_heads" ])
292+ self .head_dim = int (meta ["head_dim" ])
293+ self .q_dim = self .num_heads * self .head_dim
294+ self .kv_dim = self .num_kv_heads * self .head_dim
295+ self .intermediate_size = int (meta ["intermediate_size" ])
296+ self .vocab_size = int (meta ["vocab_size" ])
297+ self .num_layers = int (meta ["num_layers" ])
298+ self .rms_norm_eps = float (meta ["rms_norm_eps" ])
299+ self .rope_theta = float (meta ["rope_theta" ])
300+
301+ self ._layer_start = int (meta .get ("layer_start" , "0" ))
302+ self ._layer_end = int (meta .get ("layer_end" , meta ["num_layers" ]))
303+ self ._num_loaded_layers = int (meta .get ("num_loaded_layers" , meta ["num_layers" ]))
304+
305+ self ._streaming = True
306+ self ._target_device = target_device
307+ self .model = None
308+ self .lm_head_tied = False
309+
310+ # 5. Initialize parameter containers
311+ self ._quantized_weights = nn .ParameterDict ()
312+ self ._lora_params = nn .ParameterDict ()
313+ self ._norm_weights = nn .ParameterDict ()
314+
315+ # 6. Load embedding
316+ if "embed_tokens.weight" in sf .keys ():
317+ embed_weight = sf .get_tensor ("embed_tokens.weight" ).to (target_device )
318+ self .embed_tokens = nn .Embedding (self .vocab_size , self .hidden_size )
319+ self .embed_tokens .weight = nn .Parameter (embed_weight , requires_grad = False )
320+ else :
321+ self .embed_tokens = None
322+
323+ # 7. Populate _layer_data from safetensors
324+ self ._layer_data = []
325+ for i in range (self ._num_loaded_layers ):
326+ prefix = f"layer.{ i } "
327+ layer_info = {}
328+
329+ # Attention projections
330+ for proj in ["q_proj" , "k_proj" , "v_proj" , "o_proj" ]:
331+ N = int (meta [f"{ prefix } .attn.{ proj } .N" ])
332+ K = int (meta [f"{ prefix } .attn.{ proj } .K" ])
333+ N_padded = int (meta [f"{ prefix } .attn.{ proj } .N_padded" ])
334+ k_val = int (meta [f"{ prefix } .attn.{ proj } .k" ])
335+
336+ packed = sf .get_tensor (f"{ prefix } .attn.{ proj } .packed" )
337+ absmax = sf .get_tensor (f"{ prefix } .attn.{ proj } .absmax" )
338+ codebook = sf .get_tensor (f"{ prefix } .attn.{ proj } .codebook" )
339+
340+ if not weight_streaming :
341+ packed = packed .to (target_device )
342+ absmax = absmax .to (target_device )
343+ codebook = codebook .to (target_device )
344+
345+ A , B = self ._create_lora (f"layers_{ i } _attn_{ proj } " , N , K )
346+
347+ layer_info [proj ] = {
348+ "packed" : packed , "absmax" : absmax , "codebook" : codebook ,
349+ "N_padded" : N_padded , "N" : N , "K" : K , "k" : k_val ,
350+ "A" : A , "B" : B ,
351+ }
352+
353+ # MLP or MoE
354+ global_layer_idx = self ._layer_start + i
355+ is_moe_layer = self .arch .is_moe_layer (global_layer_idx )
356+
357+ if is_moe_layer :
358+ layer_info ["is_moe" ] = True
359+
360+ # Router weight (always on GPU, not quantized)
361+ router_weight = sf .get_tensor (f"{ prefix } .moe.router_weight" )
362+ layer_info ["router_weight" ] = router_weight .to (
363+ target_device , dtype = compute_dtype
364+ )
365+
366+ # Shared expert (if present)
367+ if self .arch .has_shared_expert :
368+ for proj in ["shared_gate_proj" , "shared_up_proj" , "shared_down_proj" ]:
369+ N = int (meta [f"{ prefix } .moe.{ proj } .N" ])
370+ K = int (meta [f"{ prefix } .moe.{ proj } .K" ])
371+ N_padded = int (meta [f"{ prefix } .moe.{ proj } .N_padded" ])
372+ k_val = int (meta [f"{ prefix } .moe.{ proj } .k" ])
373+
374+ packed = sf .get_tensor (f"{ prefix } .moe.{ proj } .packed" )
375+ absmax = sf .get_tensor (f"{ prefix } .moe.{ proj } .absmax" )
376+ codebook = sf .get_tensor (f"{ prefix } .moe.{ proj } .codebook" )
377+
378+ if not weight_streaming :
379+ packed = packed .to (target_device )
380+ absmax = absmax .to (target_device )
381+ codebook = codebook .to (target_device )
382+
383+ A , B = self ._create_lora (f"layers_{ i } _moe_{ proj } " , N , K )
384+
385+ layer_info [proj ] = {
386+ "packed" : packed , "absmax" : absmax , "codebook" : codebook ,
387+ "N_padded" : N_padded , "N" : N , "K" : K , "k" : k_val ,
388+ "A" : A , "B" : B ,
389+ }
390+
391+ # Expert weights (concatenated across all experts)
392+ expert_N = int (meta [f"{ prefix } .moe.experts.N" ])
393+ expert_K = int (meta [f"{ prefix } .moe.experts.K" ])
394+ expert_N_padded = int (meta [f"{ prefix } .moe.experts.N_padded" ])
395+ expert_k = int (meta [f"{ prefix } .moe.experts.k" ])
396+
397+ for expert_proj in ["gate" , "up" , "down" ]:
398+ for suffix in ["packed" , "absmax" ]:
399+ key = f"expert_{ expert_proj } _{ suffix } "
400+ tensor = sf .get_tensor (f"{ prefix } .moe.experts.{ expert_proj } .{ suffix } " )
401+ if not weight_streaming :
402+ tensor = tensor .to (target_device )
403+ layer_info [key ] = tensor
404+
405+ expert_codebook = sf .get_tensor (f"{ prefix } .moe.experts.codebook" )
406+ if not weight_streaming :
407+ expert_codebook = expert_codebook .to (target_device )
408+ layer_info ["expert_codebook" ] = expert_codebook
409+ layer_info ["expert_k" ] = expert_k
410+ layer_info ["expert_N" ] = expert_N
411+ layer_info ["expert_K" ] = expert_K
412+ layer_info ["expert_N_padded" ] = expert_N_padded
413+ else :
414+ # Dense MLP
415+ for proj in ["gate_proj" , "up_proj" , "down_proj" ]:
416+ N = int (meta [f"{ prefix } .mlp.{ proj } .N" ])
417+ K = int (meta [f"{ prefix } .mlp.{ proj } .K" ])
418+ N_padded = int (meta [f"{ prefix } .mlp.{ proj } .N_padded" ])
419+ k_val = int (meta [f"{ prefix } .mlp.{ proj } .k" ])
420+
421+ packed = sf .get_tensor (f"{ prefix } .mlp.{ proj } .packed" )
422+ absmax = sf .get_tensor (f"{ prefix } .mlp.{ proj } .absmax" )
423+ codebook = sf .get_tensor (f"{ prefix } .mlp.{ proj } .codebook" )
424+
425+ if not weight_streaming :
426+ packed = packed .to (target_device )
427+ absmax = absmax .to (target_device )
428+ codebook = codebook .to (target_device )
429+
430+ A , B = self ._create_lora (f"layers_{ i } _mlp_{ proj } " , N , K )
431+
432+ layer_info [proj ] = {
433+ "packed" : packed , "absmax" : absmax , "codebook" : codebook ,
434+ "N_padded" : N_padded , "N" : N , "K" : K , "k" : k_val ,
435+ "A" : A , "B" : B ,
436+ }
437+
438+ # Norm weights (always on GPU)
439+ for nk in ["input_layernorm" , "post_attention_layernorm" ]:
440+ tensor_name = f"{ prefix } .{ nk } .weight"
441+ if tensor_name in sf .keys ():
442+ weight = sf .get_tensor (tensor_name ).to (
443+ target_device , dtype = compute_dtype
444+ )
445+ safe_name = f"layers_{ i } _{ nk } _weight"
446+ self ._norm_weights [safe_name ] = nn .Parameter (weight )
447+ layer_info [nk ] = self ._norm_weights [safe_name ]
448+
449+ # QK norms (Qwen3)
450+ if self .arch .has_qk_norm :
451+ for nk in ["q_norm" , "k_norm" ]:
452+ tensor_name = f"{ prefix } .{ nk } .weight"
453+ if tensor_name in sf .keys ():
454+ weight = sf .get_tensor (tensor_name ).to (
455+ target_device , dtype = compute_dtype
456+ )
457+ safe_name = f"layers_{ i } _attn_{ nk } _weight"
458+ self ._norm_weights [safe_name ] = nn .Parameter (weight )
459+ layer_info [nk ] = self ._norm_weights [safe_name ]
460+
461+ self ._layer_data .append (layer_info )
462+
463+ # 8. Final norm
464+ if "final_norm.weight" in sf .keys ():
465+ weight = sf .get_tensor ("final_norm.weight" ).to (
466+ target_device , dtype = compute_dtype
467+ )
468+ self ._norm_weights ["final_norm_weight" ] = nn .Parameter (weight )
469+
470+ # 9. LM head (always on GPU — small relative to layer weights)
471+ self ._lm_head_info = None
472+ if "lm_head.packed" in sf .keys ():
473+ self ._lm_head_info = {
474+ "packed" : sf .get_tensor ("lm_head.packed" ).to (target_device ),
475+ "absmax" : sf .get_tensor ("lm_head.absmax" ).to (target_device ),
476+ "codebook" : sf .get_tensor ("lm_head.codebook" ).to (target_device ),
477+ "N_padded" : int (meta ["lm_head.N_padded" ]),
478+ "N" : int (meta ["lm_head.N" ]),
479+ "K" : int (meta ["lm_head.K" ]),
480+ "k" : int (meta ["lm_head.k" ]),
481+ }
482+
483+ # 10. Build RoPE cache
484+ self ._build_rope_cache (target_device )
485+
486+ # 11. Init weight streaming
487+ if weight_streaming :
488+ self ._init_weight_streaming ()
489+
490+ # 12. Set trainable params
491+ for p in self ._lora_params .parameters ():
492+ p .requires_grad_ (True )
493+ for p in self ._norm_weights .parameters ():
494+ p .requires_grad_ (True )
495+
496+ # 13. Load LoRA checkpoint (optional)
497+ if lora_checkpoint is not None :
498+ from bitsandbytes .checkpoint import load_lora
499+ load_lora (self , lora_checkpoint )
500+
501+ return self
502+
203503 # ─── Quantization & LoRA creation ───
204504
205505 def _quantize_weight (self , weight : torch .Tensor , name : str , k : int | None = None ):
@@ -550,21 +850,37 @@ def _init_weight_streaming(self):
550850
551851 # Pre-allocate 2 GPU buffer slots sized for the largest layer
552852 self ._copy_stream = torch .cuda .Stream (device = device )
853+
854+ def _layer_bytes (cpu_layer ):
855+ total = 0
856+ for v in cpu_layer .values ():
857+ if isinstance (v , dict ):
858+ total += sum (t .nbytes for t in v .values ())
859+ else :
860+ total += v .nbytes
861+ return total
862+
863+ largest_idx = max (range (len (self ._cpu_weights )), key = lambda i : _layer_bytes (self ._cpu_weights [i ]))
864+ largest_cpu_layer = self ._cpu_weights [largest_idx ]
865+
553866 self ._gpu_slots = []
554867 for _ in range (2 ):
555868 slot = {}
556- for i , cpu_layer in enumerate ( self . _cpu_weights ):
557- if i == 0 :
558- for key , cpu_tensor in cpu_layer .items ():
559- slot [ key ] = torch . empty_like ( cpu_tensor , device = device )
560- break
869+ for key , value in largest_cpu_layer . items ( ):
870+ if isinstance ( value , dict ) :
871+ slot [ key ] = { wk : torch . empty_like ( t , device = device ) for wk , t in value .items ()}
872+ else :
873+ slot [ key ] = torch . empty_like ( value , device = device )
561874 self ._gpu_slots .append (slot )
562875 self ._current_slot = 0
563876
877+ def _entry_bytes (v ):
878+ return sum (t .nbytes for t in v .values ()) if isinstance (v , dict ) else v .nbytes
879+
564880 total_cpu_bytes = sum (
565- sum (t . nbytes for t in cl .values ()) for cl in self ._cpu_weights
881+ sum (_entry_bytes ( v ) for v in cl .values ()) for cl in self ._cpu_weights
566882 )
567- slot_bytes = sum (t . nbytes for t in self ._gpu_slots [0 ].values ())
883+ slot_bytes = sum (_entry_bytes ( v ) for v in self ._gpu_slots [0 ].values ())
568884 print (
569885 f"Weight streaming: { total_cpu_bytes / 1e9 :.1f} GB on CPU pinned, "
570886 f"{ 2 * slot_bytes / 1e6 :.0f} MB GPU double-buffer "
@@ -576,17 +892,27 @@ def _stream_load_layer(self, layer_idx: int, slot: int, sync: bool = False):
576892 cpu_layer = self ._cpu_weights [layer_idx ]
577893 gpu_slot = self ._gpu_slots [slot ]
578894
895+ def _do_copies (non_blocking : bool ):
896+ for key , cpu_value in cpu_layer .items ():
897+ if isinstance (cpu_value , dict ):
898+ # Nested proj dict: {packed: tensor, absmax: tensor, codebook: tensor}
899+ if key not in gpu_slot :
900+ gpu_slot [key ] = {}
901+ for wk , cpu_tensor in cpu_value .items ():
902+ if wk not in gpu_slot [key ]:
903+ gpu_slot [key ][wk ] = torch .empty_like (cpu_tensor , device = self ._target_device )
904+ gpu_slot [key ][wk ].copy_ (cpu_tensor , non_blocking = non_blocking )
905+ else :
906+ # Flat tensor (expert concatenated weights)
907+ if key not in gpu_slot :
908+ gpu_slot [key ] = torch .empty_like (cpu_value , device = self ._target_device )
909+ gpu_slot [key ].copy_ (cpu_value , non_blocking = non_blocking )
910+
579911 if sync :
580- for key , cpu_tensor in cpu_layer .items ():
581- if key not in gpu_slot :
582- gpu_slot [key ] = torch .empty_like (cpu_tensor , device = self ._target_device )
583- gpu_slot [key ].copy_ (cpu_tensor )
912+ _do_copies (non_blocking = False )
584913 else :
585914 with torch .cuda .stream (self ._copy_stream ):
586- for key , cpu_tensor in cpu_layer .items ():
587- if key not in gpu_slot :
588- gpu_slot [key ] = torch .empty_like (cpu_tensor , device = self ._target_device )
589- gpu_slot [key ].copy_ (cpu_tensor , non_blocking = True )
915+ _do_copies (non_blocking = True )
590916
591917 def _get_layer_gpu_weights (self , layer_idx : int , slot : int ) -> dict :
592918 """Build a layer_info-compatible dict from GPU slot + always-resident data."""
0 commit comments