feat(engine): add Qwen3-VL dense and MoE support to Megatron path#1301
Conversation
Extend the Megatron engine to train Qwen3-VL dense models end-to-end: mcore→HF weight conversion for update_weights and HF→mcore loading that handles Qwen3-VL's nested HF config layout. Without this, GRPO/PPO of any Qwen3-VL model on the Megatron backend is blocked. Changes ======= - ``areal/engine/megatron_utils/megatron.py``: ``convert_qwen3_vl_to_hf`` anchored on ``mbridge.models.qwen3_vl.Qwen3VLBridge``, registered before ``"qwen3"`` in ``_CONVERSION_FN_REGISTRY``. Includes a defensive early-raise on Qwen3-VL-MoE expert/router param names so a future MoE arch routing here surfaces clearly. - ``areal/engine/core/model.py``: ``lang_config`` helper consolidates the ``getattr(hf_config, "text_config", hf_config)`` accessor used by both ``hf_load.py`` and ``megatron_engine.py:_collect_param``. Add ``qwen3_vl`` to ``VALID_VISION_MODELS`` and ``is_qwen3_vl_model``. - ``areal/models/mcore/hf_load.py``: route language-side config reads through ``lang_config`` so Qwen3-VL's nested ``text_config`` works alongside Qwen2.5-VL and pure text models. - ``areal/engine/megatron_engine.py``: ``_collect_param`` reads ``text_config.vocab_size`` with the same fallback. Tests ===== - Split VLM tests into CPU-only (``tests/test_megatron_engine_vlm.py``) and distributed integration (``tests/test_megatron_engine_vlm_distributed.py``) to mirror the convention used for dense LLM tests. - ``test_megatron_engine_vlm.py``: ``TestConvertQwen3VLToHF`` with fixture dims pinned to real ``Qwen/Qwen3-VL-2B-Instruct`` values; existing ``TestConvertQwen25VLToHF``/detection/remove-padding tests preserved. - ``test_megatron_engine_vlm_distributed.py``: parametric ``init`` / ``simple_forward`` / ``hf_save_load_weights`` / ``train_tensor_parallel`` tests over qwen25_vl + qwen3_vl. Helper infers ``nproc`` from ``ModelAllocation.from_str(backend).parallel.world_size``. - ``tests/torchrun/run_megatron_engine_vlm_distributed.py`` (renamed from ``run_megatron_engine_vlm.py``): ``mock_vlm_input`` reads patch geometry from ``engine.hf_config.vision_config`` so it works for both Qwen2.5-VL (patch=14) and Qwen3-VL (patch=16) without code-side branching. - ``areal/utils/testing_utils.py``: register ``qwen2_5_vl`` and ``qwen3_vl`` in ``DENSE_MODEL_PATHS`` so the parametric test fixtures consume one source of truth. Mapping anchored on mbridge.models.qwen3_vl.Qwen3VLBridge. Notes ===== - ``_vision_qkv_mcore_to_hf`` asserts ``num_kv_heads == num_heads`` (no GQA) for the vision tower. Both Qwen2.5-VL and Qwen3-VL satisfy this; the assertion catches future vision-GQA VLMs that would otherwise silently miscompile QKV. - ``_CONVERSION_FN_REGISTRY`` dispatches via substring matching, so a later ``qwen3_vl_moe`` model_type would silently fall through to the dense converter unless registered before ``qwen3_vl``. The early-raise on MoE-shaped names makes that requirement actionable.
Extend the Megatron engine to train Qwen3-VL-MoE end-to-end (verified on ``Qwen/Qwen3-VL-30B-A3B-Instruct``). Without this, RL of any Qwen3-VL-MoE checkpoint on the Megatron backend is blocked at engine init: the dense converter raises on MoE-shaped param names, the HF loader's per-expert slicer assumes 2D HF tensors, and the HF saver's per-rank emission buffer never fills under EP>1. Conversion (mcore↔HF) ===================== - ``areal/engine/megatron_utils/megatron.py``: factor the dense ``convert_qwen3_vl_to_hf`` body into ``_convert_qwen3_vl_lm_global``, ``_convert_qwen3_vl_lm_attention``, and ``_convert_qwen3_vl_vision_to_hf`` helpers. Add ``convert_qwen3_vl_moe_to_hf`` mirroring ``convert_qwen3moe_to_hf``'s per-expert flat HF emission so the XCCL ``update_weights`` path to vLLM/SGLang reuses the proven shape contract. Insert ``qwen3_vl_moe`` into ``_CONVERSION_FN_REGISTRY`` BEFORE ``qwen3_vl``, ``qwen3_moe``, and ``qwen3`` (substring-match dispatch). Drop the dead ``NotImplementedError`` guard in the dense converter. - ``areal/engine/core/model.py``: register ``qwen3_vl_moe`` in ``VALID_VISION_MODELS`` and ``VALID_MOE_MODELS``; add ``is_qwen3_vl_moe_model``; broaden ``is_qwen3_vl_model`` to the family so ``fsdp_engine``, ``fsdp_utils/parallel``, and ``awex/fsdp_adapter`` call sites cover both dense and MoE without duplication. HF→mcore loader =============== ``areal/models/mcore/hf_load.py``: add ``_slice_moe_expert_fc1_stacked_gate_up`` and ``_slice_moe_expert_fc2_stacked_down`` helpers that unpack 3D stacked HF expert tensors (``[E, hidden, 2*expert_dim]`` for ``gate_up_proj``, ``[E, expert_dim, hidden]`` for ``down_proj``) into mcore's per-expert 2D layout. Dispatch from ``_weight_to_mcore_tp`` when shape rank indicates the stacked HF format. HF saver under EP>1 =================== ``areal/models/mcore/hf_save.py``: add ``_bridge_uses_stacked_experts`` predicate (inferred from the bridge's ``_MLP_MAPPING`` per-expert template) plus a stacked-MoE branch in ``save_weights_to_hf_with_mbridge_fast`` that EP-all-gathers expert tensors per ``(layer, fc)`` group, feeds mbridge in global expert-index order so its internal buffer fills exactly once, consolidates writes on ``ep_rank==0``, and offloads stacked outputs to CPU memory immediately to bound peak GPU footprint by one stacked tensor at a time. Adjust shard-count math to drop the ``* ep_size`` multiplication for stacked emission. Per-expert-flat bridges (Qwen3-MoE, BailingMoeV2, DeepSeekV3) keep the existing fast path unchanged. Tests ===== - ``tests/test_megatron_engine_vlm.py``: add ``TestConvertQwen3VLMoEToHF`` covering registry-ordering invariant, expert fc1 chunk, fc2 passthrough, router rename, ``pre_mlp_layernorm`` rename, dense MLP fallback (for future ``decoder_sparse_step > 1`` variants), shared vision tower delegation, and unknown-name error paths. Extend ``TestVisionModelDetection`` for the new family. - ``tests/test_megatron_engine_vlm_distributed.py``: add ``test_qwen3vl_moe_expert_parallel`` (forward smoke under ``(attn:d2t4|ffn:d2e4)``) and ``test_qwen3vl_moe_dcp_save_load`` (parameter round-trip under ``(attn:d2p1t4|ffn:d1p1t2e4)``); add ``qwen3_vl_moe`` to ``_VLM_MODELS`` (skipped < 8 GPUs). - ``tests/torchrun/run_megatron_engine_vlm_distributed.py``: add ``dcp_save_load`` test type and ``wrap_with_ddp`` parameter to ``make_vlm_engine`` (forward-only paths skip the ~2× model-bytes grad-buffer alloc to clear NCCL-coordination headroom on 30B+ MoE). - ``areal/utils/testing_utils.py``: register ``qwen3_vl_moe`` in ``MOE_MODEL_PATHS``. Validation ========== End-to-end on 8×H100 (``Qwen/Qwen3-VL-30B-A3B-Instruct``): - 84 CPU unit tests pass (incl. dense regression). - Dense Qwen3-VL-2B compare TP=1 / TP=2: 625 params zero diff. - Qwen3-VL-MoE compare under ``(attn:d2t4|ffn:d2e4)`` (mbridge round-trip via the new EP-gather save path): 5394 params zero diff. - ``test_qwen3vl_moe_expert_parallel``: PASS. - ``test_qwen3vl_moe_dcp_save_load``: PASS. Notes ===== - Use ``bridge_type='mbridge'`` (default). megatron-bridge ≤ 0.3.0 has a deepstack gradient-checkpoint bug (separate upstream issue). - mbridge 0.15.1+ already registers ``Qwen3VLMoEBridge`` for the ``qwen3_vl_moe`` model_type; AReaL just needs to feed it correctly under EP>1, which is what this change does.
There was a problem hiding this comment.
Code Review
This pull request adds support for Qwen3-VL and Qwen3-VL-MoE models, including new weight conversion logic between Megatron and HuggingFace formats. It introduces specialized handling for 'stacked experts' in MoE models, ensuring correct weight slicing during loading and EP-gathering during saving. The test suite has been refactored to separate CPU-only unit tests from distributed GPU integration tests. Feedback highlights a logic error in hf_save.py where the calculation of rank-local expert indices incorrectly subtracts an offset from an already local index, potentially resulting in negative values.
| layer_idx, fc_kind, local_idx_str = m.groups() | ||
| # Recover the rank-local idx (0..num_experts_per_rank-1) from | ||
| # the global idx baked into global_name above. | ||
| local_id = int(local_idx_str) - num_experts_per_rank * ep_rank |
There was a problem hiding this comment.
The calculation of local_id appears to be logically incorrect based on the source of local_idx_str.
local_idx_str is extracted from s.local_name, which comes directly from model.named_parameters(). In Megatron Core (specifically for TEGroupedLinear), these names use rank-local indices (e.g., weight0, weight1, ... up to num_experts_per_rank - 1).
Subtracting num_experts_per_rank * ep_rank from an already rank-local index will result in negative values for any rank where ep_rank > 0. While the subsequent sort might still maintain the correct relative order (since the shift is constant per rank), it makes the code confusing and contradicts the comment on line 602. If the intent is to sort by the rank-local index, it should just be int(local_idx_str).
| local_id = int(local_idx_str) - num_experts_per_rank * ep_rank | |
| local_id = int(local_idx_str) |
There was a problem hiding this comment.
Good call, addressed in follow-up commit
Two corrections in the stacked-MoE branch of
``save_weights_to_hf_with_mbridge_fast``:
1. Drop the bogus ``- num_experts_per_rank * ep_rank`` subtraction when
parsing the per-expert sort key. ``local_name`` already carries the
rank-local TEGroupedLinear suffix (``weight0..weight{num_experts_per_rank-1}``);
subtracting the rank offset produced negative ``local_id`` values on
``ep_rank > 0``. Sort still happened to produce the right relative
order (offset is constant per rank), but the negative values were
confusing and the explanatory comment ("Recover the rank-local idx
from the global idx baked into global_name above") was wrong — the
parse target is ``local_name``, not ``global_name``.
Reported by gemini-code-assist on PR areal-project#1301.
2. Add a one-time invariant check across the EP group: every rank must
see the same set of ``(layer, fc)`` groups with the same per-group
expert count before driving the ``all_gather_into_tensor`` collective.
Without it, a future refactor that injects a vision-tower expert
spec on some ranks only — or a VPP layout where local layers differ
across EP ranks — would silently mismatch the collective and either
hang or scramble tensors across layers. One cheap
``all_gather_object`` eliminates that class of
silent-corruption-on-future-refactor failures.
garrett4wade
left a comment
There was a problem hiding this comment.
Looks goooood to me except for a minor coding style issue.
…to helpers Pulls the two ``if stacked_experts and ep_size > 1: ... else: ...`` branches in ``save_weights_to_hf_with_mbridge_fast`` out into module-level helpers: - ``_emit_stacked_moe_expert_sd``: stacked-MoE save (Qwen3-VL-MoE). Owns the (layer, fc) grouping, EP-invariant assertion, EP all-gather, global-index-ordered mbridge feed, and ``ep_rank==0`` CPU offload. - ``_emit_per_expert_flat_expert_sd``: per-expert-flat save (Qwen3-MoE, BailingMoeV2, DeepSeekV3). Each EP rank emits its local-experts subset independently. The caller's expert section now reduces to a 20-line dispatch. No behavioural change — pure extract-method. ETP gather / shard-count math / shard-write / weight_map collective stay in the caller since they are pre-/post-branch and shared across both paths. Addresses review feedback on PR areal-project#1301.
After main split the integration tests into ``tests/test_megatron_engine_vlm_distributed.py`` and renamed the torchrun script to ``run_megatron_engine_vlm_distributed.py`` (areal-project#1301), re-apply the same NPU compatibility shims as the previous combined-file version: - ``try: import mindspeed.megatron_adaptor`` at the top of both files so MindSpeed's adapters land before mbridge transitive imports on NPU. - Replace ``CUDA_AVAILABLE`` with ``ACCELERATOR_AVAILABLE`` driven by ``current_platform.device_type in ("cuda", "npu")`` and update the five ``@pytest.mark.skipif`` decorators / two ``device_count()`` checks. - ``current_platform.synchronize()`` instead of ``torch.cuda.synchronize()`` in the torchrun cleanup path. The model path is already centralized in ``areal/utils/testing_utils.py::DENSE_MODEL_PATHS`` (env override via ``VLM_MODEL_PATH``), so no additional path resolution is needed here.
After main split the integration tests into ``tests/test_megatron_engine_vlm_distributed.py`` and renamed the torchrun script to ``run_megatron_engine_vlm_distributed.py`` (areal-project#1301), re-apply the same NPU compatibility shims as the previous combined-file version: - ``try: import mindspeed.megatron_adaptor`` at the top of both files so MindSpeed's adapters land before mbridge transitive imports on NPU. - Replace ``CUDA_AVAILABLE`` with ``ACCELERATOR_AVAILABLE`` driven by ``current_platform.device_type in ("cuda", "npu")`` and update the five ``@pytest.mark.skipif`` decorators / two ``device_count()`` checks. - ``current_platform.synchronize()`` instead of ``torch.cuda.synchronize()`` in the torchrun cleanup path. The model path is already centralized in ``areal/utils/testing_utils.py::DENSE_MODEL_PATHS`` (env override via ``VLM_MODEL_PATH``), so no additional path resolution is needed here.
After main split the integration tests into ``tests/test_megatron_engine_vlm_distributed.py`` and renamed the torchrun script to ``run_megatron_engine_vlm_distributed.py`` (areal-project#1301), re-apply the same NPU compatibility shims as the previous combined-file version: - ``try: import mindspeed.megatron_adaptor`` at the top of both files so MindSpeed's adapters land before mbridge transitive imports on NPU. - Replace ``CUDA_AVAILABLE`` with ``ACCELERATOR_AVAILABLE`` driven by ``current_platform.device_type in ("cuda", "npu")`` and update the five ``@pytest.mark.skipif`` decorators / two ``device_count()`` checks. - ``current_platform.synchronize()`` instead of ``torch.cuda.synchronize()`` in the torchrun cleanup path. The model path is already centralized in ``areal/utils/testing_utils.py::DENSE_MODEL_PATHS`` (env override via ``VLM_MODEL_PATH``), so no additional path resolution is needed here.
After main split the integration tests into ``tests/test_megatron_engine_vlm_distributed.py`` and renamed the torchrun script to ``run_megatron_engine_vlm_distributed.py`` (areal-project#1301), re-apply the same NPU compatibility shims as the previous combined-file version: - ``try: import mindspeed.megatron_adaptor`` at the top of both files so MindSpeed's adapters land before mbridge transitive imports on NPU. - Replace ``CUDA_AVAILABLE`` with ``ACCELERATOR_AVAILABLE`` driven by ``current_platform.device_type in ("cuda", "npu")`` and update the five ``@pytest.mark.skipif`` decorators / two ``device_count()`` checks. - ``current_platform.synchronize()`` instead of ``torch.cuda.synchronize()`` in the torchrun cleanup path. The model path is already centralized in ``areal/utils/testing_utils.py::DENSE_MODEL_PATHS`` (env override via ``VLM_MODEL_PATH``), so no additional path resolution is needed here.
After main split the integration tests into ``tests/test_megatron_engine_vlm_distributed.py`` and renamed the torchrun script to ``run_megatron_engine_vlm_distributed.py`` (areal-project#1301), re-apply the same NPU compatibility shims as the previous combined-file version: - ``try: import mindspeed.megatron_adaptor`` at the top of both files so MindSpeed's adapters land before mbridge transitive imports on NPU. - Replace ``CUDA_AVAILABLE`` with ``ACCELERATOR_AVAILABLE`` driven by ``current_platform.device_type in ("cuda", "npu")`` and update the five ``@pytest.mark.skipif`` decorators / two ``device_count()`` checks. - ``current_platform.synchronize()`` instead of ``torch.cuda.synchronize()`` in the torchrun cleanup path. The model path is already centralized in ``areal/utils/testing_utils.py::DENSE_MODEL_PATHS`` (env override via ``VLM_MODEL_PATH``), so no additional path resolution is needed here.
After main split the integration tests into ``tests/test_megatron_engine_vlm_distributed.py`` and renamed the torchrun script to ``run_megatron_engine_vlm_distributed.py`` (areal-project#1301), re-apply the same NPU compatibility shims as the previous combined-file version: - ``try: import mindspeed.megatron_adaptor`` at the top of both files so MindSpeed's adapters land before mbridge transitive imports on NPU. - Replace ``CUDA_AVAILABLE`` with ``ACCELERATOR_AVAILABLE`` driven by ``current_platform.device_type in ("cuda", "npu")`` and update the five ``@pytest.mark.skipif`` decorators / two ``device_count()`` checks. - ``current_platform.synchronize()`` instead of ``torch.cuda.synchronize()`` in the torchrun cleanup path. The model path is already centralized in ``areal/utils/testing_utils.py::DENSE_MODEL_PATHS`` (env override via ``VLM_MODEL_PATH``), so no additional path resolution is needed here.
After main split the integration tests into ``tests/test_megatron_engine_vlm_distributed.py`` and renamed the torchrun script to ``run_megatron_engine_vlm_distributed.py`` (areal-project#1301), re-apply the same NPU compatibility shims as the previous combined-file version: - ``try: import mindspeed.megatron_adaptor`` at the top of both files so MindSpeed's adapters land before mbridge transitive imports on NPU. - Replace ``CUDA_AVAILABLE`` with ``ACCELERATOR_AVAILABLE`` driven by ``current_platform.device_type in ("cuda", "npu")`` and update the five ``@pytest.mark.skipif`` decorators / two ``device_count()`` checks. - ``current_platform.synchronize()`` instead of ``torch.cuda.synchronize()`` in the torchrun cleanup path. The model path is already centralized in ``areal/utils/testing_utils.py::DENSE_MODEL_PATHS`` (env override via ``VLM_MODEL_PATH``), so no additional path resolution is needed here.
Description
Adds Qwen3-VL dense and MoE support to AReaL's Megatron engine via mbridge, so GRPO/PPO of any Qwen3-VL or Qwen3-VL-MoE checkpoint on the Megatron backend is unblocked. Resubmission of #1299 (closed) extended with Qwen3-VL-MoE in a second commit.
Commits
b2057e3— feat(engine): add Qwen3-VL dense support to Megatron path (squash of the original feat(engine): add Qwen3-VL dense support to Megatron path #1299 + reorganized test layout)4434fea— feat(engine): add Qwen3-VL-MoE support to Megatron path (new)Commit 1 — Qwen3-VL dense
areal/engine/megatron_utils/megatron.py: newconvert_qwen3_vl_to_hfregistered before"qwen3"in_CONVERSION_FN_REGISTRY(mapping anchored onmbridge.models.qwen3_vl.Qwen3VLBridge). Carries an early-raise on Qwen3-VL-MoE param names soqwen3_vl_moecannot silently dispatch to the dense converter via substring match._vision_qkv_mcore_to_hfgains a no-GQA assertion that guards future vision-GQA VLMs.areal/engine/core/model.py: newlang_config(hf_config)helper —getattr(hf_config, "text_config", hf_config)— so callers can read language-side attrs (vocab_size,num_attention_heads,num_key_value_heads,hidden_size,head_dim) uniformly across Qwen3-VL (nested) and Qwen2.5-VL / pure text (flat).areal/models/mcore/hf_load.py:_merge_qkv_weights,_load_fused_qkv_weight, and the GQA branch of_weight_to_mcore_tpuse the sharedlang_confighelper.areal/engine/megatron_engine.py:_collect_paramuseslang_config(self.hf_config).vocab_sizeforremove_padding.tests/test_megatron_engine_vlm.pyinto CPU-only unit tests + a newtests/test_megatron_engine_vlm_distributed.pyfor torchrun-launched integration tests, mirroring the existingtest_megatron_engine_distributed.pyconvention. Renamedtests/torchrun/run_megatron_engine_vlm.py→run_megatron_engine_vlm_distributed.py. Helper_run_vlm_testinfersnprocfromModelAllocation.from_str(backend).parallel.world_size.areal/utils/testing_utils.py: registerqwen2_5_vlandqwen3_vlinDENSE_MODEL_PATHSso parametric test fixtures consume one source of truth.Commit 2 — Qwen3-VL-MoE
End-to-end on
Qwen/Qwen3-VL-30B-A3B-Instruct. Without this, RL is blocked at engine init: the dense converter raises on MoE-shaped param names, the HF loader's per-expert slicer assumes 2D HF tensors, and the HF saver's per-rank emission buffer never fills under EP>1.Conversion (mcore↔HF)
areal/engine/megatron_utils/megatron.py: factor the denseconvert_qwen3_vl_to_hfbody into_convert_qwen3_vl_lm_global,_convert_qwen3_vl_lm_attention, and_convert_qwen3_vl_vision_to_hfhelpers. Addconvert_qwen3_vl_moe_to_hfmirroringconvert_qwen3moe_to_hf's per-expert flat HF emission so the XCCLupdate_weightspath to vLLM/SGLang reuses the proven shape contract. Insertqwen3_vl_moeinto_CONVERSION_FN_REGISTRYbeforeqwen3_vl,qwen3_moe, andqwen3(substring-match dispatch). Drop the deadNotImplementedErrorguard in the dense converter.areal/engine/core/model.py: registerqwen3_vl_moeinVALID_VISION_MODELSandVALID_MOE_MODELS; addis_qwen3_vl_moe_model; broadenis_qwen3_vl_modelto the family sofsdp_engine,fsdp_utils/parallel, andawex/fsdp_adaptercall sites cover both dense and MoE without duplication.HF→mcore loader
areal/models/mcore/hf_load.py: add_slice_moe_expert_fc1_stacked_gate_upand_slice_moe_expert_fc2_stacked_downthat unpack 3D stacked HF expert tensors ([E, hidden, 2*expert_dim]forgate_up_proj,[E, expert_dim, hidden]fordown_proj) into mcore's per-expert 2D layout. Dispatch from_weight_to_mcore_tpwhen shape rank indicates the stacked HF format.HF saver under EP>1
areal/models/mcore/hf_save.py: add_bridge_uses_stacked_expertspredicate (inferred from the bridge's_MLP_MAPPINGper-expert template) plus a stacked-MoE branch insave_weights_to_hf_with_mbridge_fastthat EP-all-gathers expert tensors per(layer, fc)group, feeds mbridge in global expert-index order so its internal buffer fills exactly once, consolidates writes onep_rank==0, and offloads stacked outputs to CPU memory immediately to bound peak GPU footprint by one stacked tensor at a time. Adjust shard-count math to drop the* ep_sizemultiplication for stacked emission. Per-expert-flat bridges (Qwen3-MoE, BailingMoeV2, DeepSeekV3) keep the existing fast path unchanged.areal/utils/testing_utils.py: registerqwen3_vl_moeinMOE_MODEL_PATHS.Tests added
CPU unit tests (
tests/test_megatron_engine_vlm.py):TestConvertQwen3VLToHF: registry dispatch (qwen3_vlresolves beforeqwen3substring fallback), language model embedding/final-norm/output-layer prefixes, QKV weight + bias split with GQA,q_norm/k_norm/o_proj/input_layernorm/post_attention_layernorm, gated MLP fc1 split + fc2; vision direct mappings, vision per-block (QKV per-head→grouped reorder, attn proj, norm1/norm2 weight + bias, non-gated MLP regression guard, fc2); deepstack mergers parametrized over indices[0, 1, 2]; error handling.TestConvertQwen3VLMoEToHF: registry-ordering invariant, expert fc1 chunk, fc2 passthrough, router rename,pre_mlp_layernormrename, dense MLP fallback (for futuredecoder_sparse_step > 1variants), shared vision tower delegation, unknown-name error paths.TestVisionModelDetection: extended for the qwen3-vl-moe family.Distributed integration tests (
tests/test_megatron_engine_vlm_distributed.py):test_engine_initializes/test_simple_forward/test_hf_save_load_weights/test_train_tensor_parallelover_VLM_MODELS; entriesqwen25_vl,qwen3_vl,qwen3_vl_moe.test_qwen3vl_moe_expert_parallel: forward smoke under(attn:d2t4|ffn:d2e4).test_qwen3vl_moe_dcp_save_load: parameter round-trip under(attn:d2p1t4|ffn:d1p1t2e4).Validation
End-to-end on
Qwen/Qwen3-VL-30B-A3B-Instruct:test_qwen3vl_moe_expert_parallel: PASS.test_qwen3vl_moe_dcp_save_load: PASS.Scope
bridge_type: mbridgeonly. Thebridge_type: megatron-bridgepath with Qwen3-VL +gradient_checkpointing: truecrashes insideQwen3VLTransformerBlock._checkpointed_forwardwithTypeError: save_for_backward can only save variables, but argument 6 is of type list—deepstack_visual_embedsis handed verbatim totensor_parallel.checkpoint. Fixed upstream in megatron-bridge v0.4.0; this PR intentionally does NOT vendor-patch the megatron-bridge path so it lights up automatically when the dependency upgrade lands as chore(deps): upgrade runtime dependencies and CI workflow #1206 plans to do.context_parallel_size > 1) continues to raiseNotImplementedError— matches Qwen2.5-VL state.Qwen3VLMoEBridgefor theqwen3_vl_moemodel_type; AReaL just needs to feed it correctly under EP>1, which is what this PR does.qwen3_vl_moeLoRA path;pixel_values_videosplumbing for streaming video inputs.Related Issue
N/A — net-new feature support. Resubmission of #1299 (closed) extended with MoE.
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A.
Additional Context
Fix for #1298 is required to run integration tests or actual training. This PR was tested with a local patch which is not committed.
Training Reward Example
Image:
ghcr.io/inclusionai/areal-runtime:v1.0.3-vllmwith mbridge upgraded according to #1258Dataset: Geometry3k
Model: Qwen3-VL-3B-Instruct / Qwen3-VL-32B-Instruct / Qwen3-VL-30B-A3B-Instruct
Scheduler: Slurm