Skip to content

Commit e520fe3

Browse files
hychiang-gitclaude
andcommitted
test: add unit tests for Qwen3-VL mcore weight mapping
Add tests/unit/torch/export/test_mcore_qwen3vl_mapping.py covering mapping types, lm_head root placement, language_model prefix presence, layernorm replication annotations, TP sharding, and key-set symmetry. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
1 parent 80495e6 commit e520fe3

1 file changed

Lines changed: 101 additions & 0 deletions

File tree

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for the Qwen3-VL mcore weight mapping (mcore_qwen3vl.py).
17+
18+
Verifies that every key in qwen3vl_causal_lm_import / _export carries the
19+
``model.language_model.`` prefix (except ``lm_head.`` which stays at root),
20+
and that TP sharding and layernorm replication annotations are correct.
21+
"""
22+
23+
import pytest
24+
25+
from modelopt.torch.export.plugins.mcore_custom import (
26+
COL_TP,
27+
REPLICATE,
28+
ROW_TP,
29+
GatedMLPMerging,
30+
GatedMLPSlicing,
31+
NameRemapping,
32+
QKVMerging,
33+
QKVSlicing,
34+
)
35+
from modelopt.torch.export.plugins.mcore_qwen3vl import (
36+
qwen3vl_causal_lm_export,
37+
qwen3vl_causal_lm_import,
38+
)
39+
40+
41+
def test_mapping_types():
42+
assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging)
43+
assert isinstance(qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging)
44+
assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing)
45+
assert isinstance(qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing)
46+
47+
48+
def test_lm_head_at_root():
49+
assert qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix == "lm_head."
50+
assert qwen3vl_causal_lm_export["output_layer"].target_name_or_prefix == "lm_head."
51+
52+
53+
@pytest.mark.parametrize(
54+
"key",
55+
[
56+
"word_embeddings",
57+
"final_layernorm",
58+
"input_layernorm",
59+
"linear_qkv",
60+
"linear_proj",
61+
"q_layernorm",
62+
"k_layernorm",
63+
"pre_mlp_layernorm",
64+
"linear_fc1",
65+
"linear_fc2",
66+
],
67+
)
68+
def test_language_model_prefix(key):
69+
assert "model.language_model." in qwen3vl_causal_lm_import[key].target_name_or_prefix
70+
assert "model.language_model." in qwen3vl_causal_lm_export[key].target_name_or_prefix
71+
72+
73+
@pytest.mark.parametrize(
74+
"key",
75+
["input_layernorm", "q_layernorm", "k_layernorm", "pre_mlp_layernorm", "final_layernorm"],
76+
)
77+
def test_layernorm_replicated(key):
78+
m = qwen3vl_causal_lm_import[key]
79+
assert isinstance(m, NameRemapping)
80+
assert m.func_kwargs == REPLICATE
81+
82+
83+
def test_tp_sharding():
84+
assert qwen3vl_causal_lm_import["word_embeddings"].func_kwargs == COL_TP
85+
assert qwen3vl_causal_lm_import["output_layer"].func_kwargs == COL_TP
86+
assert qwen3vl_causal_lm_import["linear_proj"].func_kwargs == ROW_TP
87+
88+
89+
def test_export_no_parallel_config():
90+
for key in [
91+
"word_embeddings",
92+
"final_layernorm",
93+
"output_layer",
94+
"input_layernorm",
95+
"linear_proj",
96+
]:
97+
assert "parallel_config" not in qwen3vl_causal_lm_export[key].func_kwargs
98+
99+
100+
def test_import_export_same_keys():
101+
assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3vl_causal_lm_export.keys())

0 commit comments

Comments
 (0)