55quantizer that converts HF checkpoints layer-by-layer with minimal memory.
66"""
77
8+ from collections import OrderedDict
89import json
910import os
1011import shutil
1112import struct
12- from collections import OrderedDict
1313from typing import Optional
1414
15- import torch
16-
17- from safetensors .torch import save_file
1815from safetensors import safe_open
16+ from safetensors .torch import save_file
17+ import torch
1918
2019from bitsandbytes .arch_config import ArchConfig , detect_arch_config
2120
@@ -119,9 +118,7 @@ def save_quantized(model, path: str):
119118
120119 # Dense layer indices (comma-separated, empty if None or all MoE)
121120 if model .arch .dense_layer_indices is not None :
122- metadata ["dense_layer_indices" ] = "," .join (
123- str (i ) for i in model .arch .dense_layer_indices
124- )
121+ metadata ["dense_layer_indices" ] = "," .join (str (i ) for i in model .arch .dense_layer_indices )
125122 else :
126123 metadata ["dense_layer_indices" ] = ""
127124
@@ -245,7 +242,7 @@ def streaming_quantize(
245242 k : int = 4 ,
246243 k_config : Optional [dict [str , int ]] = None ,
247244 arch_config : Optional [ArchConfig ] = None ,
248- device : torch .device = torch . device ( "cuda:0" ) ,
245+ device : Optional [ torch .device ] = None ,
249246):
250247 """Quantize a HuggingFace model layer-by-layer and write to safetensors.
251248
@@ -267,6 +264,9 @@ def streaming_quantize(
267264 arch_config: Optional ArchConfig override.
268265 device: GPU device for quantization kernels.
269266 """
267+ if device is None :
268+ device = torch .device ("cuda:0" )
269+
270270 from transformers import AutoConfig
271271
272272 import bitsandbytes .functional as F
@@ -297,6 +297,7 @@ def streaming_quantize(
297297 model_dir = model_name_or_path
298298 else :
299299 from huggingface_hub import snapshot_download
300+
300301 model_dir = snapshot_download (model_name_or_path )
301302
302303 # ─── Build weight map: tensor_name → shard_filename ───
@@ -409,15 +410,19 @@ def _add_expert_concat(out_prefix, layer_idx, proj_attr, k_val, meta_prefix):
409410
410411 # --- Per-layer tensor specs ---
411412 _attn_projs = [
412- ("q_proj" , arch .q_proj ), ("k_proj" , arch .k_proj ),
413- ("v_proj" , arch .v_proj ), ("o_proj" , arch .o_proj ),
413+ ("q_proj" , arch .q_proj ),
414+ ("k_proj" , arch .k_proj ),
415+ ("v_proj" , arch .v_proj ),
416+ ("o_proj" , arch .o_proj ),
414417 ]
415418 _mlp_projs = [
416- ("gate_proj" , arch .gate_proj ), ("up_proj" , arch .up_proj ),
419+ ("gate_proj" , arch .gate_proj ),
420+ ("up_proj" , arch .up_proj ),
417421 ("down_proj" , arch .down_proj ),
418422 ]
419423 _expert_projs = [
420- ("gate" , arch .expert_gate_proj ), ("up" , arch .expert_up_proj ),
424+ ("gate" , arch .expert_gate_proj ),
425+ ("up" , arch .expert_up_proj ),
421426 ("down" , arch .expert_down_proj ),
422427 ]
423428
@@ -441,15 +446,20 @@ def _add_expert_concat(out_prefix, layer_idx, proj_attr, k_val, meta_prefix):
441446 ("shared_down_proj" , arch .down_proj ),
442447 ]:
443448 _add_quantized (
444- f"{ pfx } .moe.{ name } " , _hf_shared_expert (i , attr ),
445- k_shared_expert , f"{ pfx } .moe.{ name } " ,
449+ f"{ pfx } .moe.{ name } " ,
450+ _hf_shared_expert (i , attr ),
451+ k_shared_expert ,
452+ f"{ pfx } .moe.{ name } " ,
446453 )
447454
448455 # Experts (concatenated)
449456 for name , attr in _expert_projs :
450457 _add_expert_concat (
451- f"{ pfx } .moe.experts.{ name } " , i , attr ,
452- k_experts , f"{ pfx } .moe.experts" ,
458+ f"{ pfx } .moe.experts.{ name } " ,
459+ i ,
460+ attr ,
461+ k_experts ,
462+ f"{ pfx } .moe.experts" ,
453463 )
454464
455465 # Expert codebook (shared across projection types)
@@ -480,35 +490,35 @@ def _add_expert_concat(out_prefix, layer_idx, proj_attr, k_val, meta_prefix):
480490 _add_copy ("embed_tokens.weight" , f"{ arch .embed_path } .weight" )
481491
482492 # --- Global metadata ---
483- metadata .update ({
484- "model_type" : config . model_type ,
485- "hidden_size " : str ( hidden_size ) ,
486- "num_layers " : str (num_layers ),
487- "num_loaded_layers " : str (num_layers ),
488- "layer_start " : "0" ,
489- "layer_end " : str ( num_layers ) ,
490- "num_attention_heads " : str (num_heads ),
491- "num_key_value_heads " : str (num_kv_heads ),
492- "head_dim " : str (head_dim ),
493- "intermediate_size " : str (intermediate_size ),
494- "vocab_size " : str (vocab_size ),
495- "rms_norm_eps " : str (rms_norm_eps ),
496- "rope_theta " : str (rope_theta ),
497- "k_attention " : str (k_attn ),
498- "k_mlp " : str (k_mlp ),
499- "k_lm_head " : str (k_lm_head ),
500- "k_experts " : str (k_experts ),
501- "k_shared_expert " : str (k_shared_expert ),
502- "is_moe " : str (arch . is_moe ),
503- "num_experts " : str (arch .num_experts ),
504- "num_active_experts " : str (arch .num_active_experts ),
505- "expert_intermediate_size " : str (arch .expert_intermediate_size ),
506- "has_shared_expert " : str (arch .has_shared_expert ),
507- "has_qk_norm " : str (arch .has_qk_norm ),
508- "dense_layer_indices " : "," . join (
509- str (x ) for x in (arch .dense_layer_indices or [])
510- ),
511- } )
493+ metadata .update (
494+ {
495+ "model_type " : config . model_type ,
496+ "hidden_size " : str (hidden_size ),
497+ "num_layers " : str (num_layers ),
498+ "num_loaded_layers " : str ( num_layers ) ,
499+ "layer_start " : "0" ,
500+ "layer_end " : str (num_layers ),
501+ "num_attention_heads " : str (num_heads ),
502+ "num_key_value_heads " : str (num_kv_heads ),
503+ "head_dim " : str (head_dim ),
504+ "intermediate_size " : str (intermediate_size ),
505+ "vocab_size " : str (vocab_size ),
506+ "rms_norm_eps " : str (rms_norm_eps ),
507+ "rope_theta " : str (rope_theta ),
508+ "k_attention " : str (k_attn ),
509+ "k_mlp " : str (k_mlp ),
510+ "k_lm_head " : str (k_lm_head ),
511+ "k_experts " : str (k_experts ),
512+ "k_shared_expert " : str (k_shared_expert ),
513+ "is_moe " : str (arch .is_moe ),
514+ "num_experts " : str (arch .num_experts ),
515+ "num_active_experts " : str (arch .num_active_experts ),
516+ "expert_intermediate_size " : str (arch .expert_intermediate_size ),
517+ "has_shared_expert " : str (arch .has_shared_expert ),
518+ "has_qk_norm " : str ( arch . has_qk_norm ),
519+ "dense_layer_indices" : "," . join ( str (x ) for x in (arch .dense_layer_indices or [])),
520+ }
521+ )
512522
513523 # ─── Write safetensors header + pre-allocate file ───
514524
@@ -549,9 +559,7 @@ def _add_expert_concat(out_prefix, layer_idx, proj_attr, k_val, meta_prefix):
549559 def _load_hf (hf_name ):
550560 shard = _get_shard (hf_name )
551561 if shard not in _shard_handles :
552- _shard_handles [shard ] = safe_open (
553- os .path .join (model_dir , shard ), framework = "pt" , device = "cpu"
554- )
562+ _shard_handles [shard ] = safe_open (os .path .join (model_dir , shard ), framework = "pt" , device = "cpu" )
555563 return _shard_handles [shard ].get_tensor (hf_name )
556564
557565 _TORCH_DTYPE = {"F16" : torch .float16 , "BF16" : torch .bfloat16 , "F32" : torch .float32 }
@@ -572,17 +580,15 @@ def _write(out_name, tensor):
572580 def _quantize_and_write (out_prefix , hf_name , k_val ):
573581 """Load, pad, quantize one projection, write packed/absmax/codebook."""
574582 weight = _load_hf (hf_name ).to (device )
575- N , K_dim = weight .shape
583+ N , _K_dim = weight .shape
576584 N_padded = ((N + 127 ) // 128 ) * 128
577585 if N_padded != N :
578586 w = torch .nn .functional .pad (weight .float (), (0 , 0 , 0 , N_padded - N ))
579587 else :
580588 w = weight .float ()
581589 del weight
582590
583- packed , absmax , codebook = F .quantize_kbit (
584- w .reshape (- 1 ), k = k_val , absmax_format = "fp32"
585- )
591+ packed , absmax , codebook = F .quantize_kbit (w .reshape (- 1 ), k = k_val , absmax_format = "fp32" )
586592 del w
587593
588594 _write (f"{ out_prefix } .packed" , packed )
@@ -622,7 +628,8 @@ def _copy_and_write(out_name, hf_name):
622628 ("shared_down_proj" , arch .down_proj ),
623629 ]:
624630 _quantize_and_write (
625- f"{ pfx } .moe.{ name } " , _hf_shared_expert (i , attr ),
631+ f"{ pfx } .moe.{ name } " ,
632+ _hf_shared_expert (i , attr ),
626633 k_shared_expert ,
627634 )
628635
@@ -634,16 +641,14 @@ def _copy_and_write(out_name, hf_name):
634641
635642 for e in range (arch .num_experts ):
636643 w = _load_hf (_hf_expert (i , e , attr )).to (device )
637- N , K_dim = w .shape
644+ N , _K_dim = w .shape
638645 N_padded = ((N + 127 ) // 128 ) * 128
639646 if N_padded != N :
640647 w = torch .nn .functional .pad (w .float (), (0 , 0 , 0 , N_padded - N ))
641648 else :
642649 w = w .float ()
643650
644- packed , absmax , codebook = F .quantize_kbit (
645- w .reshape (- 1 ), k = k_experts , absmax_format = "fp32"
646- )
651+ packed , absmax , codebook = F .quantize_kbit (w .reshape (- 1 ), k = k_experts , absmax_format = "fp32" )
647652 del w
648653
649654 all_packed .append (packed .cpu ())
@@ -669,9 +674,7 @@ def _copy_and_write(out_name, hf_name):
669674
670675 # Norms
671676 _copy_and_write (f"{ pfx } .input_layernorm.weight" , _hf_norm (i , arch .input_norm ))
672- _copy_and_write (
673- f"{ pfx } .post_attention_layernorm.weight" , _hf_norm (i , arch .post_attn_norm )
674- )
677+ _copy_and_write (f"{ pfx } .post_attention_layernorm.weight" , _hf_norm (i , arch .post_attn_norm ))
675678
676679 if arch .has_qk_norm :
677680 _copy_and_write (f"{ pfx } .q_norm.weight" , _hf_qk_norm (i , arch .q_norm ))
0 commit comments