diff --git a/CHANGELOG.rst b/CHANGELOG.rst index be2210a33f2..cd8a0451129 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,6 +25,7 @@ Changelog - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. +- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL. - Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection `_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback. - The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically. diff --git a/docs/source/deployment/3_unified_hf.rst b/docs/source/deployment/3_unified_hf.rst index 9124164b576..6664f987f72 100644 --- a/docs/source/deployment/3_unified_hf.rst +++ b/docs/source/deployment/3_unified_hf.rst @@ -61,6 +61,7 @@ Models: * Llama 4, 3.x (FP8, NVFP4) * Qwen 3, 2.5 (FP8, NVFP4) * Qwen 3 MoE (FP8, NVFP4) + * Qwen 3-VL (FP8, NVFP4) * Deepseek R1/V3 (NVFP4) * Mixtral 8x7B (FP8, NVFP4) * Medusa (FP8) diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index d5bab9b4ece..15395b7a1e5 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -39,6 +39,7 @@ qwen25_causal_lm_export, qwen25_causal_lm_import, ) +from .mcore_qwen3vl import qwen3vl_causal_lm_export, qwen3vl_causal_lm_import all_mcore_hf_export_mapping: dict[str, Any] = { "DeepseekV2ForCausalLM": deepseek_causal_lm_export, @@ -54,6 +55,7 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, "GptOssForCausalLM": gptoss_causal_lm_export, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export, } all_mcore_hf_import_mapping: dict[str, Any] = { @@ -66,4 +68,5 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_import, "Qwen2ForCausalLM": qwen25_causal_lm_import, "GptOssForCausalLM": gptoss_causal_lm_import, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import, } diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py new file mode 100644 index 00000000000..1f2d3830d61 --- /dev/null +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. + +Qwen3-VL differs from Qwen3 in one structural way: language-model weights live +under ``model.language_model.`` instead of ``model.``, while ``lm_head.weight`` +remains at the root level. The mappings below are derived automatically from +the Qwen3 mappings by inserting ``language_model.`` after ``model.`` for every +prefix that starts with ``model.``. + +Note: the visual encoder (``model.visual.*``) is intentionally excluded — this +mapping covers only the language-model decoder used for quantization and export. + +Note: ``Qwen3VLMoeForConditionalGeneration`` is **not** supported here. The MoE +variant stores expert weights as 3-D tensors (``mlp.experts.gate_up_proj``, +``mlp.experts.down_proj``) that require a dedicated fused-expert mapping and +cannot reuse the dense Qwen3 rules. + +Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json +""" + +import copy + +from .mcore_custom import CustomModuleMapping +from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import + + +def _with_language_model_prefix( + mapping: dict[str, CustomModuleMapping], +) -> dict[str, CustomModuleMapping]: + """Derive a VL mapping from a base Qwen3 mapping. + + Rewrites every ``target_name_or_prefix`` that starts with ``model.`` to + ``model.language_model.``. Prefixes that do not start with + ``model.`` (e.g. ``lm_head.``) are left unchanged. + """ + result = {} + for key, m in mapping.items(): + prefix = m.target_name_or_prefix + if prefix.startswith("model."): + prefix = "model.language_model." + prefix[len("model.") :] + result[key] = type(m)( + target_name_or_prefix=prefix, func_kwargs=copy.deepcopy(m.func_kwargs) + ) + return result + + +qwen3vl_causal_lm_import = _with_language_model_prefix(qwen3_causal_lm_import) +qwen3vl_causal_lm_export = _with_language_model_prefix(qwen3_causal_lm_export) diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 34bc96cd0ae..fdd492013c1 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -28,6 +28,7 @@ BertConfig, GptOssConfig, LlamaConfig, + NemotronConfig, PreTrainedModel, Qwen3Config, Qwen3MoeConfig, @@ -120,6 +121,91 @@ def create_tiny_qwen3_moe_dir( return qwen3_moe_dir +##### Qwen3-VL ##### +def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel: + # Lazy imports — Qwen3VL classes live under transformers.models.qwen3_vl which + # may not exist in older transformers builds, and this module is imported by + # every test that uses transformers_models.py. + from transformers import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + + set_seed(SEED) + + # Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32). + # Pass config_kwargs to override for multi-GPU tests (e.g. num_attention_heads=num_gpus, + # num_key_value_heads=num_gpus, hidden_size=num_gpus*head_dim). + text_kwargs = { + "hidden_size": 32, + "intermediate_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 8, + "max_position_embeddings": 32, + "vocab_size": 32, + } + text_kwargs.update(config_kwargs) + # Pass as dicts — transformers 5.3.0 Qwen3VLConfig.__init__ only handles + # vision_config/text_config when they are dicts or None, not instances. + vision_kwargs = { + "depth": 1, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "in_channels": 3, + "patch_size": 4, + "spatial_merge_size": 1, + "temporal_patch_size": 1, + "out_hidden_size": text_kwargs["hidden_size"], # must match text hidden_size + } + cfg = Qwen3VLConfig(text_config=text_kwargs, vision_config=vision_kwargs) + return Qwen3VLForConditionalGeneration(cfg) + + +def create_tiny_qwen3vl_dir( + tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs +) -> Path: + qwen3vl_dir = Path(tmp_path) / "tiny_qwen3vl" + if with_tokenizer: + tokenizer = get_tiny_tokenizer() + tokenizer.save_pretrained(qwen3vl_dir) + config_kwargs["vocab_size"] = tokenizer.vocab_size + get_tiny_qwen3vl(**config_kwargs).save_pretrained(qwen3vl_dir) + return qwen3vl_dir + + +##### NEMOTRON ##### +def get_tiny_nemotron(**config_kwargs) -> PreTrainedModel: + set_seed(SEED) + + # hidden_size=64, ffn_hidden_size=128: relu2 activation needs non-trivial dims + # to avoid all-zero activations (scaling factor 0) in NVFP4 quantization. + kwargs = { + "dtype": torch.bfloat16, + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 2, + "num_attention_heads": 8, + "num_key_value_heads": 1, + "max_position_embeddings": 32, + "vocab_size": 32, + } + kwargs.update(**config_kwargs) + return AutoModelForCausalLM.from_config(NemotronConfig(**kwargs)) + + +def create_tiny_nemotron_dir( + tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs +) -> Path: + nemotron_dir = Path(tmp_path) / "tiny_nemotron" + if with_tokenizer: + tokenizer = get_tiny_tokenizer() + tokenizer.save_pretrained(nemotron_dir) + config_kwargs["vocab_size"] = tokenizer.vocab_size + get_tiny_nemotron(**config_kwargs).save_pretrained(nemotron_dir) + return nemotron_dir + + ##### GPT-OSS ##### def get_tiny_gpt_oss(**config_kwargs) -> PreTrainedModel: set_seed(SEED) diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 3fac8269ccd..f7aad042c32 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -23,7 +23,11 @@ import transformers from _test_utils.torch.megatron.models import get_mcore_gpt_model from _test_utils.torch.megatron.utils import get_forward -from _test_utils.torch.transformers_models import create_tiny_llama_dir, get_tiny_tokenizer +from _test_utils.torch.transformers_models import ( + create_tiny_llama_dir, + create_tiny_nemotron_dir, + create_tiny_qwen3vl_dir, +) from safetensors import safe_open from safetensors.torch import save_file @@ -71,21 +75,57 @@ def _verify_model_quant_config( assert quant_config_dict["kv_cache_quant_algo"] == KV_CACHE_FP8 +def _merge_vision_weights(src_dir: Path, dst_dir: Path) -> None: + """Copy model.visual.* tensors from src safetensors into dst_dir. + + The mcore export only writes language-model weights. To produce a complete + Qwen3-VL checkpoint the vision-encoder weights must be merged from the + original pretrained checkpoint. + """ + vision_tensors = {} + for sf in sorted(src_dir.glob("*.safetensors")): + with safe_open(str(sf), framework="pt", device="cpu") as f: + for key in f.keys(): # noqa: SIM118 + if key.startswith("model.visual."): + vision_tensors[key] = f.get_tensor(key) + if vision_tensors: + save_file(vision_tensors, str(dst_dir / "model-vision.safetensors")) + + def _test_unified_export_megatron( - tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg, rank, size + tmp_path, + model_type, + extra_module, + quant_config, + kv_cache_quant_cfg, + rank, + size, + model_dir=None, ): - tokenizer = get_tiny_tokenizer() - tokenizer.save_pretrained(tmp_path) + if model_type == "qwen3vl": + config = transformers.AutoConfig.from_pretrained(model_dir) + text_cfg = config.text_config + num_layers = text_cfg.num_hidden_layers + hidden_size = text_cfg.hidden_size + num_attention_heads = text_cfg.num_attention_heads + num_query_groups = text_cfg.num_key_value_heads + ffn_hidden_size = text_cfg.intermediate_size + max_sequence_length = text_cfg.max_position_embeddings + vocab_size = text_cfg.vocab_size + extra_kwargs = {"kv_channels": text_cfg.head_dim, "qk_layernorm": True} + elif model_type in {"llama", "nemotron"}: + config = transformers.AutoConfig.from_pretrained(model_dir) + num_layers = config.num_hidden_layers + hidden_size = config.hidden_size + num_attention_heads = config.num_attention_heads + num_query_groups = config.num_key_value_heads + ffn_hidden_size = config.intermediate_size + max_sequence_length = config.max_position_embeddings + vocab_size = config.vocab_size + extra_kwargs = {} + else: + raise ValueError(f"Unsupported model_type: {model_type}") - num_layers = 2 - hidden_size = 64 - num_attention_heads = 8 - num_query_groups = size - ffn_hidden_size = 128 - max_sequence_length = 32 - vocab_size = tokenizer.vocab_size - - arch = "NemotronForCausalLM" if model_type == "nemotron" else "LlamaForCausalLM" activation_func = "squared_relu" if model_type == "nemotron" else "swiglu" normalization = "LayerNorm" if model_type == "nemotron" else "RMSNorm" @@ -103,6 +143,7 @@ def _test_unified_export_megatron( activation_func=activation_func, normalization=normalization, transformer_impl="modelopt", + **extra_kwargs, ).cuda() if quant_config: @@ -127,26 +168,12 @@ def _test_unified_export_megatron( model = mtsp.convert(model, [("eagle", config)]) assert isinstance(model, _DynamicEagleGPTModel) - pretrained_config = { - "architectures": [arch], - "attention_bias": False, - "hidden_size": hidden_size, - "intermediate_size": ffn_hidden_size, - "max_position_embeddings": max_sequence_length, - "model_type": "llama", - "num_attention_heads": num_attention_heads, - "num_hidden_layers": num_layers, - "num_key_value_heads": num_query_groups, - "torch_dtype": "bfloat16", - } - - with open(tmp_path / "config.json", "w") as f: - json.dump(pretrained_config, f) + hf_config_dir = model_dir tmp_export_dir = tmp_path / "export" export_mcore_gpt_to_hf( model, - tmp_path if arch is not None else None, + hf_config_dir, dtype=torch.bfloat16, export_dir=str(tmp_export_dir), ) @@ -154,74 +181,125 @@ def _test_unified_export_megatron( if quant_config: _verify_model_quant_config(tmp_export_dir, quant_config, kv_cache_quant_cfg) + if model_type == "qwen3vl": + torch.distributed.barrier() + if rank == 0: + _merge_vision_weights(Path(model_dir), tmp_export_dir) + # sanity check that the vision encoder weights were merged + keys = [] + for sf in sorted(tmp_export_dir.glob("*.safetensors")): + with safe_open(str(sf), framework="pt", device="cpu") as f: + keys.extend(f.keys()) + assert any(k.startswith("model.language_model.") for k in keys), ( + "language model keys missing from combined export" + ) + assert any(k.startswith("model.visual.") for k in keys), ( + "vision encoder keys missing from combined export" + ) + # try to load the model and run a forward pass + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + ) + + vl_model = Qwen3VLForConditionalGeneration.from_pretrained( + tmp_export_dir, torch_dtype=torch.bfloat16 + ).cuda() + input_ids = torch.zeros(1, 4, dtype=torch.long).cuda() + with torch.no_grad(): + out = vl_model(input_ids=input_ids) + assert out.logits.shape[-1] == vl_model.config.text_config.vocab_size + @pytest.mark.parametrize( - ("model_type", "arch", "extra_module", "quant_config", "kv_cache_quant_cfg"), + ("model_type", "extra_module", "quant_config", "kv_cache_quant_cfg"), [ - ("nemotron", "NemotronForCausalLM", None, None, None), - ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", None), - ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", "FP8_KV_CFG"), - ("nemotron", "NemotronForCausalLM", "eagle", None, None), - ("nemotron", "NemotronForCausalLM", "medusa", None, None), - ("llama", "LlamaForCausalLM", None, None, None), - ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", None), - ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", "FP8_KV_CFG"), - ("llama", "LlamaForCausalLM", "eagle", None, None), - ("llama", "LlamaForCausalLM", "medusa", None, None), + ("nemotron", None, None, None), + ("nemotron", None, "NVFP4_DEFAULT_CFG", None), + ("nemotron", None, "NVFP4_DEFAULT_CFG", "FP8_KV_CFG"), + ("nemotron", "eagle", None, None), + ("nemotron", "medusa", None, None), + ("llama", None, None, None), + ("llama", None, "FP8_DEFAULT_CFG", None), + ("llama", None, "FP8_DEFAULT_CFG", "FP8_KV_CFG"), + ("llama", "eagle", None, None), + ("llama", "medusa", None, None), + ("qwen3vl", None, None, None), + ("qwen3vl", None, "FP8_DEFAULT_CFG", None), ], ) def test_unified_export_megatron( - dist_workers_size_1, tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg + dist_workers_size_1, tmp_path, model_type, extra_module, quant_config, kv_cache_quant_cfg ): + if model_type == "llama": + model_dir = create_tiny_llama_dir(tmp_path) + elif model_type == "qwen3vl": + model_dir = create_tiny_qwen3vl_dir(tmp_path) + elif model_type == "nemotron": + model_dir = create_tiny_nemotron_dir(tmp_path) + else: + raise ValueError(f"Unsupported model_type: {model_type}") # TODO: Fix TP>1 failures dist_workers_size_1.run( partial( _test_unified_export_megatron, tmp_path, model_type, - arch, extra_module, quant_config, kv_cache_quant_cfg, + model_dir=model_dir, ), ) -def _test_unified_import_megatron(tiny_llama_dir, rank, size): - config = transformers.AutoConfig.from_pretrained(tiny_llama_dir) +def _test_unified_import_megatron(model_dir, rank, size, model_type="llama"): + config = transformers.AutoConfig.from_pretrained(model_dir) - num_layers = config.num_hidden_layers - hidden_size = config.hidden_size - num_attention_heads = config.num_attention_heads - num_query_groups = config.num_key_value_heads - ffn_hidden_size = config.intermediate_size - max_sequence_length = config.max_position_embeddings - vocab_size = config.vocab_size - activation_func = "swiglu" - normalization = "RMSNorm" + if model_type == "qwen3vl": + cfg = config.text_config + extra_kwargs = { + "kv_channels": cfg.head_dim, + "transformer_impl": "modelopt", + "qk_layernorm": True, + } + else: + cfg = config + extra_kwargs = {} model = get_mcore_gpt_model( tensor_model_parallel_size=size, pipeline_model_parallel_size=1, initialize_megatron=True, - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups, - ffn_hidden_size=ffn_hidden_size, - max_sequence_length=max_sequence_length, - vocab_size=vocab_size, - activation_func=activation_func, - normalization=normalization, + num_layers=cfg.num_hidden_layers, + hidden_size=cfg.hidden_size, + num_attention_heads=cfg.num_attention_heads, + num_query_groups=cfg.num_key_value_heads, + ffn_hidden_size=cfg.intermediate_size, + max_sequence_length=cfg.max_position_embeddings, + vocab_size=cfg.vocab_size, + activation_func="swiglu", + normalization="RMSNorm", + **extra_kwargs, ).cuda() - import_mcore_gpt_from_hf(model, tiny_llama_dir) + import_mcore_gpt_from_hf(model, model_dir) -def test_unified_import_megatron(dist_workers, tmp_path): +@pytest.mark.parametrize("model_type", ["llama", "qwen3vl"]) +def test_unified_import_megatron(dist_workers, tmp_path, model_type): num_gpus = torch.cuda.device_count() - tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_key_value_heads=num_gpus) - dist_workers.run(partial(_test_unified_import_megatron, tiny_llama_dir)) + if model_type == "llama": + model_dir = create_tiny_llama_dir(tmp_path, num_key_value_heads=num_gpus) + elif model_type == "qwen3vl": + model_dir = create_tiny_qwen3vl_dir( + tmp_path, + num_attention_heads=num_gpus, + num_key_value_heads=num_gpus, + hidden_size=num_gpus * 8, # head_dim=8 + ) + else: + raise ValueError(f"Unsupported model_type: {model_type}") + dist_workers.run(partial(_test_unified_import_megatron, model_dir, model_type=model_type)) def _test_qkv_slicing_gqa_tp2(tmp_path, rank, size):