Skip to content

Commit bd8b9ea

Browse files
[megatron] Fix TE sliding-window attention crash at micro_batch>1 on the non-packed forward path (#1848)
# [megatron] Fix TE sliding-window attention crash at micro_batch>1 on the non-packed forward path ## Summary On the non-packed Megatron forward path (`remove_microbatch_padding=False`), `megatron_model_wrapper.py` passes the 2-D `[batch, seq]` keep-mask returned by `remove_left_padding()` straight to `model(...)`. For models whose attention uses a Transformer Engine **sliding-window** mask, TE's `get_full_mask` combines that mask with the SWA mask via `torch.logical_or(swa_mask, attention_mask)`. A 2-D `[batch, seq]` mask is not broadcastable against the `[..., seq, seq]` SWA mask, so at **micro_batch_size > 1** the batch dimension collides with the sequence dimension and the forward crashes. This adds a pure-torch helper `to_te_attention_mask` that converts the 2-D keep-mask to a `[batch, 1, 1, seq]` padding mask (`True` marks padding), and routes the `model(...)` call sites through it. The original 2-D mask is left untouched for the downstream `recover_left_padding()`. ## Reproduction (current `main`, `0a18dd72`) Real model **`Qwen/Qwen3.6-35B-A3B`** (hybrid GatedDeltaNet + MoE; its full-attention layers use TE sliding-window attention), `tp=2, pp=1`, `micro_batch_size=2`, `remove_microbatch_padding=False`, left-padded batch: - **Stock `main` crashes** in the megatron forward: ``` transformer_engine/pytorch/attention/dot_product_attention/utils.py:1363, in get_full_mask attention_mask = torch.logical_or(swa_mask, attention_mask) RuntimeError: The size of tensor a (24) must match the size of tensor b (2) at non-singleton dimension 2 ``` (`24` = sequence length, `2` = `micro_batch_size`.) - **With this patch** the same forward completes and returns logprobs — test passes. - **Regression:** a standard non-SWA model (`Qwen/Qwen3-0.6B`, non-packed, `micro_batch_size=2`, left-padded) passes the HF-vs-megatron parity test **identically on stock and patched** — the helper is a no-op there. Lifting `micro_batch_size` from 1 to 4 on this path (once unblocked) increased training throughput ~3.3× at unchanged reward in our multi-LoRA RL runs. ## Changes - Add `skyrl/backends/skyrl_train/distributed/megatron/mask_utils.py` with `to_te_attention_mask`. - Route both `model(...)` call sites in `megatron_model_wrapper.py` (4 invocations after #1841) through the helper. - Add `tests/backends/skyrl_train/distributed/test_mask_utils.py`. ## Design notes - The reshape lives in a standalone pure-torch helper so it is unit-testable on CPU without the `megatron` extra. - No-op on the packed path: there `new_attention_mask` is `None` and `packed_seq_params` drives masking, so the helper returns the input unchanged. An already higher-rank mask is also returned unchanged, so non-SWA TE attention (which already accepts the 2-D mask) is unaffected. - Not folded into `remove_left_padding()` because that function's 2-D return is also consumed by `recover_left_padding()`, which needs the 2-D form. - `True = padding` matches TE's mask convention. GatedDeltaNet / linear-attention layers ignore `attention_mask`. No config or API change; behavior is identical at `micro_batch_size = 1`. ## Testing - Unit (CPU): `uv run --extra dev pytest tests/backends/skyrl_train/distributed/test_mask_utils.py` — 5 passed. - Integration (GPU, run on 2× B300): `test_megatron_forward[tp2_pp1_policy]` with `Qwen/Qwen3.6-35B-A3B` — **fails on stock** at `get_full_mask`, **passes with this patch**; `Qwen/Qwen3-0.6B` passes on both (regression-safe). ## Checklist - [x] Follows Google Python style; comments describe what the code does. - [x] CPU unit test added; GPU reproduction + fix verified on Qwen3.6-35B-A3B. - [x] No paths/naming changes affecting `.claude/` docs; no example-script or doc changes needed. - [x] `bash format.sh` (ruff 0.11.9 + black 24.10.0 + gitleaks) clean — verified, zero auto-fixes. --------- Co-authored-by: Eric Tang <erictang000@gmail.com>
1 parent 19b0e13 commit bd8b9ea

2 files changed

Lines changed: 20 additions & 4 deletions

File tree

skyrl/backends/skyrl_train/distributed/megatron/megatron_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,3 +658,18 @@ def broadcast_object_across_pp_ranks(obj):
658658
torch.distributed.broadcast_object_list(obj_list, src=global_src, group=pp_group)
659659

660660
return obj_list[0]
661+
662+
663+
def to_te_attention_mask(attention_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
664+
"""Convert a 2-D keep-mask to a 4-D padding mask for Transformer Engine attention.
665+
666+
``remove_left_padding`` returns a 2-D ``[batch, seq]`` keep-mask where 1 marks a real token.
667+
Transformer Engine's sliding-window ``get_full_mask`` expects a mask broadcastable to
668+
``[batch, 1, q_seq, kv_seq]``; a 2-D mask collides the batch dimension with the sequence
669+
dimension and fails for ``micro_batch_size > 1``. Return a ``[batch, 1, 1, seq]`` padding mask
670+
where True marks padding. A ``None`` mask (packed sequences) or an already higher-rank mask is
671+
returned unchanged.
672+
"""
673+
if attention_mask is None or attention_mask.dim() != 2:
674+
return attention_mask
675+
return (~attention_mask.bool())[:, None, None, :]

skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
preprocess_packed_seqs,
1717
recover_left_padding,
1818
remove_left_padding,
19+
to_te_attention_mask,
1920
)
2021
from skyrl.backends.skyrl_train.distributed.megatron.model_utils import (
2122
_fused_vocab_parallel_entropy_from_hidden,
@@ -311,7 +312,7 @@ def forward_step(batch_iter, model):
311312
outputs = model(
312313
new_sequences,
313314
new_position_ids,
314-
new_attention_mask,
315+
to_te_attention_mask(new_attention_mask),
315316
packed_seq_params=packed_seq_params,
316317
output_processor=_fused_lm_head_output_processor,
317318
output_processor_context=_op_ctx,
@@ -321,7 +322,7 @@ def forward_step(batch_iter, model):
321322
outputs = model(
322323
new_sequences,
323324
new_position_ids,
324-
new_attention_mask,
325+
to_te_attention_mask(new_attention_mask),
325326
packed_seq_params=packed_seq_params,
326327
)
327328

@@ -748,7 +749,7 @@ def forward_step(batch_iter, model):
748749
outputs = model(
749750
new_sequences,
750751
new_position_ids,
751-
new_attention_mask,
752+
to_te_attention_mask(new_attention_mask),
752753
packed_seq_params=packed_seq_params,
753754
output_processor=_fused_lm_head_output_processor,
754755
output_processor_context=_op_ctx,
@@ -758,7 +759,7 @@ def forward_step(batch_iter, model):
758759
outputs = model(
759760
new_sequences,
760761
new_position_ids,
761-
new_attention_mask,
762+
to_te_attention_mask(new_attention_mask),
762763
packed_seq_params=packed_seq_params,
763764
)
764765

0 commit comments

Comments
 (0)