Skip to content

Commit f0d2237

Browse files
hychiang-gitclaude
andauthored
Add Qwen3VL MCore Export support from PR 895 (#1482)
# [Megatron Export] Add Qwen3-VL mcore ↔ HF weight mapping > This PR is duplicated from [PR #895](#895). > The original branch source is no longer available; this new branch carries the same changes forward. ## What does this PR do? **New feature:** Add Qwen3-VL (Vision-Language) model support to the Megatron Core export/import plugin, enabling HuggingFace-to-mcore weight conversion for PTQ/QAT/QAD workflows. ### Overview Qwen3-VL has a different weight structure from Qwen3 text-only models: - Language model weights are under `model.language_model.` prefix (not `model.`) - Visual encoder weights are under `model.visual.` prefix - `lm_head` is at root level, not nested under `language_model` ### What changed | File | Change | |---|---| | `modelopt/torch/export/plugins/mcore_qwen3vl.py` | New plugin: derives Qwen3-VL mcore↔HF mapping by rewriting `model.*` → `model.language_model.*` on top of the existing Qwen3 dense rules; `lm_head.` is intentionally left unchanged | | `modelopt/torch/export/plugins/mcore_common.py` | Registers `Qwen3VLForConditionalGeneration` in `all_mcore_hf_export_mapping` and `all_mcore_hf_import_mapping` | | `modelopt/torch/export/plugins/hf_checkpoint_utils.py` | Generalized `load_multimodal_components` with a `prefixes` parameter; sharded checkpoints now scan all shards (not just the first) | | `modelopt/torch/export/unified_export_megatron.py` | `save_pretrained`: added Qwen3-VL branch that copies `model.visual.*` vision-encoder weights from the original HF checkpoint into the exported directory, producing a complete, loadable checkpoint | | `tests/_test_utils/torch/transformers_models.py` | Added `get_tiny_qwen3vl` / `create_tiny_qwen3vl_dir` helpers; Qwen3VL classes are lazy-imported inside the function to avoid collection failures on older transformers builds | | `tests/gpu_megatron/torch/export/test_unified_export_megatron.py` | Integrated Qwen3-VL export/import tests into the existing `test_unified_export_megatron` / `test_unified_import_megatron` parametrized suites; removed standalone `test_mcore_qwen3vl.py` | | `docs/source/deployment/3_unified_hf.rst` | Added Qwen3-VL (FP8 / NVFP4) to the deployment support matrix for TensorRT-LLM | ### Workflow coverage | Step | Status | Files | |---|---|---| | 1. Quantize Qwen3-VL with `hf_ptq` | ✅ existing | — | | 2. Export quantized mcore → HF | ✅ this PR | `plugins/mcore_qwen3vl.py` (weight name mapping), `unified_export_megatron.py` (export path) | | 3. Vision-encoder weights merged into export dir | ✅ this PR | `plugins/hf_checkpoint_utils.py` (`load_multimodal_components` with `prefixes`), `unified_export_megatron.py` (calls it when `arch == "Qwen3VLForConditionalGeneration"`) | | 4. Import HF checkpoint back to mcore | ✅ this PR | `plugins/mcore_qwen3vl.py` (same mapping, reverse direction), `unified_export_megatron.py` (import path) | ### Design notes - **MoE not supported**: `Qwen3VLMoeForConditionalGeneration` stores expert weights as 3-D tensors (`mlp.experts.gate_up_proj`, `mlp.experts.down_proj`) that require a dedicated fused-expert mapping. A `NotImplementedError` comment in the plugin documents this explicitly. - **`copy.deepcopy` on `func_kwargs`**: each mapping entry gets its own copy to prevent shared-dict mutation when both Qwen3 and Qwen3-VL rules are loaded. - **`prefixes` parameter on `load_multimodal_components`**: backward-compatible default preserves existing LLaVA behaviour (`"multi_modal_projector"`, `"vision_model"`); Qwen3-VL callers pass `("model.visual.",)`. - **Sharded checkpoint scan**: the old code only looked in the first shard. The Qwen3-VL vision encoder can span multiple shards, so all shards are now scanned. ## Usage From the [Megatron-LM PR comment](NVIDIA/Megatron-LM#3444 (comment)): > Qwen3VL is supported within [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge), and pretraining and PEFT recipes for Qwen3VL are [here](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py) and the core code logic [here](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/src/megatron/bridge/models/qwen_vl). Create `Megatron-LM/examples/post_training/modelopt/conf/Qwen/Qwen3-VL-8B-Instruct.sh`: ```bash #!/bin/bash # Qwen3-VL-8B-Instruct text-model config for Megatron-LM import/quantize. # # Text-model dimensions are identical to Qwen3-8B (4096 hidden, 36 layers, # 32 heads, GQA=8). Differences: rope_theta=5000000, checkpoint path uses # model.language_model.* prefix (handled by mcore_qwen3vl plugin). if [ -z ${HF_MODEL_CKPT} ]; then HF_MODEL_CKPT=Qwen/Qwen3-VL-8B-Instruct TOKENIZER_MODEL=Qwen/Qwen3-VL-8B-Instruct else TOKENIZER_MODEL=${HF_MODEL_CKPT} fi MODEL_ARGS=" \ --save-interval 100000 \ --micro-batch-size 1 \ --bf16 \ --no-masked-softmax-fusion \ --disable-bias-linear \ --untie-embeddings-and-output-weights \ --position-embedding-type rope \ --no-rope-fusion \ --normalization RMSNorm \ --swiglu \ --num-layers 36 \ --hidden-size 4096 \ --ffn-hidden-size 12288 \ --num-attention-heads 32 \ --group-query-attention \ --num-query-groups 8 \ --kv-channels 128 \ --qk-layernorm \ --seq-length 4096 \ --max-position-embeddings 262144 \ --tokenizer-type HuggingFaceTokenizer \ --make-vocab-size-divisible-by 1187 \ --use-mcore-models \ --rotary-percent 1.0 \ --rotary-base 5000000 \ --no-bias-swiglu-fusion \ " ``` Import Qwen3-VL from HuggingFace to MCore (local, requires GPUs): ```bash MLM_MODEL_CFG=Qwen/Qwen3-VL-8B-Instruct \ HF_MODEL_CKPT=Qwen/Qwen3-VL-8B-Instruct \ MLM_MODEL_SAVE=/tmp/qwen3vl_mcore \ TP=1 \ bash Megatron-LM/examples/post_training/modelopt/convert.sh Qwen/Qwen3-VL-8B-Instruct ``` Quantize (PTQ via Megatron-LM path): ```bash MLM_MODEL_CFG=Qwen/Qwen3-VL-8B-Instruct \ HF_MODEL_CKPT=Qwen/Qwen3-VL-8B-Instruct \ QUANT_CFG=NVFP4_DEFAULT_CFG \ TP=4 \ bash Megatron-LM/examples/post_training/modelopt/quantize.sh Qwen/Qwen3-VL-8B-Instruct ``` ## Testing - Verified round-trip import/export with Qwen3-VL-8B-Instruct with the example usage above - Unit/GPU tests covering: - Registration in global export/import mappings - Import mapping: dense keys, `model.language_model.` prefix, `lm_head.` at root, `QKVMerging`, `GatedMLPMerging`, `REPLICATE` for layernorms, TP sharding configs - Export mapping: `QKVSlicing`, `GatedMLPSlicing`, no `parallel_config` - Import/export symmetry: same mcore keys, matching HF prefixes - Qwen3-VL vs Qwen3 difference: same keys, VL adds `language_model.` prefix, `lm_head` unchanged ## Before your PR is "Ready for review" - Is this change backward compatible?: Yes, additive only - Did you write any new necessary tests?: Yes, `tests/gpu_megatron/torch/export/test_unified_export_megatron.py` - Did you add or update any necessary documentation? Yes, see `docs/source/deployment/3_unified_hf.rst` - Did you update Changelog? Yes, see `CHANGELOG.rst` ## Additional Information Companion Megatron-LM PR adds `Qwen3VLModel`, `Qwen3VLDataset`, and `pretrain_qwenvl.py`. See: NVIDIA/Megatron-LM#3444 --------- Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com> Signed-off-by: hychiang <hungyuehc@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 2b7668d commit f0d2237

8 files changed

Lines changed: 302 additions & 91 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Changelog
3232
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
3333
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
3434
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
35+
- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL.
3536
- Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection <https://huggingface.co/collections/nvidia/nemotron-post-training-v3>`_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback.
3637
- The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically.
3738
- Add shared Megatron-Core calibration forward loop: ``modelopt.torch.utils.plugins.megatron_calibration.get_megatron_calibration_forward_loop`` produces the ``forward_loop`` callable expected by ``mtq.quantize`` / ``mtp.prune``. Replaces the bespoke calibration loops in Megatron-LM and Megatron-Bridge for quantization and pruning with a single canonical implementation.

docs/source/deployment/3_unified_hf.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Models:
6161
* Llama 4, 3.x (FP8, NVFP4)
6262
* Qwen 3, 2.5 (FP8, NVFP4)
6363
* Qwen 3 MoE (FP8, NVFP4)
64+
* Qwen 3-VL (FP8, NVFP4)
6465
* Deepseek R1/V3 (NVFP4)
6566
* Mixtral 8x7B (FP8, NVFP4)
6667
* Medusa (FP8)

modelopt/torch/export/plugins/hf_checkpoint_utils.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ def copy_hf_ckpt_remote_code(
8888

8989
def load_multimodal_components(
9090
pretrained_model_path: str | os.PathLike,
91+
prefixes: tuple[str, ...] = ("multi_modal_projector", "vision_model"),
9192
) -> dict[str, torch.Tensor]:
9293
"""Load multimodal components from safetensors file.
9394
9495
Args:
9596
pretrained_model_path: Path to the pretrained model.
97+
prefixes: Tensor key prefixes to select. Defaults to the LLaVA-style
98+
``multi_modal_projector`` / ``vision_model`` prefixes. Pass
99+
``("model.visual.",)`` for Qwen3-VL checkpoints.
96100
97101
Returns:
98102
A dictionary of multimodal components.
@@ -114,7 +118,7 @@ def load_multimodal_components(
114118
multimodal_keys = [
115119
key
116120
for key in f.keys() # noqa: SIM118
117-
if key.startswith(("multi_modal_projector", "vision_model"))
121+
if key.startswith(prefixes)
118122
]
119123
for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"):
120124
multimodal_state_dict[key] = f.get_tensor(key)
@@ -124,28 +128,13 @@ def load_multimodal_components(
124128
with open(safetensors_index_file) as f:
125129
safetensors_index = json.load(f)
126130

127-
# For multimodal models, vision_model and multi_modal_projector are in the first shard
128131
all_shard_files = sorted(set(safetensors_index["weight_map"].values()))
129-
first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors"
130-
131-
# Load multimodal components from the first shard file
132-
safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file
133-
print(f"Loading multimodal components from {first_shard_file}")
134-
135-
with safe_open(safetensors_filepath, framework="pt") as f:
136-
shard_keys = list(f.keys())
137-
multimodal_keys_in_shard = [
138-
k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model"))
139-
]
140-
141-
if multimodal_keys_in_shard:
142-
print(
143-
f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}"
144-
)
145-
for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"):
146-
multimodal_state_dict[key] = f.get_tensor(key)
147-
else:
148-
print(f"No multimodal components found in {first_shard_file}")
132+
for shard_file in all_shard_files:
133+
safetensors_filepath = Path(hf_checkpoint_path) / shard_file
134+
with safe_open(safetensors_filepath, framework="pt") as f:
135+
for key in f.keys(): # noqa: SIM118
136+
if key.startswith(prefixes):
137+
multimodal_state_dict[key] = f.get_tensor(key)
149138

150139
else:
151140
print(f"Warning: No safetensors files found in {hf_checkpoint_path}")

modelopt/torch/export/plugins/mcore_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
qwen25_causal_lm_export,
4040
qwen25_causal_lm_import,
4141
)
42+
from .mcore_qwen3vl import qwen3vl_causal_lm_export, qwen3vl_causal_lm_import
4243

4344
all_mcore_hf_export_mapping: dict[str, Any] = {
4445
"DeepseekV2ForCausalLM": deepseek_causal_lm_export,
@@ -54,6 +55,7 @@
5455
"Qwen3MoeForCausalLM": qwen3_causal_lm_export,
5556
"Qwen2ForCausalLM": qwen25_causal_lm_export,
5657
"GptOssForCausalLM": gptoss_causal_lm_export,
58+
"Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export,
5759
}
5860

5961
all_mcore_hf_import_mapping: dict[str, Any] = {
@@ -66,4 +68,5 @@
6668
"Qwen3MoeForCausalLM": qwen3_causal_lm_import,
6769
"Qwen2ForCausalLM": qwen25_causal_lm_import,
6870
"GptOssForCausalLM": gptoss_causal_lm_import,
71+
"Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import,
6972
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models.
17+
18+
Qwen3-VL differs from Qwen3 in one structural way: language-model weights live
19+
under ``model.language_model.`` instead of ``model.``, while ``lm_head.weight``
20+
remains at the root level. The mappings below are derived automatically from
21+
the Qwen3 mappings by inserting ``language_model.`` after ``model.`` for every
22+
prefix that starts with ``model.``.
23+
24+
Note: the visual encoder (``model.visual.*``) is intentionally excluded — this
25+
mapping covers only the language-model decoder used for quantization and export.
26+
27+
Note: ``Qwen3VLMoeForConditionalGeneration`` is **not** supported here. The MoE
28+
variant stores expert weights as 3-D tensors (``mlp.experts.gate_up_proj``,
29+
``mlp.experts.down_proj``) that require a dedicated fused-expert mapping and
30+
cannot reuse the dense Qwen3 rules.
31+
32+
Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json
33+
"""
34+
35+
import copy
36+
37+
from .mcore_custom import CustomModuleMapping
38+
from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import
39+
40+
41+
def _with_language_model_prefix(
42+
mapping: dict[str, CustomModuleMapping],
43+
) -> dict[str, CustomModuleMapping]:
44+
"""Derive a VL mapping from a base Qwen3 mapping.
45+
46+
Rewrites every ``target_name_or_prefix`` that starts with ``model.`` to
47+
``model.language_model.<rest>``. Prefixes that do not start with
48+
``model.`` (e.g. ``lm_head.``) are left unchanged.
49+
"""
50+
result = {}
51+
for key, m in mapping.items():
52+
prefix = m.target_name_or_prefix
53+
if prefix.startswith("model."):
54+
prefix = "model.language_model." + prefix[len("model.") :]
55+
result[key] = type(m)(
56+
target_name_or_prefix=prefix, func_kwargs=copy.deepcopy(m.func_kwargs)
57+
)
58+
return result
59+
60+
61+
qwen3vl_causal_lm_import = _with_language_model_prefix(qwen3_causal_lm_import)
62+
qwen3vl_causal_lm_export = _with_language_model_prefix(qwen3_causal_lm_export)

modelopt/torch/export/unified_export_megatron.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,20 @@ def save_pretrained(
382382
# Add multimodal components to state_dict. Since only support decoder model quantization,
383383
# no changes will be made to the multimodal components. We copy the multimodal components
384384
# from the pretrained model directly to the state_dict to avoid implementing the export logic.
385-
if is_first_stage_main_rank and self.is_multimodal:
386-
multimodal_state_dict = load_multimodal_components(pretrained_model_name_or_path)
387-
layer_state_dicts[0].update(multimodal_state_dict)
385+
if is_first_stage_main_rank:
386+
# layer_state_dicts is keyed by layer_number (1-indexed), so the first
387+
# decoder layer on this (first) PP stage is the smallest key, not 0.
388+
# Merge the multimodal components into that shard so they land in a file
389+
# the index builder picks up (it scans shards 1..num_layers).
390+
first_layer_key = next(iter(layer_state_dicts))
391+
if self.is_multimodal:
392+
multimodal_state_dict = load_multimodal_components(pretrained_model_name_or_path)
393+
layer_state_dicts[first_layer_key].update(multimodal_state_dict)
394+
elif self.arch == "Qwen3VLForConditionalGeneration":
395+
vision_state_dict = load_multimodal_components(
396+
pretrained_model_name_or_path, prefixes=("model.visual.",)
397+
)
398+
layer_state_dicts[first_layer_key].update(vision_state_dict)
388399

389400
# Barrier to ensure the export_dir has been created.
390401
torch.distributed.barrier()

tests/_test_utils/torch/transformers_models.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DeepseekV3Config,
3030
GptOssConfig,
3131
LlamaConfig,
32+
NemotronConfig,
3233
PreTrainedModel,
3334
Qwen3Config,
3435
Qwen3MoeConfig,
@@ -121,6 +122,91 @@ def create_tiny_qwen3_moe_dir(
121122
return qwen3_moe_dir
122123

123124

125+
##### Qwen3-VL #####
126+
def get_tiny_qwen3vl(**config_kwargs) -> PreTrainedModel:
127+
# Lazy imports — Qwen3VL classes live under transformers.models.qwen3_vl which
128+
# may not exist in older transformers builds, and this module is imported by
129+
# every test that uses transformers_models.py.
130+
from transformers import Qwen3VLConfig
131+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
132+
133+
set_seed(SEED)
134+
135+
# Defaults: hidden_size=num_attention_heads*head_dim (e.g. 4*8=32).
136+
# Pass config_kwargs to override for multi-GPU tests (e.g. num_attention_heads=num_gpus,
137+
# num_key_value_heads=num_gpus, hidden_size=num_gpus*head_dim).
138+
text_kwargs = {
139+
"hidden_size": 32,
140+
"intermediate_size": 32,
141+
"num_hidden_layers": 2,
142+
"num_attention_heads": 4,
143+
"num_key_value_heads": 2,
144+
"head_dim": 8,
145+
"max_position_embeddings": 32,
146+
"vocab_size": 32,
147+
}
148+
text_kwargs.update(config_kwargs)
149+
# Pass as dicts — transformers 5.3.0 Qwen3VLConfig.__init__ only handles
150+
# vision_config/text_config when they are dicts or None, not instances.
151+
vision_kwargs = {
152+
"depth": 1,
153+
"hidden_size": 16,
154+
"intermediate_size": 16,
155+
"num_heads": 2,
156+
"in_channels": 3,
157+
"patch_size": 4,
158+
"spatial_merge_size": 1,
159+
"temporal_patch_size": 1,
160+
"out_hidden_size": text_kwargs["hidden_size"], # must match text hidden_size
161+
}
162+
cfg = Qwen3VLConfig(text_config=text_kwargs, vision_config=vision_kwargs)
163+
return Qwen3VLForConditionalGeneration(cfg)
164+
165+
166+
def create_tiny_qwen3vl_dir(
167+
tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs
168+
) -> Path:
169+
qwen3vl_dir = Path(tmp_path) / "tiny_qwen3vl"
170+
if with_tokenizer:
171+
tokenizer = get_tiny_tokenizer()
172+
tokenizer.save_pretrained(qwen3vl_dir)
173+
config_kwargs["vocab_size"] = tokenizer.vocab_size
174+
get_tiny_qwen3vl(**config_kwargs).save_pretrained(qwen3vl_dir)
175+
return qwen3vl_dir
176+
177+
178+
##### NEMOTRON #####
179+
def get_tiny_nemotron(**config_kwargs) -> PreTrainedModel:
180+
set_seed(SEED)
181+
182+
# hidden_size=64, ffn_hidden_size=128: relu2 activation needs non-trivial dims
183+
# to avoid all-zero activations (scaling factor 0) in NVFP4 quantization.
184+
kwargs = {
185+
"dtype": torch.bfloat16,
186+
"hidden_size": 64,
187+
"intermediate_size": 128,
188+
"num_hidden_layers": 2,
189+
"num_attention_heads": 8,
190+
"num_key_value_heads": 1,
191+
"max_position_embeddings": 32,
192+
"vocab_size": 32,
193+
}
194+
kwargs.update(**config_kwargs)
195+
return AutoModelForCausalLM.from_config(NemotronConfig(**kwargs))
196+
197+
198+
def create_tiny_nemotron_dir(
199+
tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs
200+
) -> Path:
201+
nemotron_dir = Path(tmp_path) / "tiny_nemotron"
202+
if with_tokenizer:
203+
tokenizer = get_tiny_tokenizer()
204+
tokenizer.save_pretrained(nemotron_dir)
205+
config_kwargs["vocab_size"] = tokenizer.vocab_size
206+
get_tiny_nemotron(**config_kwargs).save_pretrained(nemotron_dir)
207+
return nemotron_dir
208+
209+
124210
##### DeepSeek V3 #####
125211
def get_tiny_deepseek_v3(**config_kwargs) -> PreTrainedModel:
126212
set_seed(SEED)

0 commit comments

Comments
 (0)