Skip to content

Commit 4297eed

Browse files
Update Qwen3 vLLM layer names to match tpu-inference mappings.
1 parent f8ea954 commit 4297eed

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

  • src/maxtext/integration/tunix/weight_mapping

src/maxtext/integration/tunix/weight_mapping/qwen3.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def to_hf_mapping():
6767
return {
6868
# Token embeddings - shard vocab dimension
6969
"base.token_embedder.embedding": (
70-
"model.embed.embedding",
70+
"model.embed_tokens.weight",
7171
("model", None),
7272
),
7373
# Final layer norm - no sharding needed
7474
"base.decoder.decoder_norm.scale": (
75-
"model.norm.scale",
75+
"model.norm.weight",
7676
(None,),
7777
),
7878
# LM head (logits projection) - shard vocab dimension
@@ -83,49 +83,49 @@ def to_hf_mapping():
8383
# Layer-specific mappings (scanned -> unscanned)
8484
# MLP components - shard hidden dimensions
8585
"base.decoder.layers.mlp.wi_0.kernel": (
86-
"model.layers.*.mlp.gate_proj.kernel",
86+
"model.layers.*.mlp.gate_proj.weight",
8787
(None, "layer", "model"),
8888
),
8989
"base.decoder.layers.mlp.wi_1.kernel": (
90-
"model.layers.*.mlp.up_proj.kernel",
90+
"model.layers.*.mlp.up_proj.weight",
9191
(None, "layer", "model"),
9292
),
9393
"base.decoder.layers.mlp.wo.kernel": (
94-
"model.layers.*.mlp.down_proj.kernel",
94+
"model.layers.*.mlp.down_proj.weight",
9595
("model", "layer", None),
9696
),
9797
# Layer norms - no sharding needed
9898
"base.decoder.layers.pre_self_attention_layer_norm.scale": (
99-
"model.layers.*.input_layernorm.scale",
99+
"model.layers.*.input_layernorm.weight",
100100
(None, "layer"),
101101
),
102102
"base.decoder.layers.post_self_attention_layer_norm.scale": (
103-
"model.layers.*.post_attention_layernorm.scale",
103+
"model.layers.*.post_attention_layernorm.weight",
104104
(None, "layer"),
105105
),
106106
# Attention components - shard head dimensions
107107
"base.decoder.layers.self_attention.query.kernel": (
108-
"model.layers.*.self_attn.q_proj.kernel",
108+
"model.layers.*.self_attn.q_proj.weight",
109109
(None, "layer", "model", None),
110110
),
111111
"base.decoder.layers.self_attention.key.kernel": (
112-
"model.layers.*.self_attn.k_proj.kernel",
112+
"model.layers.*.self_attn.k_proj.weight",
113113
(None, "layer", "model", None),
114114
),
115115
"base.decoder.layers.self_attention.value.kernel": (
116-
"model.layers.*.self_attn.v_proj.kernel",
116+
"model.layers.*.self_attn.v_proj.weight",
117117
(None, "layer", "model", None),
118118
),
119119
"base.decoder.layers.self_attention.out.kernel": (
120-
"model.layers.*.self_attn.o_proj.kernel",
120+
"model.layers.*.self_attn.o_proj.weight",
121121
("model", "layer", None, None),
122122
),
123123
"base.decoder.layers.self_attention.query_norm.scale": (
124-
"model.layers.*.self_attn.q_norm.scale",
124+
"model.layers.*.self_attn.q_norm.weight",
125125
(None, "layer"),
126126
),
127127
"base.decoder.layers.self_attention.key_norm.scale": (
128-
"model.layers.*.self_attn.k_norm.scale",
128+
"model.layers.*.self_attn.k_norm.weight",
129129
(None, "layer"),
130130
),
131131
}

0 commit comments

Comments
 (0)