Skip to content

Commit 80495e6

Browse files
hychiang-gitclaude
andcommitted
refactor: derive Qwen3-VL mcore mapping from Qwen3 via prefix rewrite
Replace the hand-written dict literals in mcore_qwen3vl.py with a helper that derives the VL mapping from qwen3_causal_lm_import/export by inserting 'language_model.' after 'model.' in every prefix. lm_head. (root-level) is left unchanged. Remove TestQwen3VLvsQwen3Difference since it now tests the implementation against itself. Note visual encoder (model.visual.*) is intentionally excluded from the mapping. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
1 parent aecbbfa commit 80495e6

2 files changed

Lines changed: 30 additions & 146 deletions

File tree

modelopt/torch/export/plugins/mcore_qwen3vl.py

Lines changed: 30 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -15,106 +15,39 @@
1515

1616
"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models.
1717
18-
Qwen3-VL model structure differs from Qwen3:
19-
- Language model weights are under `model.language_model.` prefix
20-
- Visual encoder weights are under `model.visual.` prefix
18+
Qwen3-VL differs from Qwen3 in one structural way: language-model weights live
19+
under ``model.language_model.`` instead of ``model.``, while ``lm_head.weight``
20+
remains at the root level. The mappings below are derived automatically from
21+
the Qwen3 mappings by inserting ``language_model.`` after ``model.`` for every
22+
prefix that starts with ``model.``.
2123
22-
This module handles the language model conversion for PTQ/QAT workflows.
23-
Visual components are typically kept in full precision.
24+
Note: the visual encoder (``model.visual.*``) is intentionally excluded — this
25+
mapping covers only the language-model decoder used for quantization and export.
2426
25-
HuggingFace Qwen3-VL-8B structure:
26-
- model.language_model.embed_tokens.weight
27-
- model.language_model.layers.{L}.input_layernorm.weight
28-
- model.language_model.layers.{L}.self_attn.q_proj.weight
29-
- model.language_model.layers.{L}.self_attn.k_proj.weight
30-
- model.language_model.layers.{L}.self_attn.v_proj.weight
31-
- model.language_model.layers.{L}.self_attn.q_norm.weight
32-
- model.language_model.layers.{L}.self_attn.k_norm.weight
33-
- model.language_model.layers.{L}.self_attn.o_proj.weight
34-
- model.language_model.layers.{L}.post_attention_layernorm.weight
35-
- model.language_model.layers.{L}.mlp.gate_proj.weight
36-
- model.language_model.layers.{L}.mlp.up_proj.weight
37-
- model.language_model.layers.{L}.mlp.down_proj.weight
38-
- model.language_model.norm.weight
39-
- lm_head.weight
27+
Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json
4028
"""
4129

42-
from .mcore_custom import (
43-
COL_ETP,
44-
COL_TP,
45-
REPLICATE,
46-
ROW_ETP,
47-
ROW_TP,
48-
CustomModuleMapping,
49-
GatedMLPMerging,
50-
GatedMLPSlicing,
51-
NameRemapping,
52-
QKVMerging,
53-
QKVSlicing,
54-
)
30+
from .mcore_custom import CustomModuleMapping
31+
from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import
5532

56-
# Import rules: HuggingFace -> Megatron Core
57-
qwen3vl_causal_lm_import: dict[str, CustomModuleMapping] = {
58-
# Embeddings - note the language_model prefix
59-
"word_embeddings": NameRemapping("model.language_model.embed_tokens.", COL_TP),
60-
# Final layer norm
61-
"final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE),
62-
# Output layer (lm_head is at root level, not under language_model)
63-
"output_layer": NameRemapping("lm_head.", COL_TP),
64-
# Attention - input layernorm
65-
"input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm.", REPLICATE),
66-
# Attention - QKV projection (merged)
67-
"linear_qkv": QKVMerging("model.language_model.layers.{}.self_attn.", COL_TP),
68-
# Attention - output projection
69-
"linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj.", ROW_TP),
70-
# Attention - Q/K layer norms (Qwen3 uses RMSNorm on Q and K)
71-
"q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm.", REPLICATE),
72-
"k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm.", REPLICATE),
73-
# MLP - pre-MLP layernorm (post_attention_layernorm in HF)
74-
"pre_mlp_layernorm": NameRemapping(
75-
"model.language_model.layers.{}.post_attention_layernorm.", REPLICATE
76-
),
77-
# MLP - gate_proj + up_proj merged into linear_fc1
78-
"linear_fc1": GatedMLPMerging("model.language_model.layers.{}.mlp.", COL_TP),
79-
# MLP - down_proj as linear_fc2
80-
"linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj.", ROW_TP),
81-
# MoE support (for Qwen3-VL MoE variants like 30B-A3B)
82-
"router": NameRemapping("model.language_model.layers.{}.mlp.gate.", REPLICATE),
83-
"local_experts.linear_fc1": GatedMLPMerging(
84-
"model.language_model.layers.{}.mlp.experts.{}.", COL_ETP
85-
),
86-
"local_experts.linear_fc2": NameRemapping(
87-
"model.language_model.layers.{}.mlp.experts.{}.down_proj.", ROW_ETP
88-
),
89-
}
9033

91-
# Export rules: Megatron Core -> HuggingFace
92-
qwen3vl_causal_lm_export: dict[str, CustomModuleMapping] = {
93-
# Embeddings
94-
"word_embeddings": NameRemapping("model.language_model.embed_tokens."),
95-
# Final layer norm
96-
"final_layernorm": NameRemapping("model.language_model.norm."),
97-
# Output layer
98-
"output_layer": NameRemapping("lm_head."),
99-
# Attention - input layernorm
100-
"input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm."),
101-
# Attention - QKV projection (sliced back to separate q/k/v)
102-
"linear_qkv": QKVSlicing("model.language_model.layers.{}.self_attn."),
103-
# Attention - output projection
104-
"linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj."),
105-
# Attention - Q/K layer norms
106-
"q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm."),
107-
"k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm."),
108-
# MLP - pre-MLP layernorm
109-
"pre_mlp_layernorm": NameRemapping("model.language_model.layers.{}.post_attention_layernorm."),
110-
# MLP - linear_fc1 sliced back to gate_proj + up_proj
111-
"linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp."),
112-
# MLP - down_proj
113-
"linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj."),
114-
# MoE support
115-
"router": NameRemapping("model.language_model.layers.{}.mlp.gate."),
116-
"local_experts.linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp.experts.{}."),
117-
"local_experts.linear_fc2": NameRemapping(
118-
"model.language_model.layers.{}.mlp.experts.{}.down_proj."
119-
),
120-
}
34+
def _with_language_model_prefix(
35+
mapping: dict[str, CustomModuleMapping],
36+
) -> dict[str, CustomModuleMapping]:
37+
"""Derive a VL mapping from a base Qwen3 mapping.
38+
39+
Rewrites every ``target_name_or_prefix`` that starts with ``model.`` to
40+
``model.language_model.<rest>``. Prefixes that do not start with
41+
``model.`` (e.g. ``lm_head.``) are left unchanged.
42+
"""
43+
result = {}
44+
for key, m in mapping.items():
45+
prefix = m.target_name_or_prefix
46+
if prefix.startswith("model."):
47+
prefix = "model.language_model." + prefix[len("model.") :]
48+
result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=m.func_kwargs)
49+
return result
50+
51+
52+
qwen3vl_causal_lm_import = _with_language_model_prefix(qwen3_causal_lm_import)
53+
qwen3vl_causal_lm_export = _with_language_model_prefix(qwen3_causal_lm_export)

tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -239,52 +239,3 @@ def test_mlp_fc1_matching_prefix(self):
239239
imp = qwen3vl_causal_lm_import["linear_fc1"]
240240
exp = qwen3vl_causal_lm_export["linear_fc1"]
241241
assert imp.target_name_or_prefix == exp.target_name_or_prefix
242-
243-
244-
class TestQwen3VLvsQwen3Difference:
245-
"""Test that Qwen3-VL differs from Qwen3 only in the language_model prefix."""
246-
247-
def test_same_keys_as_qwen3(self):
248-
from modelopt.torch.export.plugins.mcore_qwen import (
249-
qwen3_causal_lm_export,
250-
qwen3_causal_lm_import,
251-
)
252-
253-
assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3_causal_lm_import.keys())
254-
assert set(qwen3vl_causal_lm_export.keys()) == set(qwen3_causal_lm_export.keys())
255-
256-
@pytest.mark.parametrize(
257-
"key",
258-
[
259-
"word_embeddings",
260-
"final_layernorm",
261-
"input_layernorm",
262-
"linear_qkv",
263-
"linear_proj",
264-
"q_layernorm",
265-
"k_layernorm",
266-
"pre_mlp_layernorm",
267-
"linear_fc1",
268-
"linear_fc2",
269-
"router",
270-
"local_experts.linear_fc1",
271-
"local_experts.linear_fc2",
272-
],
273-
)
274-
def test_vl_adds_language_model_prefix(self, key):
275-
"""Qwen3-VL should have 'language_model.' inserted after 'model.'."""
276-
from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import
277-
278-
qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix
279-
qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix
280-
expected = qwen3_prefix.replace("model.", "model.language_model.", 1)
281-
assert qwen3vl_prefix == expected, f"{key}: expected '{expected}', got '{qwen3vl_prefix}'"
282-
283-
def test_output_layer_same(self):
284-
"""lm_head is at root level for both Qwen3 and Qwen3-VL."""
285-
from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import
286-
287-
assert (
288-
qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix
289-
== qwen3_causal_lm_import["output_layer"].target_name_or_prefix
290-
)

0 commit comments

Comments
 (0)