Skip to content

Commit 23ee333

Browse files
TimDettmersclaude
andcommitted
feat: Add comprehensive metadata to save_quantized for load_quantized
Metadata now includes all fields needed by from_quantized() to reconstruct a KbitLoraModel without the original HF model: - Model config: hidden_size, num_attention_heads, num_key_value_heads, head_dim, intermediate_size, vocab_size, rms_norm_eps, rope_theta - MoE config: expert_intermediate_size, has_shared_expert, has_qk_norm, dense_layer_indices - Per-projection dims: N, K, N_padded, k for every attention/MLP/expert projection in every layer, plus LM head dims Updated test_checkpoint.py to verify all metadata fields are present and correct for a tiny Llama model. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b736d14 commit 23ee333

File tree

2 files changed

+114
-2
lines changed

2 files changed

+114
-2
lines changed

bitsandbytes/checkpoint.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,86 @@ def save_quantized(model, path: str):
7979
if model.embed_tokens is not None:
8080
tensors["embed_tokens.weight"] = model.embed_tokens.weight.data
8181

82-
# Metadata
82+
# Metadata — comprehensive, enables load_quantized without the HF model
8383
metadata = {
84+
# Model architecture
8485
"model_type": model.model_type,
8586
"hidden_size": str(model.hidden_size),
8687
"num_layers": str(model.num_layers),
8788
"num_loaded_layers": str(model._num_loaded_layers),
8889
"layer_start": str(model._layer_start),
8990
"layer_end": str(model._layer_end),
91+
"num_attention_heads": str(model.num_heads),
92+
"num_key_value_heads": str(model.num_kv_heads),
93+
"head_dim": str(model.head_dim),
94+
"intermediate_size": str(model.intermediate_size),
95+
"vocab_size": str(model.vocab_size),
96+
"rms_norm_eps": str(model.rms_norm_eps),
97+
"rope_theta": str(model.rope_theta),
98+
# Quantization config
9099
"k_attention": str(model.k_attention),
91100
"k_mlp": str(model.k_mlp),
92101
"k_lm_head": str(model.k_lm_head),
93102
"k_experts": str(model.k_experts),
94103
"k_shared_expert": str(model.k_shared_expert),
104+
# MoE config
95105
"is_moe": str(model.arch.is_moe),
96106
"num_experts": str(model.arch.num_experts),
97107
"num_active_experts": str(model.arch.num_active_experts),
108+
"expert_intermediate_size": str(model.arch.expert_intermediate_size),
109+
"has_shared_expert": str(model.arch.has_shared_expert),
110+
"has_qk_norm": str(model.arch.has_qk_norm),
98111
}
99112

113+
# Dense layer indices (comma-separated, empty if None or all MoE)
114+
if model.arch.dense_layer_indices is not None:
115+
metadata["dense_layer_indices"] = ",".join(
116+
str(i) for i in model.arch.dense_layer_indices
117+
)
118+
else:
119+
metadata["dense_layer_indices"] = ""
120+
121+
# Per-projection dimensions (needed for LoRA initialization in load_quantized)
122+
for i, layer_info in enumerate(model._layer_data):
123+
prefix = f"layer.{i}"
124+
125+
# Attention projections
126+
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
127+
metadata[f"{prefix}.attn.{proj}.N"] = str(layer_info[proj]["N"])
128+
metadata[f"{prefix}.attn.{proj}.K"] = str(layer_info[proj]["K"])
129+
metadata[f"{prefix}.attn.{proj}.N_padded"] = str(layer_info[proj]["N_padded"])
130+
metadata[f"{prefix}.attn.{proj}.k"] = str(layer_info[proj]["k"])
131+
132+
# MLP or MoE
133+
if layer_info.get("is_moe"):
134+
# Shared expert dims
135+
if "shared_gate_proj" in layer_info:
136+
for proj in ["shared_gate_proj", "shared_up_proj", "shared_down_proj"]:
137+
metadata[f"{prefix}.moe.{proj}.N"] = str(layer_info[proj]["N"])
138+
metadata[f"{prefix}.moe.{proj}.K"] = str(layer_info[proj]["K"])
139+
metadata[f"{prefix}.moe.{proj}.N_padded"] = str(layer_info[proj]["N_padded"])
140+
metadata[f"{prefix}.moe.{proj}.k"] = str(layer_info[proj]["k"])
141+
142+
# Expert dims (same for all experts — store once)
143+
metadata[f"{prefix}.moe.experts.N"] = str(layer_info.get("expert_N", 0))
144+
metadata[f"{prefix}.moe.experts.K"] = str(layer_info.get("expert_K", 0))
145+
metadata[f"{prefix}.moe.experts.N_padded"] = str(layer_info.get("expert_N_padded", 0))
146+
metadata[f"{prefix}.moe.experts.k"] = str(layer_info.get("expert_k", 0))
147+
else:
148+
for proj in ["gate_proj", "up_proj", "down_proj"]:
149+
metadata[f"{prefix}.mlp.{proj}.N"] = str(layer_info[proj]["N"])
150+
metadata[f"{prefix}.mlp.{proj}.K"] = str(layer_info[proj]["K"])
151+
metadata[f"{prefix}.mlp.{proj}.N_padded"] = str(layer_info[proj]["N_padded"])
152+
metadata[f"{prefix}.mlp.{proj}.k"] = str(layer_info[proj]["k"])
153+
154+
# LM head dims
155+
if model._lm_head_info is not None:
156+
lm = model._lm_head_info
157+
metadata["lm_head.N"] = str(lm["N"])
158+
metadata["lm_head.K"] = str(lm["K"])
159+
metadata["lm_head.N_padded"] = str(lm["N_padded"])
160+
metadata["lm_head.k"] = str(lm["k"])
161+
100162
# Move all tensors to CPU for saving
101163
cpu_tensors = OrderedDict()
102164
for k, v in tensors.items():

tests/test_checkpoint.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,59 @@ def test_metadata_present(self, kbit_model):
8181
save_quantized(kbit_model, path)
8282
sf = safe_open(path, framework="pt", device="cpu")
8383
meta = sf.metadata()
84+
85+
# Model architecture
8486
assert meta["model_type"] == "llama"
85-
assert meta["k_attention"] == "4"
87+
assert int(meta["hidden_size"]) == 256
8688
assert int(meta["num_layers"]) == 2
89+
assert int(meta["num_attention_heads"]) == 4
90+
assert int(meta["num_key_value_heads"]) == 2
91+
assert int(meta["head_dim"]) == 64 # 256 / 4
92+
assert int(meta["intermediate_size"]) == 512
93+
assert int(meta["vocab_size"]) == 1000
94+
assert float(meta["rms_norm_eps"]) > 0
95+
assert float(meta["rope_theta"]) > 0
96+
97+
# Quantization config
98+
assert meta["k_attention"] == "4"
99+
assert meta["k_mlp"] == "4"
100+
assert meta["k_lm_head"] == "4"
101+
assert meta["k_experts"] == "4"
102+
assert meta["k_shared_expert"] == "4"
103+
104+
# MoE config
105+
assert meta["is_moe"] == "False"
106+
assert meta["has_shared_expert"] == "False"
107+
assert meta["has_qk_norm"] == "False"
108+
assert meta["dense_layer_indices"] == ""
109+
110+
# Per-projection dims for layer 0 attention
111+
assert int(meta["layer.0.attn.q_proj.N"]) == 256 # q_dim = 4 * 64
112+
assert int(meta["layer.0.attn.q_proj.K"]) == 256 # hidden_size
113+
assert int(meta["layer.0.attn.q_proj.N_padded"]) == 256 # already mult of 128
114+
assert int(meta["layer.0.attn.q_proj.k"]) == 4
115+
116+
assert int(meta["layer.0.attn.k_proj.N"]) == 128 # kv_dim = 2 * 64
117+
assert int(meta["layer.0.attn.k_proj.K"]) == 256
118+
119+
# MLP dims
120+
assert int(meta["layer.0.mlp.gate_proj.N"]) == 512 # intermediate
121+
assert int(meta["layer.0.mlp.gate_proj.K"]) == 256 # hidden
122+
123+
# LM head dims
124+
assert int(meta["lm_head.N"]) == 1000 # vocab_size
125+
assert int(meta["lm_head.K"]) == 256 # hidden_size
126+
127+
# Check all layers have dims
128+
for i in range(2):
129+
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
130+
assert f"layer.{i}.attn.{proj}.N" in meta
131+
assert f"layer.{i}.attn.{proj}.K" in meta
132+
assert f"layer.{i}.attn.{proj}.N_padded" in meta
133+
assert f"layer.{i}.attn.{proj}.k" in meta
134+
for proj in ["gate_proj", "up_proj", "down_proj"]:
135+
assert f"layer.{i}.mlp.{proj}.N" in meta
136+
assert f"layer.{i}.mlp.{proj}.K" in meta
87137
finally:
88138
os.unlink(path)
89139

0 commit comments

Comments
 (0)