Skip to content

Commit 9c4c497

Browse files
authored
feat(engine): add Qwen3-VL dense and MoE support to Megatron path (#1301)
* feat(engine): add Qwen3-VL dense support to Megatron path Extend the Megatron engine to train Qwen3-VL dense models end-to-end: mcore→HF weight conversion for update_weights and HF→mcore loading that handles Qwen3-VL's nested HF config layout. Without this, GRPO/PPO of any Qwen3-VL model on the Megatron backend is blocked.
1 parent 2755661 commit 9c4c497

9 files changed

Lines changed: 1842 additions & 202 deletions

File tree

areal/engine/core/model.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"qwen2_vl",
77
"qwen2_5_vl",
88
"qwen3_vl",
9+
"qwen3_vl_moe",
910
"gemma3",
1011
]
1112
# This registry is used to check if a model is a vision model that we have checked it works with AReaL.
@@ -23,19 +24,44 @@ def is_qwen2_vl_model(model_type: str) -> bool:
2324

2425

2526
def is_qwen3_vl_model(model_type: str) -> bool:
26-
return model_type in ["qwen3_vl"]
27+
"""True for the Qwen3-VL family (dense and MoE).
28+
29+
Existing call sites in ``fsdp_engine``, ``fsdp_utils/parallel``, and
30+
``awex/fsdp_adapter`` gate family-level behaviour (mRoPE index,
31+
attention-mask handling) that is identical for dense and MoE, so this
32+
helper covers both. Use ``is_qwen3_vl_moe_model`` when the MoE-vs-dense
33+
distinction matters.
34+
"""
35+
return model_type in ("qwen3_vl", "qwen3_vl_moe")
36+
37+
38+
def is_qwen3_vl_moe_model(model_type: str) -> bool:
39+
return model_type == "qwen3_vl_moe"
2740

2841

2942
def is_qwen_vl_model(model_type: str) -> bool:
3043
return is_qwen2_vl_model(model_type) or is_qwen3_vl_model(model_type)
3144

3245

46+
def lang_config(hf_config):
47+
"""Return the language-model side of a (possibly nested) HF config.
48+
49+
Qwen3-VL and similar VLMs nest text-model attributes (vocab_size,
50+
num_attention_heads, num_key_value_heads, hidden_size, head_dim) under
51+
``hf_config.text_config``. Qwen2.5-VL and pure text models keep them
52+
flat. Use this anywhere the caller wants a language-side attribute and
53+
doesn't know the model family up front.
54+
"""
55+
return getattr(hf_config, "text_config", hf_config)
56+
57+
3358
def is_gemma3_model(model_type: str) -> bool:
3459
return model_type in ["gemma3"]
3560

3661

3762
VALID_MOE_MODELS = [
3863
"qwen3_moe",
64+
"qwen3_vl_moe",
3965
"qwen3_5_moe",
4066
"qwen3_5_moe_text",
4167
"bailing_moe_v2",

areal/engine/megatron_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@
5656
init_custom_process_group,
5757
warmup_process_groups,
5858
)
59-
from areal.engine.core.model import disable_dropout_in_model, is_valid_vision_model
59+
from areal.engine.core.model import (
60+
disable_dropout_in_model,
61+
is_valid_vision_model,
62+
lang_config,
63+
)
6064
from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager
6165
from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms
6266
from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper
@@ -1463,7 +1467,7 @@ def _collect_param(
14631467
duplicated_param_names=self._duplicated_param_names,
14641468
gated_linear_unit=is_glu,
14651469
)
1466-
param = remove_padding(name, param, self.hf_config.vocab_size)
1470+
param = remove_padding(name, param, lang_config(self.hf_config).vocab_size)
14671471

14681472
if isinstance(param, FP8BlockwiseTensorHelper):
14691473
# FP8 is stored as uint8, so element_size is 1 byte

0 commit comments

Comments
 (0)