From a7d1170cf43615027708abe0347152e8062b0cb0 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Date: Sat, 14 Feb 2026 01:54:53 +0000 Subject: [PATCH 01/14] [Megatron Export] Add Qwen3-VL export/import mapping Add Megatron Core export/import mapping for Qwen3-VL (Qwen3VLForConditionalGeneration). Handles the model.language_model. weight prefix and supports both dense and MoE variants. Signed-off-by: Hung-Yueh mv test_mcore_qwen3vl.py to tests/gpu_megatron/torch/export/ Signed-off-by: Hung-Yueh Chiang --- CHANGELOG.rst | 1 + docs/source/deployment/3_unified_hf.rst | 1 + modelopt/torch/export/plugins/mcore_common.py | 6 + .../torch/export/plugins/mcore_qwen3vl.py | 120 +++++++ .../torch/export/test_mcore_qwen3vl.py | 306 ++++++++++++++++++ 5 files changed, 434 insertions(+) create mode 100644 modelopt/torch/export/plugins/mcore_qwen3vl.py create mode 100644 tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 62f2b0041cb..8fd414ec5fc 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,6 +24,7 @@ Changelog - Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md `_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. +- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL and supports both dense and MoE variants. 0.44 (2026-05-18) ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/deployment/3_unified_hf.rst b/docs/source/deployment/3_unified_hf.rst index 9124164b576..6664f987f72 100644 --- a/docs/source/deployment/3_unified_hf.rst +++ b/docs/source/deployment/3_unified_hf.rst @@ -61,6 +61,7 @@ Models: * Llama 4, 3.x (FP8, NVFP4) * Qwen 3, 2.5 (FP8, NVFP4) * Qwen 3 MoE (FP8, NVFP4) + * Qwen 3-VL (FP8, NVFP4) * Deepseek R1/V3 (NVFP4) * Mixtral 8x7B (FP8, NVFP4) * Medusa (FP8) diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index d5bab9b4ece..660e4eac96d 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -39,6 +39,10 @@ qwen25_causal_lm_export, qwen25_causal_lm_import, ) +from .mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) all_mcore_hf_export_mapping: dict[str, Any] = { "DeepseekV2ForCausalLM": deepseek_causal_lm_export, @@ -54,6 +58,7 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, "GptOssForCausalLM": gptoss_causal_lm_export, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export, } all_mcore_hf_import_mapping: dict[str, Any] = { @@ -66,4 +71,5 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_import, "Qwen2ForCausalLM": qwen25_causal_lm_import, "GptOssForCausalLM": gptoss_causal_lm_import, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import, } diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py new file mode 100644 index 00000000000..40eb99adb50 --- /dev/null +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. + +Qwen3-VL model structure differs from Qwen3: +- Language model weights are under `model.language_model.` prefix +- Visual encoder weights are under `model.visual.` prefix + +This module handles the language model conversion for PTQ/QAT workflows. +Visual components are typically kept in full precision. + +HuggingFace Qwen3-VL-8B structure: +- model.language_model.embed_tokens.weight +- model.language_model.layers.{L}.input_layernorm.weight +- model.language_model.layers.{L}.self_attn.q_proj.weight +- model.language_model.layers.{L}.self_attn.k_proj.weight +- model.language_model.layers.{L}.self_attn.v_proj.weight +- model.language_model.layers.{L}.self_attn.q_norm.weight +- model.language_model.layers.{L}.self_attn.k_norm.weight +- model.language_model.layers.{L}.self_attn.o_proj.weight +- model.language_model.layers.{L}.post_attention_layernorm.weight +- model.language_model.layers.{L}.mlp.gate_proj.weight +- model.language_model.layers.{L}.mlp.up_proj.weight +- model.language_model.layers.{L}.mlp.down_proj.weight +- model.language_model.norm.weight +- lm_head.weight +""" + +from .mcore_custom import ( + COL_ETP, + COL_TP, + REPLICATE, + ROW_ETP, + ROW_TP, + CustomModuleMapping, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) + +# Import rules: HuggingFace -> Megatron Core +qwen3vl_causal_lm_import: dict[str, CustomModuleMapping] = { + # Embeddings - note the language_model prefix + "word_embeddings": NameRemapping("model.language_model.embed_tokens.", COL_TP), + # Final layer norm + "final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE), + # Output layer (lm_head is at root level, not under language_model) + "output_layer": NameRemapping("lm_head.", COL_TP), + # Attention - input layernorm + "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm.", REPLICATE), + # Attention - QKV projection (merged) + "linear_qkv": QKVMerging("model.language_model.layers.{}.self_attn.", COL_TP), + # Attention - output projection + "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj.", ROW_TP), + # Attention - Q/K layer norms (Qwen3 uses RMSNorm on Q and K) + "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm.", REPLICATE), + "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm.", REPLICATE), + # MLP - pre-MLP layernorm (post_attention_layernorm in HF) + "pre_mlp_layernorm": NameRemapping( + "model.language_model.layers.{}.post_attention_layernorm.", REPLICATE + ), + # MLP - gate_proj + up_proj merged into linear_fc1 + "linear_fc1": GatedMLPMerging("model.language_model.layers.{}.mlp.", COL_TP), + # MLP - down_proj as linear_fc2 + "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj.", ROW_TP), + # MoE support (for Qwen3-VL MoE variants like 30B-A3B) + "router": NameRemapping("model.language_model.layers.{}.mlp.gate.", REPLICATE), + "local_experts.linear_fc1": GatedMLPMerging( + "model.language_model.layers.{}.mlp.experts.{}.", COL_ETP + ), + "local_experts.linear_fc2": NameRemapping( + "model.language_model.layers.{}.mlp.experts.{}.down_proj.", ROW_ETP + ), +} + +# Export rules: Megatron Core -> HuggingFace +qwen3vl_causal_lm_export: dict[str, CustomModuleMapping] = { + # Embeddings + "word_embeddings": NameRemapping("model.language_model.embed_tokens."), + # Final layer norm + "final_layernorm": NameRemapping("model.language_model.norm."), + # Output layer + "output_layer": NameRemapping("lm_head."), + # Attention - input layernorm + "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm."), + # Attention - QKV projection (sliced back to separate q/k/v) + "linear_qkv": QKVSlicing("model.language_model.layers.{}.self_attn."), + # Attention - output projection + "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj."), + # Attention - Q/K layer norms + "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm."), + "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm."), + # MLP - pre-MLP layernorm + "pre_mlp_layernorm": NameRemapping("model.language_model.layers.{}.post_attention_layernorm."), + # MLP - linear_fc1 sliced back to gate_proj + up_proj + "linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp."), + # MLP - down_proj + "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj."), + # MoE support + "router": NameRemapping("model.language_model.layers.{}.mlp.gate."), + "local_experts.linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp.experts.{}."), + "local_experts.linear_fc2": NameRemapping( + "model.language_model.layers.{}.mlp.experts.{}.down_proj." + ), +} \ No newline at end of file diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py new file mode 100644 index 00000000000..3f57cb9c478 --- /dev/null +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Qwen3-VL Megatron Core export/import plugin.""" + +import pytest + +from modelopt.torch.export.plugins.mcore_custom import ( + COL_TP, + REPLICATE, + ROW_TP, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) +from modelopt.torch.export.plugins.mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) + + +# All mcore keys that a dense (non-MoE) Qwen3-VL model should have +DENSE_MCORE_KEYS = { + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", +} + +# Additional MoE keys +MOE_MCORE_KEYS = { + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", +} + + +class TestQwen3VLRegistration: + """Test that Qwen3-VL is registered in the global mapping.""" + + def test_registered_in_export_mapping(self): + from modelopt.torch.export.plugins.mcore_common import ( + all_mcore_hf_export_mapping, + ) + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_export_mapping + assert ( + all_mcore_hf_export_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_export + ) + + def test_registered_in_import_mapping(self): + from modelopt.torch.export.plugins.mcore_common import ( + all_mcore_hf_import_mapping, + ) + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_import_mapping + assert ( + all_mcore_hf_import_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_import + ) + + +class TestQwen3VLImportMapping: + """Test the HuggingFace -> Megatron Core import mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_language_model_prefix(self): + """Qwen3-VL uses model.language_model. prefix (not model.).""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_import[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + """lm_head is at root level, not under language_model.""" + mapping = qwen3vl_causal_lm_import["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_merging(self): + assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging) + + def test_mlp_uses_gated_merging(self): + assert isinstance( + qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging + ) + + @pytest.mark.parametrize( + "key", + [ + "input_layernorm", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "final_layernorm", + ], + ) + def test_layernorms_are_replicated(self, key): + """Layernorms should use REPLICATE (empty func_kwargs).""" + mapping = qwen3vl_causal_lm_import[key] + assert isinstance(mapping, NameRemapping) + assert mapping.func_kwargs == REPLICATE + + @pytest.mark.parametrize( + "key,expected_kwargs", + [ + ("word_embeddings", COL_TP), + ("output_layer", COL_TP), + ("linear_proj", ROW_TP), + ], + ) + def test_tp_sharding(self, key, expected_kwargs): + mapping = qwen3vl_causal_lm_import[key] + assert mapping.func_kwargs == expected_kwargs + + +class TestQwen3VLExportMapping: + """Test the Megatron Core -> HuggingFace export mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_language_model_prefix(self): + """Export paths should also use model.language_model. prefix.""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_export[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + mapping = qwen3vl_causal_lm_export["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_slicing(self): + assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing) + + def test_mlp_uses_gated_slicing(self): + assert isinstance( + qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing + ) + + def test_export_has_no_parallel_config(self): + """Export mappings should not have parallel configs.""" + for key in ["word_embeddings", "final_layernorm", "output_layer", + "input_layernorm", "linear_proj"]: + mapping = qwen3vl_causal_lm_export[key] + assert "parallel_config" not in mapping.func_kwargs + + +class TestQwen3VLImportExportSymmetry: + """Test that import and export mappings are consistent.""" + + def test_same_mcore_keys(self): + assert set(qwen3vl_causal_lm_import.keys()) == set( + qwen3vl_causal_lm_export.keys() + ) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc2", + "router", + ], + ) + def test_matching_hf_prefixes(self, key): + """Import and export should map to the same HF prefix.""" + imp = qwen3vl_causal_lm_import[key] + exp = qwen3vl_causal_lm_export[key] + assert imp.target_name_or_prefix == exp.target_name_or_prefix, ( + f"{key}: import prefix '{imp.target_name_or_prefix}' != " + f"export prefix '{exp.target_name_or_prefix}'" + ) + + def test_qkv_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_qkv"] + exp = qwen3vl_causal_lm_export["linear_qkv"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + def test_mlp_fc1_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_fc1"] + exp = qwen3vl_causal_lm_export["linear_fc1"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + +class TestQwen3VLvsQwen3Difference: + """Test that Qwen3-VL differs from Qwen3 only in the language_model prefix.""" + + def test_same_keys_as_qwen3(self): + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_export, + qwen3_causal_lm_import, + ) + + assert set(qwen3vl_causal_lm_import.keys()) == set( + qwen3_causal_lm_import.keys() + ) + assert set(qwen3vl_causal_lm_export.keys()) == set( + qwen3_causal_lm_export.keys() + ) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", + ], + ) + def test_vl_adds_language_model_prefix(self, key): + """Qwen3-VL should have 'language_model.' inserted after 'model.'.""" + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_import, + ) + + qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix + qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix + expected = qwen3_prefix.replace("model.", "model.language_model.", 1) + assert qwen3vl_prefix == expected, ( + f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" + ) + + def test_output_layer_same(self): + """lm_head is at root level for both Qwen3 and Qwen3-VL.""" + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_import, + ) + + assert ( + qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix + == qwen3_causal_lm_import["output_layer"].target_name_or_prefix + ) From 36da6deac2d5d419a460766e4de0799051ca7e63 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Wed, 13 May 2026 21:39:56 +0000 Subject: [PATCH 02/14] fix: ruff formatting and PT006 parametrize tuple fix Signed-off-by: Hung-Yueh Chiang --- tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py index 3f57cb9c478..c0d4cf9bb07 100644 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -142,7 +142,7 @@ def test_layernorms_are_replicated(self, key): assert mapping.func_kwargs == REPLICATE @pytest.mark.parametrize( - "key,expected_kwargs", + ("key", "expected_kwargs"), [ ("word_embeddings", COL_TP), ("output_layer", COL_TP), From e8101a7f8a14cc82a7998a9d987c2195239506b0 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Wed, 13 May 2026 23:14:20 +0000 Subject: [PATCH 03/14] fix: apply ruff formatting to mcore_qwen3vl plugin and test files Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- modelopt/torch/export/plugins/mcore_common.py | 5 +-- .../torch/export/plugins/mcore_qwen3vl.py | 2 +- .../torch/export/test_mcore_qwen3vl.py | 33 ++++++++----------- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index 660e4eac96d..15395b7a1e5 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -39,10 +39,7 @@ qwen25_causal_lm_export, qwen25_causal_lm_import, ) -from .mcore_qwen3vl import ( - qwen3vl_causal_lm_export, - qwen3vl_causal_lm_import, -) +from .mcore_qwen3vl import qwen3vl_causal_lm_export, qwen3vl_causal_lm_import all_mcore_hf_export_mapping: dict[str, Any] = { "DeepseekV2ForCausalLM": deepseek_causal_lm_export, diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py index 40eb99adb50..4dc3c63f4a2 100644 --- a/modelopt/torch/export/plugins/mcore_qwen3vl.py +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -117,4 +117,4 @@ "local_experts.linear_fc2": NameRemapping( "model.language_model.layers.{}.mlp.experts.{}.down_proj." ), -} \ No newline at end of file +} diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py index c0d4cf9bb07..c7a5efc47d4 100644 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -121,9 +121,7 @@ def test_qkv_uses_merging(self): assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging) def test_mlp_uses_gated_merging(self): - assert isinstance( - qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging - ) + assert isinstance(qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging) @pytest.mark.parametrize( "key", @@ -192,14 +190,17 @@ def test_qkv_uses_slicing(self): assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing) def test_mlp_uses_gated_slicing(self): - assert isinstance( - qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing - ) + assert isinstance(qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing) def test_export_has_no_parallel_config(self): """Export mappings should not have parallel configs.""" - for key in ["word_embeddings", "final_layernorm", "output_layer", - "input_layernorm", "linear_proj"]: + for key in [ + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_proj", + ]: mapping = qwen3vl_causal_lm_export[key] assert "parallel_config" not in mapping.func_kwargs @@ -208,9 +209,7 @@ class TestQwen3VLImportExportSymmetry: """Test that import and export mappings are consistent.""" def test_same_mcore_keys(self): - assert set(qwen3vl_causal_lm_import.keys()) == set( - qwen3vl_causal_lm_export.keys() - ) + assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3vl_causal_lm_export.keys()) @pytest.mark.parametrize( "key", @@ -256,12 +255,8 @@ def test_same_keys_as_qwen3(self): qwen3_causal_lm_import, ) - assert set(qwen3vl_causal_lm_import.keys()) == set( - qwen3_causal_lm_import.keys() - ) - assert set(qwen3vl_causal_lm_export.keys()) == set( - qwen3_causal_lm_export.keys() - ) + assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3_causal_lm_import.keys()) + assert set(qwen3vl_causal_lm_export.keys()) == set(qwen3_causal_lm_export.keys()) @pytest.mark.parametrize( "key", @@ -290,9 +285,7 @@ def test_vl_adds_language_model_prefix(self, key): qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix expected = qwen3_prefix.replace("model.", "model.language_model.", 1) - assert qwen3vl_prefix == expected, ( - f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" - ) + assert qwen3vl_prefix == expected, f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" def test_output_layer_same(self): """lm_head is at root level for both Qwen3 and Qwen3-VL.""" From aecbbfae4db2034e62b950532e81fcc63ddbb82d Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Thu, 14 May 2026 16:03:07 +0000 Subject: [PATCH 04/14] fix: collapse single-item imports in test_mcore_qwen3vl per ruff Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- .../torch/export/test_mcore_qwen3vl.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py index c7a5efc47d4..a9b8ddd5a0f 100644 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -32,7 +32,6 @@ qwen3vl_causal_lm_import, ) - # All mcore keys that a dense (non-MoE) Qwen3-VL model should have DENSE_MCORE_KEYS = { "word_embeddings", @@ -60,9 +59,7 @@ class TestQwen3VLRegistration: """Test that Qwen3-VL is registered in the global mapping.""" def test_registered_in_export_mapping(self): - from modelopt.torch.export.plugins.mcore_common import ( - all_mcore_hf_export_mapping, - ) + from modelopt.torch.export.plugins.mcore_common import all_mcore_hf_export_mapping assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_export_mapping assert ( @@ -71,9 +68,7 @@ def test_registered_in_export_mapping(self): ) def test_registered_in_import_mapping(self): - from modelopt.torch.export.plugins.mcore_common import ( - all_mcore_hf_import_mapping, - ) + from modelopt.torch.export.plugins.mcore_common import all_mcore_hf_import_mapping assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_import_mapping assert ( @@ -278,9 +273,7 @@ def test_same_keys_as_qwen3(self): ) def test_vl_adds_language_model_prefix(self, key): """Qwen3-VL should have 'language_model.' inserted after 'model.'.""" - from modelopt.torch.export.plugins.mcore_qwen import ( - qwen3_causal_lm_import, - ) + from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix @@ -289,9 +282,7 @@ def test_vl_adds_language_model_prefix(self, key): def test_output_layer_same(self): """lm_head is at root level for both Qwen3 and Qwen3-VL.""" - from modelopt.torch.export.plugins.mcore_qwen import ( - qwen3_causal_lm_import, - ) + from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import assert ( qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix From 80495e6bad87688e77687889a2fdfa359b2233c6 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Thu, 14 May 2026 16:36:49 +0000 Subject: [PATCH 05/14] refactor: derive Qwen3-VL mcore mapping from Qwen3 via prefix rewrite Replace the hand-written dict literals in mcore_qwen3vl.py with a helper that derives the VL mapping from qwen3_causal_lm_import/export by inserting 'language_model.' after 'model.' in every prefix. lm_head. (root-level) is left unchanged. Remove TestQwen3VLvsQwen3Difference since it now tests the implementation against itself. Note visual encoder (model.visual.*) is intentionally excluded from the mapping. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- .../torch/export/plugins/mcore_qwen3vl.py | 127 +++++------------- .../torch/export/test_mcore_qwen3vl.py | 49 ------- 2 files changed, 30 insertions(+), 146 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py index 4dc3c63f4a2..2f35b1291e0 100644 --- a/modelopt/torch/export/plugins/mcore_qwen3vl.py +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -15,106 +15,39 @@ """Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. -Qwen3-VL model structure differs from Qwen3: -- Language model weights are under `model.language_model.` prefix -- Visual encoder weights are under `model.visual.` prefix +Qwen3-VL differs from Qwen3 in one structural way: language-model weights live +under ``model.language_model.`` instead of ``model.``, while ``lm_head.weight`` +remains at the root level. The mappings below are derived automatically from +the Qwen3 mappings by inserting ``language_model.`` after ``model.`` for every +prefix that starts with ``model.``. -This module handles the language model conversion for PTQ/QAT workflows. -Visual components are typically kept in full precision. +Note: the visual encoder (``model.visual.*``) is intentionally excluded — this +mapping covers only the language-model decoder used for quantization and export. -HuggingFace Qwen3-VL-8B structure: -- model.language_model.embed_tokens.weight -- model.language_model.layers.{L}.input_layernorm.weight -- model.language_model.layers.{L}.self_attn.q_proj.weight -- model.language_model.layers.{L}.self_attn.k_proj.weight -- model.language_model.layers.{L}.self_attn.v_proj.weight -- model.language_model.layers.{L}.self_attn.q_norm.weight -- model.language_model.layers.{L}.self_attn.k_norm.weight -- model.language_model.layers.{L}.self_attn.o_proj.weight -- model.language_model.layers.{L}.post_attention_layernorm.weight -- model.language_model.layers.{L}.mlp.gate_proj.weight -- model.language_model.layers.{L}.mlp.up_proj.weight -- model.language_model.layers.{L}.mlp.down_proj.weight -- model.language_model.norm.weight -- lm_head.weight +Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json """ -from .mcore_custom import ( - COL_ETP, - COL_TP, - REPLICATE, - ROW_ETP, - ROW_TP, - CustomModuleMapping, - GatedMLPMerging, - GatedMLPSlicing, - NameRemapping, - QKVMerging, - QKVSlicing, -) +from .mcore_custom import CustomModuleMapping +from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import -# Import rules: HuggingFace -> Megatron Core -qwen3vl_causal_lm_import: dict[str, CustomModuleMapping] = { - # Embeddings - note the language_model prefix - "word_embeddings": NameRemapping("model.language_model.embed_tokens.", COL_TP), - # Final layer norm - "final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE), - # Output layer (lm_head is at root level, not under language_model) - "output_layer": NameRemapping("lm_head.", COL_TP), - # Attention - input layernorm - "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm.", REPLICATE), - # Attention - QKV projection (merged) - "linear_qkv": QKVMerging("model.language_model.layers.{}.self_attn.", COL_TP), - # Attention - output projection - "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj.", ROW_TP), - # Attention - Q/K layer norms (Qwen3 uses RMSNorm on Q and K) - "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm.", REPLICATE), - "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm.", REPLICATE), - # MLP - pre-MLP layernorm (post_attention_layernorm in HF) - "pre_mlp_layernorm": NameRemapping( - "model.language_model.layers.{}.post_attention_layernorm.", REPLICATE - ), - # MLP - gate_proj + up_proj merged into linear_fc1 - "linear_fc1": GatedMLPMerging("model.language_model.layers.{}.mlp.", COL_TP), - # MLP - down_proj as linear_fc2 - "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj.", ROW_TP), - # MoE support (for Qwen3-VL MoE variants like 30B-A3B) - "router": NameRemapping("model.language_model.layers.{}.mlp.gate.", REPLICATE), - "local_experts.linear_fc1": GatedMLPMerging( - "model.language_model.layers.{}.mlp.experts.{}.", COL_ETP - ), - "local_experts.linear_fc2": NameRemapping( - "model.language_model.layers.{}.mlp.experts.{}.down_proj.", ROW_ETP - ), -} -# Export rules: Megatron Core -> HuggingFace -qwen3vl_causal_lm_export: dict[str, CustomModuleMapping] = { - # Embeddings - "word_embeddings": NameRemapping("model.language_model.embed_tokens."), - # Final layer norm - "final_layernorm": NameRemapping("model.language_model.norm."), - # Output layer - "output_layer": NameRemapping("lm_head."), - # Attention - input layernorm - "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm."), - # Attention - QKV projection (sliced back to separate q/k/v) - "linear_qkv": QKVSlicing("model.language_model.layers.{}.self_attn."), - # Attention - output projection - "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj."), - # Attention - Q/K layer norms - "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm."), - "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm."), - # MLP - pre-MLP layernorm - "pre_mlp_layernorm": NameRemapping("model.language_model.layers.{}.post_attention_layernorm."), - # MLP - linear_fc1 sliced back to gate_proj + up_proj - "linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp."), - # MLP - down_proj - "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj."), - # MoE support - "router": NameRemapping("model.language_model.layers.{}.mlp.gate."), - "local_experts.linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp.experts.{}."), - "local_experts.linear_fc2": NameRemapping( - "model.language_model.layers.{}.mlp.experts.{}.down_proj." - ), -} +def _with_language_model_prefix( + mapping: dict[str, CustomModuleMapping], +) -> dict[str, CustomModuleMapping]: + """Derive a VL mapping from a base Qwen3 mapping. + + Rewrites every ``target_name_or_prefix`` that starts with ``model.`` to + ``model.language_model.``. Prefixes that do not start with + ``model.`` (e.g. ``lm_head.``) are left unchanged. + """ + result = {} + for key, m in mapping.items(): + prefix = m.target_name_or_prefix + if prefix.startswith("model."): + prefix = "model.language_model." + prefix[len("model.") :] + result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=m.func_kwargs) + return result + + +qwen3vl_causal_lm_import = _with_language_model_prefix(qwen3_causal_lm_import) +qwen3vl_causal_lm_export = _with_language_model_prefix(qwen3_causal_lm_export) diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py index a9b8ddd5a0f..f5f62058bf4 100644 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -239,52 +239,3 @@ def test_mlp_fc1_matching_prefix(self): imp = qwen3vl_causal_lm_import["linear_fc1"] exp = qwen3vl_causal_lm_export["linear_fc1"] assert imp.target_name_or_prefix == exp.target_name_or_prefix - - -class TestQwen3VLvsQwen3Difference: - """Test that Qwen3-VL differs from Qwen3 only in the language_model prefix.""" - - def test_same_keys_as_qwen3(self): - from modelopt.torch.export.plugins.mcore_qwen import ( - qwen3_causal_lm_export, - qwen3_causal_lm_import, - ) - - assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3_causal_lm_import.keys()) - assert set(qwen3vl_causal_lm_export.keys()) == set(qwen3_causal_lm_export.keys()) - - @pytest.mark.parametrize( - "key", - [ - "word_embeddings", - "final_layernorm", - "input_layernorm", - "linear_qkv", - "linear_proj", - "q_layernorm", - "k_layernorm", - "pre_mlp_layernorm", - "linear_fc1", - "linear_fc2", - "router", - "local_experts.linear_fc1", - "local_experts.linear_fc2", - ], - ) - def test_vl_adds_language_model_prefix(self, key): - """Qwen3-VL should have 'language_model.' inserted after 'model.'.""" - from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import - - qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix - qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix - expected = qwen3_prefix.replace("model.", "model.language_model.", 1) - assert qwen3vl_prefix == expected, f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" - - def test_output_layer_same(self): - """lm_head is at root level for both Qwen3 and Qwen3-VL.""" - from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import - - assert ( - qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix - == qwen3_causal_lm_import["output_layer"].target_name_or_prefix - ) From 6ad8d0e806db391543774b11ba3cfbf57fa14bc2 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Mon, 18 May 2026 22:18:53 +0000 Subject: [PATCH 06/14] Integrate Qwen3-VL mcore weight mapping tests into unified export test suite - Remove standalone test_mcore_qwen3vl.py; fold tests into test_unified_export_megatron.py using the same parametrized patterns as llama/nemotron/gpt-oss - Add model_type param to _test_unified_import_megatron and test_unified_import_megatron; qwen3vl uses nested text_config attrs and extra kv_channels/qk_layernorm mcore kwargs - Add model_type param to _test_unified_export_megatron and test_unified_export_megatron; qwen3vl post-export merges vision weights and validates both language_model and visual tensors are present - Add create_tiny_nemotron_dir and create_tiny_qwen3vl_dir helpers in transformers_models.py so all model types use the same dir-based pattern - Fix get_tiny_nemotron dims (hidden_size=64, intermediate_size=128) to prevent all-zero relu2 activations that break NVFP4 scaling factor checks - Fix get_tiny_qwen3vl to pass sub-configs as dicts to Qwen3VLConfig to work with transformers 5.3.0 where instances are not accepted Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- .../_test_utils/torch/transformers_models.py | 82 ++++++ .../torch/export/test_mcore_qwen3vl.py | 241 ------------------ .../export/test_unified_export_megatron.py | 206 ++++++++++----- 3 files changed, 221 insertions(+), 308 deletions(-) delete mode 100644 tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 34bc96cd0ae..708162e55dd 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -28,12 +28,15 @@ BertConfig, GptOssConfig, LlamaConfig, + NemotronConfig, PreTrainedModel, Qwen3Config, Qwen3MoeConfig, + Qwen3VLConfig, T5Config, T5ForConditionalGeneration, ) +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.opt as mto @@ -120,6 +123,85 @@ def create_tiny_qwen3_moe_dir( return qwen3_moe_dir +##### Qwen3-VL ##### +def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel: + set_seed(SEED) + + # Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32). + # Pass config_kwargs to override for multi-GPU tests (e.g. num_attention_heads=num_gpus, + # num_key_value_heads=num_gpus, hidden_size=num_gpus*head_dim). + text_kwargs = { + "hidden_size": 32, + "intermediate_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 8, + "max_position_embeddings": 32, + "vocab_size": 32, + } + text_kwargs.update(config_kwargs) + # Pass as dicts — transformers 5.3.0 Qwen3VLConfig.__init__ only handles + # vision_config/text_config when they are dicts or None, not instances. + vision_kwargs = { + "depth": 1, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "in_channels": 3, + "patch_size": 4, + "spatial_merge_size": 1, + "temporal_patch_size": 1, + "out_hidden_size": text_kwargs["hidden_size"], # must match text hidden_size + } + cfg = Qwen3VLConfig(text_config=text_kwargs, vision_config=vision_kwargs) + return Qwen3VLForConditionalGeneration(cfg) + + +def create_tiny_qwen3vl_dir( + tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs +) -> Path: + qwen3vl_dir = Path(tmp_path) / "tiny_qwen3vl" + if with_tokenizer: + tokenizer = get_tiny_tokenizer() + tokenizer.save_pretrained(qwen3vl_dir) + config_kwargs["vocab_size"] = tokenizer.vocab_size + get_tiny_qwen3vl(**config_kwargs).save_pretrained(qwen3vl_dir) + return qwen3vl_dir + + +##### NEMOTRON ##### +def get_tiny_nemotron(**config_kwargs) -> PreTrainedModel: + set_seed(SEED) + + # hidden_size=64, ffn_hidden_size=128: relu2 activation needs non-trivial dims + # to avoid all-zero activations (scaling factor 0) in NVFP4 quantization. + kwargs = { + "dtype": torch.bfloat16, + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 2, + "num_attention_heads": 8, + "num_key_value_heads": 1, + "max_position_embeddings": 32, + "vocab_size": 32, + } + kwargs.update(**config_kwargs) + return AutoModelForCausalLM.from_config(NemotronConfig(**kwargs)) + + +def create_tiny_nemotron_dir( + tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs +) -> Path: + nemotron_dir = Path(tmp_path) / "tiny_nemotron" + if with_tokenizer: + tokenizer = get_tiny_tokenizer() + tokenizer.save_pretrained(nemotron_dir) + config_kwargs["vocab_size"] = tokenizer.vocab_size + get_tiny_nemotron(**config_kwargs).save_pretrained(nemotron_dir) + return nemotron_dir + + ##### GPT-OSS ##### def get_tiny_gpt_oss(**config_kwargs) -> PreTrainedModel: set_seed(SEED) diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py deleted file mode 100644 index f5f62058bf4..00000000000 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ /dev/null @@ -1,241 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for Qwen3-VL Megatron Core export/import plugin.""" - -import pytest - -from modelopt.torch.export.plugins.mcore_custom import ( - COL_TP, - REPLICATE, - ROW_TP, - GatedMLPMerging, - GatedMLPSlicing, - NameRemapping, - QKVMerging, - QKVSlicing, -) -from modelopt.torch.export.plugins.mcore_qwen3vl import ( - qwen3vl_causal_lm_export, - qwen3vl_causal_lm_import, -) - -# All mcore keys that a dense (non-MoE) Qwen3-VL model should have -DENSE_MCORE_KEYS = { - "word_embeddings", - "final_layernorm", - "output_layer", - "input_layernorm", - "linear_qkv", - "linear_proj", - "q_layernorm", - "k_layernorm", - "pre_mlp_layernorm", - "linear_fc1", - "linear_fc2", -} - -# Additional MoE keys -MOE_MCORE_KEYS = { - "router", - "local_experts.linear_fc1", - "local_experts.linear_fc2", -} - - -class TestQwen3VLRegistration: - """Test that Qwen3-VL is registered in the global mapping.""" - - def test_registered_in_export_mapping(self): - from modelopt.torch.export.plugins.mcore_common import all_mcore_hf_export_mapping - - assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_export_mapping - assert ( - all_mcore_hf_export_mapping["Qwen3VLForConditionalGeneration"] - is qwen3vl_causal_lm_export - ) - - def test_registered_in_import_mapping(self): - from modelopt.torch.export.plugins.mcore_common import all_mcore_hf_import_mapping - - assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_import_mapping - assert ( - all_mcore_hf_import_mapping["Qwen3VLForConditionalGeneration"] - is qwen3vl_causal_lm_import - ) - - -class TestQwen3VLImportMapping: - """Test the HuggingFace -> Megatron Core import mapping.""" - - def test_has_all_dense_keys(self): - assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) - - def test_has_all_moe_keys(self): - assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) - - def test_language_model_prefix(self): - """Qwen3-VL uses model.language_model. prefix (not model.).""" - prefix_keys = [ - "word_embeddings", - "final_layernorm", - "input_layernorm", - "linear_qkv", - "linear_proj", - "q_layernorm", - "k_layernorm", - "pre_mlp_layernorm", - "linear_fc1", - "linear_fc2", - ] - for key in prefix_keys: - mapping = qwen3vl_causal_lm_import[key] - assert "model.language_model." in mapping.target_name_or_prefix, ( - f"{key}: expected 'model.language_model.' prefix, " - f"got '{mapping.target_name_or_prefix}'" - ) - - def test_output_layer_at_root(self): - """lm_head is at root level, not under language_model.""" - mapping = qwen3vl_causal_lm_import["output_layer"] - assert mapping.target_name_or_prefix == "lm_head." - - def test_qkv_uses_merging(self): - assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging) - - def test_mlp_uses_gated_merging(self): - assert isinstance(qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging) - - @pytest.mark.parametrize( - "key", - [ - "input_layernorm", - "q_layernorm", - "k_layernorm", - "pre_mlp_layernorm", - "final_layernorm", - ], - ) - def test_layernorms_are_replicated(self, key): - """Layernorms should use REPLICATE (empty func_kwargs).""" - mapping = qwen3vl_causal_lm_import[key] - assert isinstance(mapping, NameRemapping) - assert mapping.func_kwargs == REPLICATE - - @pytest.mark.parametrize( - ("key", "expected_kwargs"), - [ - ("word_embeddings", COL_TP), - ("output_layer", COL_TP), - ("linear_proj", ROW_TP), - ], - ) - def test_tp_sharding(self, key, expected_kwargs): - mapping = qwen3vl_causal_lm_import[key] - assert mapping.func_kwargs == expected_kwargs - - -class TestQwen3VLExportMapping: - """Test the Megatron Core -> HuggingFace export mapping.""" - - def test_has_all_dense_keys(self): - assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) - - def test_has_all_moe_keys(self): - assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) - - def test_language_model_prefix(self): - """Export paths should also use model.language_model. prefix.""" - prefix_keys = [ - "word_embeddings", - "final_layernorm", - "input_layernorm", - "linear_qkv", - "linear_proj", - "q_layernorm", - "k_layernorm", - "pre_mlp_layernorm", - "linear_fc1", - "linear_fc2", - ] - for key in prefix_keys: - mapping = qwen3vl_causal_lm_export[key] - assert "model.language_model." in mapping.target_name_or_prefix, ( - f"{key}: expected 'model.language_model.' prefix, " - f"got '{mapping.target_name_or_prefix}'" - ) - - def test_output_layer_at_root(self): - mapping = qwen3vl_causal_lm_export["output_layer"] - assert mapping.target_name_or_prefix == "lm_head." - - def test_qkv_uses_slicing(self): - assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing) - - def test_mlp_uses_gated_slicing(self): - assert isinstance(qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing) - - def test_export_has_no_parallel_config(self): - """Export mappings should not have parallel configs.""" - for key in [ - "word_embeddings", - "final_layernorm", - "output_layer", - "input_layernorm", - "linear_proj", - ]: - mapping = qwen3vl_causal_lm_export[key] - assert "parallel_config" not in mapping.func_kwargs - - -class TestQwen3VLImportExportSymmetry: - """Test that import and export mappings are consistent.""" - - def test_same_mcore_keys(self): - assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3vl_causal_lm_export.keys()) - - @pytest.mark.parametrize( - "key", - [ - "word_embeddings", - "final_layernorm", - "output_layer", - "input_layernorm", - "linear_proj", - "q_layernorm", - "k_layernorm", - "pre_mlp_layernorm", - "linear_fc2", - "router", - ], - ) - def test_matching_hf_prefixes(self, key): - """Import and export should map to the same HF prefix.""" - imp = qwen3vl_causal_lm_import[key] - exp = qwen3vl_causal_lm_export[key] - assert imp.target_name_or_prefix == exp.target_name_or_prefix, ( - f"{key}: import prefix '{imp.target_name_or_prefix}' != " - f"export prefix '{exp.target_name_or_prefix}'" - ) - - def test_qkv_matching_prefix(self): - imp = qwen3vl_causal_lm_import["linear_qkv"] - exp = qwen3vl_causal_lm_export["linear_qkv"] - assert imp.target_name_or_prefix == exp.target_name_or_prefix - - def test_mlp_fc1_matching_prefix(self): - imp = qwen3vl_causal_lm_import["linear_fc1"] - exp = qwen3vl_causal_lm_export["linear_fc1"] - assert imp.target_name_or_prefix == exp.target_name_or_prefix diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 3fac8269ccd..b5b182113ac 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -23,7 +23,12 @@ import transformers from _test_utils.torch.megatron.models import get_mcore_gpt_model from _test_utils.torch.megatron.utils import get_forward -from _test_utils.torch.transformers_models import create_tiny_llama_dir, get_tiny_tokenizer +from _test_utils.torch.transformers_models import ( + create_tiny_llama_dir, + create_tiny_nemotron_dir, + create_tiny_qwen3vl_dir, +) +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration from safetensors import safe_open from safetensors.torch import save_file @@ -71,21 +76,61 @@ def _verify_model_quant_config( assert quant_config_dict["kv_cache_quant_algo"] == KV_CACHE_FP8 +def _merge_vision_weights(src_dir: Path, dst_dir: Path) -> None: + """Copy model.visual.* tensors from src safetensors into dst_dir. + + The mcore export only writes language-model weights. To produce a complete + Qwen3-VL checkpoint the vision-encoder weights must be merged from the + original pretrained checkpoint. + """ + vision_tensors = {} + for sf in sorted(src_dir.glob("*.safetensors")): + with safe_open(str(sf), framework="pt", device="cpu") as f: + for key in f.keys(): # noqa: SIM118 + if key.startswith("model.visual."): + vision_tensors[key] = f.get_tensor(key) + if vision_tensors: + save_file(vision_tensors, str(dst_dir / "model-vision.safetensors")) + + def _test_unified_export_megatron( - tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg, rank, size + tmp_path, model_type, extra_module, quant_config, kv_cache_quant_cfg, rank, size, + model_dir=None, ): - tokenizer = get_tiny_tokenizer() - tokenizer.save_pretrained(tmp_path) + if model_type == "qwen3vl": + config = transformers.AutoConfig.from_pretrained(model_dir) + text_cfg = config.text_config + num_layers = text_cfg.num_hidden_layers + hidden_size = text_cfg.hidden_size + num_attention_heads = text_cfg.num_attention_heads + num_query_groups = text_cfg.num_key_value_heads + ffn_hidden_size = text_cfg.intermediate_size + max_sequence_length = text_cfg.max_position_embeddings + vocab_size = text_cfg.vocab_size + extra_kwargs = dict(kv_channels=text_cfg.head_dim, qk_layernorm=True) + elif model_type == "llama": + config = transformers.AutoConfig.from_pretrained(model_dir) + num_layers = config.num_hidden_layers + hidden_size = config.hidden_size + num_attention_heads = config.num_attention_heads + num_query_groups = config.num_key_value_heads + ffn_hidden_size = config.intermediate_size + max_sequence_length = config.max_position_embeddings + vocab_size = config.vocab_size + extra_kwargs = {} + elif model_type == "nemotron": + config = transformers.AutoConfig.from_pretrained(model_dir) + num_layers = config.num_hidden_layers + hidden_size = config.hidden_size + num_attention_heads = config.num_attention_heads + num_query_groups = config.num_key_value_heads + ffn_hidden_size = config.intermediate_size + max_sequence_length = config.max_position_embeddings + vocab_size = config.vocab_size + extra_kwargs = {} + else: + raise ValueError(f"Unsupported model_type: {model_type}") - num_layers = 2 - hidden_size = 64 - num_attention_heads = 8 - num_query_groups = size - ffn_hidden_size = 128 - max_sequence_length = 32 - vocab_size = tokenizer.vocab_size - - arch = "NemotronForCausalLM" if model_type == "nemotron" else "LlamaForCausalLM" activation_func = "squared_relu" if model_type == "nemotron" else "swiglu" normalization = "LayerNorm" if model_type == "nemotron" else "RMSNorm" @@ -103,6 +148,7 @@ def _test_unified_export_megatron( activation_func=activation_func, normalization=normalization, transformer_impl="modelopt", + **extra_kwargs, ).cuda() if quant_config: @@ -127,101 +173,125 @@ def _test_unified_export_megatron( model = mtsp.convert(model, [("eagle", config)]) assert isinstance(model, _DynamicEagleGPTModel) - pretrained_config = { - "architectures": [arch], - "attention_bias": False, - "hidden_size": hidden_size, - "intermediate_size": ffn_hidden_size, - "max_position_embeddings": max_sequence_length, - "model_type": "llama", - "num_attention_heads": num_attention_heads, - "num_hidden_layers": num_layers, - "num_key_value_heads": num_query_groups, - "torch_dtype": "bfloat16", - } - - with open(tmp_path / "config.json", "w") as f: - json.dump(pretrained_config, f) + hf_config_dir = model_dir tmp_export_dir = tmp_path / "export" export_mcore_gpt_to_hf( model, - tmp_path if arch is not None else None, + hf_config_dir, dtype=torch.bfloat16, export_dir=str(tmp_export_dir), ) - if quant_config: + if quant_config and model_type != "qwen3vl": _verify_model_quant_config(tmp_export_dir, quant_config, kv_cache_quant_cfg) + if model_type == "qwen3vl" and rank == 0: + _merge_vision_weights(Path(model_dir), tmp_export_dir) + keys = [] + for sf in sorted(tmp_export_dir.glob("*.safetensors")): + with safe_open(str(sf), framework="pt", device="cpu") as f: + keys.extend(f.keys()) + assert any(k.startswith("model.language_model.") for k in keys), ( + "language model keys missing from combined export" + ) + assert any(k.startswith("model.visual.") for k in keys), ( + "vision encoder keys missing from combined export" + ) + vl_model = Qwen3VLForConditionalGeneration.from_pretrained( + tmp_export_dir, torch_dtype=torch.bfloat16 + ).cuda() + input_ids = torch.zeros(1, 4, dtype=torch.long).cuda() + with torch.no_grad(): + out = vl_model(input_ids=input_ids) + assert out.logits.shape[-1] == vl_model.config.text_config.vocab_size + @pytest.mark.parametrize( - ("model_type", "arch", "extra_module", "quant_config", "kv_cache_quant_cfg"), + ("model_type", "extra_module", "quant_config", "kv_cache_quant_cfg"), [ - ("nemotron", "NemotronForCausalLM", None, None, None), - ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", None), - ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", "FP8_KV_CFG"), - ("nemotron", "NemotronForCausalLM", "eagle", None, None), - ("nemotron", "NemotronForCausalLM", "medusa", None, None), - ("llama", "LlamaForCausalLM", None, None, None), - ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", None), - ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", "FP8_KV_CFG"), - ("llama", "LlamaForCausalLM", "eagle", None, None), - ("llama", "LlamaForCausalLM", "medusa", None, None), + ("nemotron", None, None, None), + ("nemotron", None, "NVFP4_DEFAULT_CFG", None), + ("nemotron", None, "NVFP4_DEFAULT_CFG", "FP8_KV_CFG"), + ("nemotron", "eagle", None, None), + ("nemotron", "medusa", None, None), + ("llama", None, None, None), + ("llama", None, "FP8_DEFAULT_CFG", None), + ("llama", None, "FP8_DEFAULT_CFG", "FP8_KV_CFG"), + ("llama", "eagle", None, None), + ("llama", "medusa", None, None), + ("qwen3vl", None, "FP8_DEFAULT_CFG", None), ], ) def test_unified_export_megatron( - dist_workers_size_1, tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg + dist_workers_size_1, tmp_path, model_type, extra_module, quant_config, kv_cache_quant_cfg ): + if model_type == "llama": + model_dir = create_tiny_llama_dir(tmp_path) + elif model_type == "qwen3vl": + model_dir = create_tiny_qwen3vl_dir(tmp_path) + elif model_type == "nemotron": + model_dir = create_tiny_nemotron_dir(tmp_path) + else: + raise ValueError(f"Unsupported model_type: {model_type}") # TODO: Fix TP>1 failures dist_workers_size_1.run( partial( _test_unified_export_megatron, tmp_path, model_type, - arch, extra_module, quant_config, kv_cache_quant_cfg, + model_dir=model_dir, ), ) -def _test_unified_import_megatron(tiny_llama_dir, rank, size): - config = transformers.AutoConfig.from_pretrained(tiny_llama_dir) +def _test_unified_import_megatron(model_dir, rank, size, model_type="llama"): + config = transformers.AutoConfig.from_pretrained(model_dir) - num_layers = config.num_hidden_layers - hidden_size = config.hidden_size - num_attention_heads = config.num_attention_heads - num_query_groups = config.num_key_value_heads - ffn_hidden_size = config.intermediate_size - max_sequence_length = config.max_position_embeddings - vocab_size = config.vocab_size - activation_func = "swiglu" - normalization = "RMSNorm" + if model_type == "qwen3vl": + cfg = config.text_config + extra_kwargs = dict(kv_channels=cfg.head_dim, transformer_impl="modelopt", qk_layernorm=True) + else: + cfg = config + extra_kwargs = {} model = get_mcore_gpt_model( tensor_model_parallel_size=size, pipeline_model_parallel_size=1, initialize_megatron=True, - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups, - ffn_hidden_size=ffn_hidden_size, - max_sequence_length=max_sequence_length, - vocab_size=vocab_size, - activation_func=activation_func, - normalization=normalization, + num_layers=cfg.num_hidden_layers, + hidden_size=cfg.hidden_size, + num_attention_heads=cfg.num_attention_heads, + num_query_groups=cfg.num_key_value_heads, + ffn_hidden_size=cfg.intermediate_size, + max_sequence_length=cfg.max_position_embeddings, + vocab_size=cfg.vocab_size, + activation_func="swiglu", + normalization="RMSNorm", + **extra_kwargs, ).cuda() - import_mcore_gpt_from_hf(model, tiny_llama_dir) + import_mcore_gpt_from_hf(model, model_dir) -def test_unified_import_megatron(dist_workers, tmp_path): +@pytest.mark.parametrize("model_type", ["llama", "qwen3vl"]) +def test_unified_import_megatron(dist_workers, tmp_path, model_type): num_gpus = torch.cuda.device_count() - tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_key_value_heads=num_gpus) - dist_workers.run(partial(_test_unified_import_megatron, tiny_llama_dir)) + if model_type == "llama": + model_dir = create_tiny_llama_dir(tmp_path, num_key_value_heads=num_gpus) + elif model_type == "qwen3vl": + model_dir = create_tiny_qwen3vl_dir( + tmp_path, + num_attention_heads=num_gpus, + num_key_value_heads=num_gpus, + hidden_size=num_gpus * 8, # head_dim=8 + ) + else: + raise ValueError(f"Unsupported model_type: {model_type}") + dist_workers.run(partial(_test_unified_import_megatron, model_dir, model_type=model_type)) def _test_qkv_slicing_gqa_tp2(tmp_path, rank, size): @@ -372,3 +442,5 @@ def test_mtp_state_dict_index_file(tmp_path): assert "mtp.0.hnorm.weight" in mtp_state_dict assert torch.allclose(mtp_state_dict["mtp.0.hnorm.weight"], torch.full((32,), 3.0)) assert "mtp*" in exporter.exclude_modules + + From 77adc9d3e99fbe40f349b02293d2cee7c1c026f8 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Mon, 18 May 2026 22:24:55 +0000 Subject: [PATCH 07/14] Run _verify_model_quant_config for qwen3vl export MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The exception for qwen3vl was conservative — export_mcore_gpt_to_hf writes config.json and hf_quant_config.json for any quantized model regardless of architecture, so the verification applies equally. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- .../gpu_megatron/torch/export/test_unified_export_megatron.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index b5b182113ac..7f470e890c9 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -183,11 +183,12 @@ def _test_unified_export_megatron( export_dir=str(tmp_export_dir), ) - if quant_config and model_type != "qwen3vl": + if quant_config: _verify_model_quant_config(tmp_export_dir, quant_config, kv_cache_quant_cfg) if model_type == "qwen3vl" and rank == 0: _merge_vision_weights(Path(model_dir), tmp_export_dir) + # sanity check that the vision encoder weights were merged keys = [] for sf in sorted(tmp_export_dir.glob("*.safetensors")): with safe_open(str(sf), framework="pt", device="cpu") as f: @@ -198,6 +199,7 @@ def _test_unified_export_megatron( assert any(k.startswith("model.visual.") for k in keys), ( "vision encoder keys missing from combined export" ) + # try to load the model and run a forward pass vl_model = Qwen3VLForConditionalGeneration.from_pretrained( tmp_export_dir, torch_dtype=torch.bfloat16 ).cuda() From 3637fe7d4d71a29a2112deef9f403ea4d036c2c7 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Mon, 18 May 2026 22:45:24 +0000 Subject: [PATCH 08/14] Fix ruff lint errors in test_unified_export_megatron.py - C408: Replace dict() calls with dict literals - PLR1714: Merge llama/nemotron branches using `model_type in {"llama", "nemotron"}` - Format: apply ruff formatting (import order, function signature, trailing newlines) Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- .../export/test_unified_export_megatron.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 7f470e890c9..fd1477aa147 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -28,9 +28,9 @@ create_tiny_nemotron_dir, create_tiny_qwen3vl_dir, ) -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration from safetensors import safe_open from safetensors.torch import save_file +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp @@ -94,7 +94,13 @@ def _merge_vision_weights(src_dir: Path, dst_dir: Path) -> None: def _test_unified_export_megatron( - tmp_path, model_type, extra_module, quant_config, kv_cache_quant_cfg, rank, size, + tmp_path, + model_type, + extra_module, + quant_config, + kv_cache_quant_cfg, + rank, + size, model_dir=None, ): if model_type == "qwen3vl": @@ -107,18 +113,8 @@ def _test_unified_export_megatron( ffn_hidden_size = text_cfg.intermediate_size max_sequence_length = text_cfg.max_position_embeddings vocab_size = text_cfg.vocab_size - extra_kwargs = dict(kv_channels=text_cfg.head_dim, qk_layernorm=True) - elif model_type == "llama": - config = transformers.AutoConfig.from_pretrained(model_dir) - num_layers = config.num_hidden_layers - hidden_size = config.hidden_size - num_attention_heads = config.num_attention_heads - num_query_groups = config.num_key_value_heads - ffn_hidden_size = config.intermediate_size - max_sequence_length = config.max_position_embeddings - vocab_size = config.vocab_size - extra_kwargs = {} - elif model_type == "nemotron": + extra_kwargs = {"kv_channels": text_cfg.head_dim, "qk_layernorm": True} + elif model_type in {"llama", "nemotron"}: config = transformers.AutoConfig.from_pretrained(model_dir) num_layers = config.num_hidden_layers hidden_size = config.hidden_size @@ -255,7 +251,11 @@ def _test_unified_import_megatron(model_dir, rank, size, model_type="llama"): if model_type == "qwen3vl": cfg = config.text_config - extra_kwargs = dict(kv_channels=cfg.head_dim, transformer_impl="modelopt", qk_layernorm=True) + extra_kwargs = { + "kv_channels": cfg.head_dim, + "transformer_impl": "modelopt", + "qk_layernorm": True, + } else: cfg = config extra_kwargs = {} @@ -444,5 +444,3 @@ def test_mtp_state_dict_index_file(tmp_path): assert "mtp.0.hnorm.weight" in mtp_state_dict assert torch.allclose(mtp_state_dict["mtp.0.hnorm.weight"], torch.full((32,), 3.0)) assert "mtp*" in exporter.exclude_modules - - From 57a4608d196a79e821cfea3b847bbd7233004a35 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Mon, 18 May 2026 22:49:11 +0000 Subject: [PATCH 09/14] Address PR review findings for Qwen3-VL mcore mapping IMPORTANT fixes: - Make Qwen3VLConfig / Qwen3VLForConditionalGeneration lazy imports inside get_tiny_qwen3vl() and _test_unified_export_megatron() so older transformers builds do not break collection of every test importing transformers_models.py - Drop unsupported MoE claim from CHANGELOG; only Qwen3VLForConditionalGeneration (dense) is registered in the mcore dispatch tables Suggestions: - Copy func_kwargs in _with_language_model_prefix to avoid shared-dict mutation between Qwen3 and Qwen3VL mapping entries - Add unquantized qwen3vl export test case alongside the FP8 one - Add torch.distributed.barrier() before the rank-0 vision-weight merge block to keep the code safe if TP > 1 is used later Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- CHANGELOG.rst | 2 +- modelopt/torch/export/plugins/mcore_qwen3vl.py | 2 +- tests/_test_utils/torch/transformers_models.py | 8 ++++++-- .../torch/export/test_unified_export_megatron.py | 6 +++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 982a7da1194..faf5e82c15c 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,7 +25,7 @@ Changelog - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. -- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL and supports both dense and MoE variants. +- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL. 0.44 (2026-05-18) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py index 2f35b1291e0..16edba27e0e 100644 --- a/modelopt/torch/export/plugins/mcore_qwen3vl.py +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -45,7 +45,7 @@ def _with_language_model_prefix( prefix = m.target_name_or_prefix if prefix.startswith("model."): prefix = "model.language_model." + prefix[len("model.") :] - result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=m.func_kwargs) + result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=dict(m.func_kwargs)) return result diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 708162e55dd..fdd492013c1 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -32,11 +32,9 @@ PreTrainedModel, Qwen3Config, Qwen3MoeConfig, - Qwen3VLConfig, T5Config, T5ForConditionalGeneration, ) -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.opt as mto @@ -125,6 +123,12 @@ def create_tiny_qwen3_moe_dir( ##### Qwen3-VL ##### def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel: + # Lazy imports — Qwen3VL classes live under transformers.models.qwen3_vl which + # may not exist in older transformers builds, and this module is imported by + # every test that uses transformers_models.py. + from transformers import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + set_seed(SEED) # Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32). diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index fd1477aa147..12670d5c08b 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -30,7 +30,6 @@ ) from safetensors import safe_open from safetensors.torch import save_file -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp @@ -182,6 +181,8 @@ def _test_unified_export_megatron( if quant_config: _verify_model_quant_config(tmp_export_dir, quant_config, kv_cache_quant_cfg) + if model_type == "qwen3vl": + torch.distributed.barrier() if model_type == "qwen3vl" and rank == 0: _merge_vision_weights(Path(model_dir), tmp_export_dir) # sanity check that the vision encoder weights were merged @@ -196,6 +197,8 @@ def _test_unified_export_megatron( "vision encoder keys missing from combined export" ) # try to load the model and run a forward pass + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + vl_model = Qwen3VLForConditionalGeneration.from_pretrained( tmp_export_dir, torch_dtype=torch.bfloat16 ).cuda() @@ -218,6 +221,7 @@ def _test_unified_export_megatron( ("llama", None, "FP8_DEFAULT_CFG", "FP8_KV_CFG"), ("llama", "eagle", None, None), ("llama", "medusa", None, None), + ("qwen3vl", None, None, None), ("qwen3vl", None, "FP8_DEFAULT_CFG", None), ], ) From e8e2d7b63646f80dfbccec8c30635af6fec731ea Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Mon, 18 May 2026 22:58:05 +0000 Subject: [PATCH 10/14] Address second-round PR review suggestions - Add docstring note that Qwen3VLMoeForConditionalGeneration is unsupported (MoE variant uses 3-D fused expert tensors incompatible with dense mapping) - Replace dict(m.func_kwargs) with copy.deepcopy for full isolation of copied mapping entries - Fix asymmetric barrier: fold rank==0 check under the barrier so all ranks participate before any rank writes vision weights Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- .../torch/export/plugins/mcore_qwen3vl.py | 11 ++++- .../export/test_unified_export_megatron.py | 48 ++++++++++--------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py index 16edba27e0e..1f2d3830d61 100644 --- a/modelopt/torch/export/plugins/mcore_qwen3vl.py +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -24,9 +24,16 @@ Note: the visual encoder (``model.visual.*``) is intentionally excluded — this mapping covers only the language-model decoder used for quantization and export. +Note: ``Qwen3VLMoeForConditionalGeneration`` is **not** supported here. The MoE +variant stores expert weights as 3-D tensors (``mlp.experts.gate_up_proj``, +``mlp.experts.down_proj``) that require a dedicated fused-expert mapping and +cannot reuse the dense Qwen3 rules. + Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json """ +import copy + from .mcore_custom import CustomModuleMapping from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import @@ -45,7 +52,9 @@ def _with_language_model_prefix( prefix = m.target_name_or_prefix if prefix.startswith("model."): prefix = "model.language_model." + prefix[len("model.") :] - result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=dict(m.func_kwargs)) + result[key] = type(m)( + target_name_or_prefix=prefix, func_kwargs=copy.deepcopy(m.func_kwargs) + ) return result diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 12670d5c08b..f7aad042c32 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -183,29 +183,31 @@ def _test_unified_export_megatron( if model_type == "qwen3vl": torch.distributed.barrier() - if model_type == "qwen3vl" and rank == 0: - _merge_vision_weights(Path(model_dir), tmp_export_dir) - # sanity check that the vision encoder weights were merged - keys = [] - for sf in sorted(tmp_export_dir.glob("*.safetensors")): - with safe_open(str(sf), framework="pt", device="cpu") as f: - keys.extend(f.keys()) - assert any(k.startswith("model.language_model.") for k in keys), ( - "language model keys missing from combined export" - ) - assert any(k.startswith("model.visual.") for k in keys), ( - "vision encoder keys missing from combined export" - ) - # try to load the model and run a forward pass - from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration - - vl_model = Qwen3VLForConditionalGeneration.from_pretrained( - tmp_export_dir, torch_dtype=torch.bfloat16 - ).cuda() - input_ids = torch.zeros(1, 4, dtype=torch.long).cuda() - with torch.no_grad(): - out = vl_model(input_ids=input_ids) - assert out.logits.shape[-1] == vl_model.config.text_config.vocab_size + if rank == 0: + _merge_vision_weights(Path(model_dir), tmp_export_dir) + # sanity check that the vision encoder weights were merged + keys = [] + for sf in sorted(tmp_export_dir.glob("*.safetensors")): + with safe_open(str(sf), framework="pt", device="cpu") as f: + keys.extend(f.keys()) + assert any(k.startswith("model.language_model.") for k in keys), ( + "language model keys missing from combined export" + ) + assert any(k.startswith("model.visual.") for k in keys), ( + "vision encoder keys missing from combined export" + ) + # try to load the model and run a forward pass + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + ) + + vl_model = Qwen3VLForConditionalGeneration.from_pretrained( + tmp_export_dir, torch_dtype=torch.bfloat16 + ).cuda() + input_ids = torch.zeros(1, 4, dtype=torch.long).cuda() + with torch.no_grad(): + out = vl_model(input_ids=input_ids) + assert out.logits.shape[-1] == vl_model.config.text_config.vocab_size @pytest.mark.parametrize( From 63a229a9191aaefe9dd4acc49745f9dbeeb85483 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Tue, 19 May 2026 16:09:35 +0000 Subject: [PATCH 11/14] Move Qwen3VL imports back to module top level Per reviewer feedback: the imports are now protected by the module-level pytest.importorskip("transformers") guard (transformers_models.py) and by the gpu_megatron test environment assumption (test file), so top-level imports are the right style rather than lazy per-call imports. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- tests/_test_utils/torch/transformers_models.py | 8 ++------ .../torch/export/test_unified_export_megatron.py | 5 +---- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index fdd492013c1..708162e55dd 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -32,9 +32,11 @@ PreTrainedModel, Qwen3Config, Qwen3MoeConfig, + Qwen3VLConfig, T5Config, T5ForConditionalGeneration, ) +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.opt as mto @@ -123,12 +125,6 @@ def create_tiny_qwen3_moe_dir( ##### Qwen3-VL ##### def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel: - # Lazy imports — Qwen3VL classes live under transformers.models.qwen3_vl which - # may not exist in older transformers builds, and this module is imported by - # every test that uses transformers_models.py. - from transformers import Qwen3VLConfig - from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration - set_seed(SEED) # Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32). diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index f7aad042c32..4c899349dab 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -30,6 +30,7 @@ ) from safetensors import safe_open from safetensors.torch import save_file +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp @@ -197,10 +198,6 @@ def _test_unified_export_megatron( "vision encoder keys missing from combined export" ) # try to load the model and run a forward pass - from transformers.models.qwen3_vl.modeling_qwen3_vl import ( - Qwen3VLForConditionalGeneration, - ) - vl_model = Qwen3VLForConditionalGeneration.from_pretrained( tmp_export_dir, torch_dtype=torch.bfloat16 ).cuda() From 3f0b9210df32cde593809a60281580ee3f7f6a3c Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Wed, 20 May 2026 06:18:39 +0000 Subject: [PATCH 12/14] Guard Qwen3VL imports with try/except in transformers_models.py Top-level import of transformers.models.qwen3_vl fails on older transformers builds that don't have the submodule, breaking collection of every test that imports transformers_models.py. Use try/except so the module loads cleanly and get_tiny_qwen3vl() calls pytest.skip() when the classes are unavailable. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- tests/_test_utils/torch/transformers_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 708162e55dd..89fdb448c72 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -32,11 +32,16 @@ PreTrainedModel, Qwen3Config, Qwen3MoeConfig, - Qwen3VLConfig, T5Config, T5ForConditionalGeneration, ) -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + +try: + from transformers import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration +except ImportError: + Qwen3VLConfig = None # type: ignore[assignment,misc] + Qwen3VLForConditionalGeneration = None # type: ignore[assignment,misc] import modelopt.torch.opt as mto @@ -125,6 +130,8 @@ def create_tiny_qwen3_moe_dir( ##### Qwen3-VL ##### def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel: + if Qwen3VLConfig is None: + pytest.skip("transformers does not have Qwen3VL support") set_seed(SEED) # Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32). From 826667061b67178636b3ec50176af78e94404511 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Wed, 20 May 2026 06:47:29 +0000 Subject: [PATCH 13/14] Revert "Guard Qwen3VL imports with try/except in transformers_models.py" This reverts commit 10038f0aafa504831f269a9f5b9f68a5d30603c5. Signed-off-by: Hung-Yueh Chiang --- tests/_test_utils/torch/transformers_models.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 89fdb448c72..708162e55dd 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -32,16 +32,11 @@ PreTrainedModel, Qwen3Config, Qwen3MoeConfig, + Qwen3VLConfig, T5Config, T5ForConditionalGeneration, ) - -try: - from transformers import Qwen3VLConfig - from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration -except ImportError: - Qwen3VLConfig = None # type: ignore[assignment,misc] - Qwen3VLForConditionalGeneration = None # type: ignore[assignment,misc] +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.opt as mto @@ -130,8 +125,6 @@ def create_tiny_qwen3_moe_dir( ##### Qwen3-VL ##### def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel: - if Qwen3VLConfig is None: - pytest.skip("transformers does not have Qwen3VL support") set_seed(SEED) # Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32). From f56e4c2d3ce2b84eefabd76d81c9470ea4d8bbd8 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Wed, 20 May 2026 06:49:15 +0000 Subject: [PATCH 14/14] Revert "Move Qwen3VL imports back to module top level" This reverts commit 63a229a9191aaefe9dd4acc49745f9dbeeb85483. Signed-off-by: Hung-Yueh Chiang --- tests/_test_utils/torch/transformers_models.py | 8 ++++++-- .../torch/export/test_unified_export_megatron.py | 5 ++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 708162e55dd..fdd492013c1 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -32,11 +32,9 @@ PreTrainedModel, Qwen3Config, Qwen3MoeConfig, - Qwen3VLConfig, T5Config, T5ForConditionalGeneration, ) -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.opt as mto @@ -125,6 +123,12 @@ def create_tiny_qwen3_moe_dir( ##### Qwen3-VL ##### def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel: + # Lazy imports — Qwen3VL classes live under transformers.models.qwen3_vl which + # may not exist in older transformers builds, and this module is imported by + # every test that uses transformers_models.py. + from transformers import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + set_seed(SEED) # Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32). diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 4c899349dab..f7aad042c32 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -30,7 +30,6 @@ ) from safetensors import safe_open from safetensors.torch import save_file -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp @@ -198,6 +197,10 @@ def _test_unified_export_megatron( "vision encoder keys missing from combined export" ) # try to load the model and run a forward pass + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + ) + vl_model = Qwen3VLForConditionalGeneration.from_pretrained( tmp_export_dir, torch_dtype=torch.bfloat16 ).cuda()