Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a7d1170
[Megatron Export] Add Qwen3-VL export/import mapping
hychiang-git Feb 14, 2026
36da6de
fix: ruff formatting and PT006 parametrize tuple fix
hychiang-git May 13, 2026
ff1152f
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 13, 2026
e8101a7
fix: apply ruff formatting to mcore_qwen3vl plugin and test files
hychiang-git May 13, 2026
aecbbfa
fix: collapse single-item imports in test_mcore_qwen3vl per ruff
hychiang-git May 14, 2026
80495e6
refactor: derive Qwen3-VL mcore mapping from Qwen3 via prefix rewrite
hychiang-git May 14, 2026
5bf943b
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 14, 2026
425145c
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 15, 2026
d6f03cd
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 18, 2026
6ad8d0e
Integrate Qwen3-VL mcore weight mapping tests into unified export tes…
hychiang-git May 18, 2026
77adc9d
Run _verify_model_quant_config for qwen3vl export
hychiang-git May 18, 2026
5cdb6b4
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 18, 2026
3637fe7
Fix ruff lint errors in test_unified_export_megatron.py
hychiang-git May 18, 2026
57a4608
Address PR review findings for Qwen3-VL mcore mapping
hychiang-git May 18, 2026
e8e2d7b
Address second-round PR review suggestions
hychiang-git May 18, 2026
1a86b05
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 19, 2026
63a229a
Move Qwen3VL imports back to module top level
hychiang-git May 19, 2026
73d74b3
Merge branch 'hungyueh/pr-895' of github.com:NVIDIA/Model-Optimizer i…
hychiang-git May 19, 2026
4dbffb2
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 19, 2026
cf0fb9f
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 19, 2026
1243b42
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 20, 2026
3f0b921
Guard Qwen3VL imports with try/except in transformers_models.py
hychiang-git May 20, 2026
8266670
Revert "Guard Qwen3VL imports with try/except in transformers_models.py"
hychiang-git May 20, 2026
f56e4c2
Revert "Move Qwen3VL imports back to module top level"
hychiang-git May 20, 2026
74019e3
Merge branch 'main' into hungyueh/pr-895
hychiang-git May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Changelog
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL.
- Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection <https://huggingface.co/collections/nvidia/nemotron-post-training-v3>`_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback.
- The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically.

Expand Down
1 change: 1 addition & 0 deletions docs/source/deployment/3_unified_hf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Models:
* Llama 4, 3.x (FP8, NVFP4)
* Qwen 3, 2.5 (FP8, NVFP4)
* Qwen 3 MoE (FP8, NVFP4)
* Qwen 3-VL (FP8, NVFP4)
* Deepseek R1/V3 (NVFP4)
* Mixtral 8x7B (FP8, NVFP4)
* Medusa (FP8)
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/export/plugins/mcore_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
qwen25_causal_lm_export,
qwen25_causal_lm_import,
)
from .mcore_qwen3vl import qwen3vl_causal_lm_export, qwen3vl_causal_lm_import

all_mcore_hf_export_mapping: dict[str, Any] = {
"DeepseekV2ForCausalLM": deepseek_causal_lm_export,
Expand All @@ -54,6 +55,7 @@
"Qwen3MoeForCausalLM": qwen3_causal_lm_export,
"Qwen2ForCausalLM": qwen25_causal_lm_export,
"GptOssForCausalLM": gptoss_causal_lm_export,
"Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export,
}

all_mcore_hf_import_mapping: dict[str, Any] = {
Expand All @@ -66,4 +68,5 @@
"Qwen3MoeForCausalLM": qwen3_causal_lm_import,
"Qwen2ForCausalLM": qwen25_causal_lm_import,
"GptOssForCausalLM": gptoss_causal_lm_import,
"Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import,
}
62 changes: 62 additions & 0 deletions modelopt/torch/export/plugins/mcore_qwen3vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models.

Qwen3-VL differs from Qwen3 in one structural way: language-model weights live
under ``model.language_model.`` instead of ``model.``, while ``lm_head.weight``
remains at the root level. The mappings below are derived automatically from
the Qwen3 mappings by inserting ``language_model.`` after ``model.`` for every
prefix that starts with ``model.``.

Note: the visual encoder (``model.visual.*``) is intentionally excluded — this
mapping covers only the language-model decoder used for quantization and export.

Note: ``Qwen3VLMoeForConditionalGeneration`` is **not** supported here. The MoE
variant stores expert weights as 3-D tensors (``mlp.experts.gate_up_proj``,
``mlp.experts.down_proj``) that require a dedicated fused-expert mapping and
cannot reuse the dense Qwen3 rules.

Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json
"""
Comment on lines +16 to +33
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Scope-clarifying note: Qwen3-VL ships in two architectures — Qwen3VLForConditionalGeneration (dense) and Qwen3VLMoeForConditionalGeneration (MoE, e.g. Qwen/Qwen3-VL-30B-A3B-Instruct). This PR only registers the dense variant.

The MoE variant cannot reuse qwen3_causal_lm_export with a prefix rewrite because Qwen3-VL-MoE stores experts in fused form (mlp.experts.gate_up_proj/mlp.experts.down_proj as 3-D tensors) rather than the per-expert layout (mlp.experts.{}.down_proj) that qwen3_causal_lm_* assumes. So this is a real limitation, not just a missing registration.

Consider adding a one-line note to the module docstring (e.g. "Covers the dense Qwen3VL variant only; Qwen3VLMoeForConditionalGeneration uses a fused-expert layout and requires a separate mapping.") so that users hitting KeyError: 'Qwen3VLMoeForConditionalGeneration' in _populate_rule_book know it's intentional and what's missing.


import copy

from .mcore_custom import CustomModuleMapping
from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import


def _with_language_model_prefix(
mapping: dict[str, CustomModuleMapping],
) -> dict[str, CustomModuleMapping]:
"""Derive a VL mapping from a base Qwen3 mapping.

Rewrites every ``target_name_or_prefix`` that starts with ``model.`` to
``model.language_model.<rest>``. Prefixes that do not start with
``model.`` (e.g. ``lm_head.``) are left unchanged.
"""
result = {}
for key, m in mapping.items():
prefix = m.target_name_or_prefix
if prefix.startswith("model."):
prefix = "model.language_model." + prefix[len("model.") :]
result[key] = type(m)(
target_name_or_prefix=prefix, func_kwargs=copy.deepcopy(m.func_kwargs)
)
return result
Comment on lines +50 to +58
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] m.func_kwargs is reused by reference (not copied). Because the parent Qwen3 mappings use module-level constants like COL_TP, ROW_TP, REPLICATE as func_kwargs, the resulting qwen3vl_* mapping entries hold the same parallel_config dict objects as the corresponding qwen3_* entries — and as each other when the source uses the same constant.

It's harmless today (these dicts are treated as immutable in the rest of the codebase), but a future caller that mutates mapping[k].func_kwargs to e.g. tweak a ParallelConfig for one entry would silently propagate the change to the Qwen3 base mapping and to any other entry sharing that constant. A shallow copy.copy(m.func_kwargs) (or dict(m.func_kwargs)) when constructing the rebuilt mapping would prevent the foot-gun:

result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=dict(m.func_kwargs))


Comment on lines +42 to +59
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Reconstructing each mapping via type(m)(target_name_or_prefix=..., func_kwargs=...) works today only because every CustomModuleMapping subclass in mcore_custom.py happens to share the same (target_name_or_prefix, func_kwargs) init signature. If a future subclass adds another required constructor arg (or changes the keyword names), this loop will silently break or drop information.

A more robust pattern is to deep-copy the original mapping and just rewrite the prefix:

def _with_language_model_prefix(
    mapping: dict[str, CustomModuleMapping],
) -> dict[str, CustomModuleMapping]:
    result = {}
    for key, m in mapping.items():
        new_m = copy.deepcopy(m)
        if new_m.target_name_or_prefix.startswith("model."):
            new_m.target_name_or_prefix = (
                "model.language_model." + new_m.target_name_or_prefix[len("model.") :]
            )
        result[key] = new_m
    return result

This preserves func_name and any future fields automatically and removes the implicit coupling to subclass __init__ signatures.


qwen3vl_causal_lm_import = _with_language_model_prefix(qwen3_causal_lm_import)
qwen3vl_causal_lm_export = _with_language_model_prefix(qwen3_causal_lm_export)
86 changes: 86 additions & 0 deletions tests/_test_utils/torch/transformers_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BertConfig,
GptOssConfig,
LlamaConfig,
NemotronConfig,
PreTrainedModel,
Qwen3Config,
Qwen3MoeConfig,
Expand Down Expand Up @@ -120,6 +121,91 @@ def create_tiny_qwen3_moe_dir(
return qwen3_moe_dir


##### Qwen3-VL #####
def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel:
# Lazy imports — Qwen3VL classes live under transformers.models.qwen3_vl which
# may not exist in older transformers builds, and this module is imported by
# every test that uses transformers_models.py.
from transformers import Qwen3VLConfig
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration

set_seed(SEED)

# Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32).
# Pass config_kwargs to override for multi-GPU tests (e.g. num_attention_heads=num_gpus,
# num_key_value_heads=num_gpus, hidden_size=num_gpus*head_dim).
text_kwargs = {
"hidden_size": 32,
"intermediate_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 8,
"max_position_embeddings": 32,
"vocab_size": 32,
}
text_kwargs.update(config_kwargs)
# Pass as dicts — transformers 5.3.0 Qwen3VLConfig.__init__ only handles
# vision_config/text_config when they are dicts or None, not instances.
vision_kwargs = {
"depth": 1,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"in_channels": 3,
"patch_size": 4,
"spatial_merge_size": 1,
"temporal_patch_size": 1,
"out_hidden_size": text_kwargs["hidden_size"], # must match text hidden_size
}
cfg = Qwen3VLConfig(text_config=text_kwargs, vision_config=vision_kwargs)
return Qwen3VLForConditionalGeneration(cfg)


def create_tiny_qwen3vl_dir(
tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs
) -> Path:
qwen3vl_dir = Path(tmp_path) / "tiny_qwen3vl"
if with_tokenizer:
tokenizer = get_tiny_tokenizer()
tokenizer.save_pretrained(qwen3vl_dir)
config_kwargs["vocab_size"] = tokenizer.vocab_size
get_tiny_qwen3vl(**config_kwargs).save_pretrained(qwen3vl_dir)
return qwen3vl_dir


##### NEMOTRON #####
def get_tiny_nemotron(**config_kwargs) -> PreTrainedModel:
set_seed(SEED)

# hidden_size=64, ffn_hidden_size=128: relu2 activation needs non-trivial dims
# to avoid all-zero activations (scaling factor 0) in NVFP4 quantization.
kwargs = {
"dtype": torch.bfloat16,
"hidden_size": 64,
"intermediate_size": 128,
"num_hidden_layers": 2,
"num_attention_heads": 8,
"num_key_value_heads": 1,
"max_position_embeddings": 32,
"vocab_size": 32,
}
kwargs.update(**config_kwargs)
return AutoModelForCausalLM.from_config(NemotronConfig(**kwargs))


def create_tiny_nemotron_dir(
tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs
) -> Path:
nemotron_dir = Path(tmp_path) / "tiny_nemotron"
if with_tokenizer:
tokenizer = get_tiny_tokenizer()
tokenizer.save_pretrained(nemotron_dir)
config_kwargs["vocab_size"] = tokenizer.vocab_size
get_tiny_nemotron(**config_kwargs).save_pretrained(nemotron_dir)
return nemotron_dir


##### GPT-OSS #####
def get_tiny_gpt_oss(**config_kwargs) -> PreTrainedModel:
set_seed(SEED)
Expand Down
Loading
Loading