-
Notifications
You must be signed in to change notification settings - Fork 403
Add Qwen3VL MCore Export support from PR 895 #1482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a7d1170
36da6de
ff1152f
e8101a7
aecbbfa
80495e6
5bf943b
425145c
d6f03cd
6ad8d0e
77adc9d
5cdb6b4
3637fe7
57a4608
e8e2d7b
1a86b05
63a229a
73d74b3
4dbffb2
cf0fb9f
1243b42
3f0b921
8266670
f56e4c2
74019e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. | ||
|
|
||
| Qwen3-VL differs from Qwen3 in one structural way: language-model weights live | ||
| under ``model.language_model.`` instead of ``model.``, while ``lm_head.weight`` | ||
| remains at the root level. The mappings below are derived automatically from | ||
| the Qwen3 mappings by inserting ``language_model.`` after ``model.`` for every | ||
| prefix that starts with ``model.``. | ||
|
|
||
| Note: the visual encoder (``model.visual.*``) is intentionally excluded — this | ||
| mapping covers only the language-model decoder used for quantization and export. | ||
|
|
||
| Note: ``Qwen3VLMoeForConditionalGeneration`` is **not** supported here. The MoE | ||
| variant stores expert weights as 3-D tensors (``mlp.experts.gate_up_proj``, | ||
| ``mlp.experts.down_proj``) that require a dedicated fused-expert mapping and | ||
| cannot reuse the dense Qwen3 rules. | ||
|
|
||
| Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json | ||
| """ | ||
|
|
||
| import copy | ||
|
|
||
| from .mcore_custom import CustomModuleMapping | ||
| from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import | ||
|
|
||
|
|
||
| def _with_language_model_prefix( | ||
| mapping: dict[str, CustomModuleMapping], | ||
| ) -> dict[str, CustomModuleMapping]: | ||
| """Derive a VL mapping from a base Qwen3 mapping. | ||
|
|
||
| Rewrites every ``target_name_or_prefix`` that starts with ``model.`` to | ||
| ``model.language_model.<rest>``. Prefixes that do not start with | ||
| ``model.`` (e.g. ``lm_head.``) are left unchanged. | ||
| """ | ||
| result = {} | ||
| for key, m in mapping.items(): | ||
| prefix = m.target_name_or_prefix | ||
| if prefix.startswith("model."): | ||
| prefix = "model.language_model." + prefix[len("model.") :] | ||
| result[key] = type(m)( | ||
| target_name_or_prefix=prefix, func_kwargs=copy.deepcopy(m.func_kwargs) | ||
| ) | ||
| return result | ||
|
Comment on lines
+50
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] It's harmless today (these dicts are treated as immutable in the rest of the codebase), but a future caller that mutates result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=dict(m.func_kwargs)) |
||
|
|
||
|
Comment on lines
+42
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] Reconstructing each mapping via A more robust pattern is to deep-copy the original mapping and just rewrite the prefix: def _with_language_model_prefix(
mapping: dict[str, CustomModuleMapping],
) -> dict[str, CustomModuleMapping]:
result = {}
for key, m in mapping.items():
new_m = copy.deepcopy(m)
if new_m.target_name_or_prefix.startswith("model."):
new_m.target_name_or_prefix = (
"model.language_model." + new_m.target_name_or_prefix[len("model.") :]
)
result[key] = new_m
return resultThis preserves |
||
|
|
||
| qwen3vl_causal_lm_import = _with_language_model_prefix(qwen3_causal_lm_import) | ||
| qwen3vl_causal_lm_export = _with_language_model_prefix(qwen3_causal_lm_export) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[SUGGESTION] Scope-clarifying note: Qwen3-VL ships in two architectures —
Qwen3VLForConditionalGeneration(dense) andQwen3VLMoeForConditionalGeneration(MoE, e.g. Qwen/Qwen3-VL-30B-A3B-Instruct). This PR only registers the dense variant.The MoE variant cannot reuse
qwen3_causal_lm_exportwith a prefix rewrite because Qwen3-VL-MoE stores experts in fused form (mlp.experts.gate_up_proj/mlp.experts.down_projas 3-D tensors) rather than the per-expert layout (mlp.experts.{}.down_proj) thatqwen3_causal_lm_*assumes. So this is a real limitation, not just a missing registration.Consider adding a one-line note to the module docstring (e.g. "Covers the dense Qwen3VL variant only;
Qwen3VLMoeForConditionalGenerationuses a fused-expert layout and requires a separate mapping.") so that users hittingKeyError: 'Qwen3VLMoeForConditionalGeneration'in_populate_rule_bookknow it's intentional and what's missing.