Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion areal/engine/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
"qwen3_vl_moe",
"gemma3",
]
# This registry is used to check if a model is a vision model that we have checked it works with AReaL.
Expand All @@ -23,19 +24,44 @@ def is_qwen2_vl_model(model_type: str) -> bool:


def is_qwen3_vl_model(model_type: str) -> bool:
return model_type in ["qwen3_vl"]
"""True for the Qwen3-VL family (dense and MoE).

Existing call sites in ``fsdp_engine``, ``fsdp_utils/parallel``, and
``awex/fsdp_adapter`` gate family-level behaviour (mRoPE index,
attention-mask handling) that is identical for dense and MoE, so this
helper covers both. Use ``is_qwen3_vl_moe_model`` when the MoE-vs-dense
distinction matters.
"""
return model_type in ("qwen3_vl", "qwen3_vl_moe")


def is_qwen3_vl_moe_model(model_type: str) -> bool:
return model_type == "qwen3_vl_moe"


def is_qwen_vl_model(model_type: str) -> bool:
return is_qwen2_vl_model(model_type) or is_qwen3_vl_model(model_type)


def lang_config(hf_config):
"""Return the language-model side of a (possibly nested) HF config.

Qwen3-VL and similar VLMs nest text-model attributes (vocab_size,
num_attention_heads, num_key_value_heads, hidden_size, head_dim) under
``hf_config.text_config``. Qwen2.5-VL and pure text models keep them
flat. Use this anywhere the caller wants a language-side attribute and
doesn't know the model family up front.
"""
return getattr(hf_config, "text_config", hf_config)


def is_gemma3_model(model_type: str) -> bool:
return model_type in ["gemma3"]


VALID_MOE_MODELS = [
"qwen3_moe",
"qwen3_vl_moe",
"qwen3_5_moe",
"qwen3_5_moe_text",
"bailing_moe_v2",
Expand Down
8 changes: 6 additions & 2 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@
init_custom_process_group,
warmup_process_groups,
)
from areal.engine.core.model import disable_dropout_in_model, is_valid_vision_model
from areal.engine.core.model import (
disable_dropout_in_model,
is_valid_vision_model,
lang_config,
)
from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager
from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms
from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper
Expand Down Expand Up @@ -1463,7 +1467,7 @@ def _collect_param(
duplicated_param_names=self._duplicated_param_names,
gated_linear_unit=is_glu,
)
param = remove_padding(name, param, self.hf_config.vocab_size)
param = remove_padding(name, param, lang_config(self.hf_config).vocab_size)

if isinstance(param, FP8BlockwiseTensorHelper):
# FP8 is stored as uint8, so element_size is 1 byte
Expand Down
Loading
Loading