Skip to content

[dev] [follow-up] Qwen3.5 support: MoE aux loss padding_mask#4776

Draft
wplf wants to merge 2 commits into
NVIDIA:devfrom
wplf:fix/moe-padding-mask
Draft

[dev] [follow-up] Qwen3.5 support: MoE aux loss padding_mask#4776
wplf wants to merge 2 commits into
NVIDIA:devfrom
wplf:fix/moe-padding-mask

Conversation

@wplf
Copy link
Copy Markdown
Member

@wplf wplf commented May 13, 2026

Qwen3.5 support series

This is a follow-up to the 5-PR series adding Qwen3.5-VL support; it lands on top of #4751 (the example), not the core changes.

Dev PRs:

Main mirror: opened separately against main as a sibling.

Why

In examples/multimodal_dev/models/base.py, MultimodalModel.forward() was calling self.language_model(...) without a padding_mask argument. GPTModel.forward() documents padding_mask as "Only used for MoE layers to exclude padding tokens from routing computations." With padding_mask=None the router skips three masking sites that affect MoE numerics:

Router site (megatron/core/transformer/moe/router.py) Effect when padding_mask=None
apply_z_loss(logits, padding_mask=None) (line 526) z-loss averages over all tokens including padding
_apply_aux_loss(..., with_padding_mask=padding_mask is not None) (line 736) aux load-balancing loss computed over all tokens
_apply_expert_bias(routing_map, padding_mask=None) (line 604) expert-bias EMA accumulates padded-token routing

Both data paths in pack_or_pad_batch (forward_step.py) introduce padding:

  • BSHD (line 272-275): every sample padded to target_seqlens with input_ids=0, labels=-100, loss_mask=0.
  • THD packed (line 224-227): each sample padded so its length is a multiple of divisible_by; cu_seqlens_q_padded differs from cu_seqlens_q.

For MoE variants (proxy, 35b_a3b, 35b_a3b_light, 122b_a10b, 397b_a17b) this means the load-balancing signal is diluted by padded positions whose router logits don't reflect any real token.

What this PR does

  1. Build padding_mask at collate time in both branches of pack_or_pad_batch:
    • BSHD: [B, target_seqlens] bool, True past each sample's real length.
    • THD: [1, T] bool, True between cu_seqlens_padded[i] + real_len[i] and cu_seqlens_padded[i+1].
  2. Thread it through forward_stepMultimodalModel.forward_cp_split_for_forwardlanguage_model.forward, mirroring how loss_mask is handled. CP split uses the same BSHD zigzag / THD tex.thd_get_partitioned_indices index as the rest.

Why not derive from loss_mask or labels == -100

For SFT data, prompt tokens carry loss_mask=0 / label=-100 but are real tokens that should still participate in routing. Folding them into padding_mask would under-route real activations — a different bug. The collate-time mask only marks tokens added by the padder.

Dependency

Depends on #4751 (PR-5: the Qwen3.5-VL example) — MultimodalModel and pack_or_pad_batch are introduced there. The diff vs dev therefore shows PR-5 + this fix; reviewers should compare against #4751's tip for the isolated padding_mask delta.

Risk

  • Dense variants unaffected — padding_mask is only consumed inside MoE layers.
  • _cp_split_for_forward and MultimodalModel.forward gain one optional kwarg with None default; existing callers unaffected.

🤖 Generated with Claude Code

Adds a standalone VLM training playground under
``examples/multimodal_dev/`` with Qwen3.5-VL end-to-end.

Highlights
- Model-agnostic entry point (``pretrain_multimodal.py``) with a
  ``MODEL_REGISTRY`` so adding a new architecture is just a registry
  entry plus a backing module.
- Qwen3.5-VL model: vision encoder, MRoPE, decoder, factory, specs,
  configurations covering proxy / 9B / 397B-A17B variants.
- Datasets: mock data and CORD-V2 VLM dataset, with THD pack/pad in the
  collate function.
- THD + CP support consolidated in ``forward_step.py`` and the model
  layer (uses MRoPE THD pre-computation and ``cu_seqlens_q_padded`` CP
  partitioning).
- Run script + README, plus tests for MRoPE parity, CP correctness, CP
  support, and THD correctness / e2e.

Also gates the torch DataLoader vanilla-collate path on the new
``use_vanilla_collate_fn`` arg (one-line change to
``megatron/training/datasets/data_samplers.py``) so CORD-V2 works under
BSHD.

Functional dependency: the new model arch sets ``mrope_interleaved=True``
in its config and relies on the core MRoPE interleaved layout introduced
in a separate PR.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@wplf wplf added the Run tests label May 13, 2026
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

…ecoder

`MultimodalModel.forward()` was calling `self.language_model(...)`
without a `padding_mask` argument. With `padding_mask=None`, the
language decoder's MoE router skips three masking sites that document
themselves as MoE-only:

  - `apply_z_loss(logits, padding_mask=None)` — z-loss averages over
    all tokens including collate padding.
  - `_apply_aux_loss(..., with_padding_mask=padding_mask is not None)`
    — aux load-balancing loss is computed over all tokens, so padded
    positions dilute the signal and bias balancing toward whatever the
    model emits for input_id=0 at those slots.
  - `_apply_expert_bias(routing_map, padding_mask=None)` — expert-bias
    EMA accumulates routing statistics from padded tokens.

Both code paths in `pack_or_pad_batch` introduce padded positions:
BSHD pads each sample to `target_seqlens` with input_id=0 /
label=-100 / loss_mask=0; THD pads each sample's length to a multiple
of `divisible_by` so `cu_seqlens_q_padded` differs from `cu_seqlens_q`.

Fix:

  - Build `padding_mask` at collate time in both branches (BSHD: ``[B,
    target_seqlens]``; THD: ``[1, T]``). True marks collate-padded
    positions only — distinct from `loss_mask` so SFT prompt tokens
    (which carry `loss_mask=0` but are real tokens) still participate
    in routing.
  - Thread `padding_mask` through `forward_step` → `MultimodalModel.forward`
    → `_cp_split_for_forward` → `GPTModel.forward`, mirroring how
    `loss_mask` is handled. CP split uses the same BSHD zigzag /
    THD `tex.thd_get_partitioned_indices` index as the other tensors.

Dense Qwen3.5-VL variants are unaffected — `padding_mask` is only
consumed inside MoE layers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Co-Authored-By: BestJuly <19769279+BestJuly@users.noreply.github.com>
@wplf wplf force-pushed the fix/moe-padding-mask branch from 5418b65 to af7d670 Compare May 13, 2026 10:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant