Commit bd8b9ea
[megatron] Fix TE sliding-window attention crash at micro_batch>1 on the non-packed forward path (#1848)
# [megatron] Fix TE sliding-window attention crash at micro_batch>1 on
the non-packed forward path
## Summary
On the non-packed Megatron forward path
(`remove_microbatch_padding=False`),
`megatron_model_wrapper.py` passes the 2-D `[batch, seq]` keep-mask
returned by
`remove_left_padding()` straight to `model(...)`. For models whose
attention uses a Transformer
Engine **sliding-window** mask, TE's `get_full_mask` combines that mask
with the SWA mask via
`torch.logical_or(swa_mask, attention_mask)`. A 2-D `[batch, seq]` mask
is not broadcastable against
the `[..., seq, seq]` SWA mask, so at **micro_batch_size > 1** the batch
dimension collides with the
sequence dimension and the forward crashes.
This adds a pure-torch helper `to_te_attention_mask` that converts the
2-D keep-mask to a
`[batch, 1, 1, seq]` padding mask (`True` marks padding), and routes the
`model(...)` call sites
through it. The original 2-D mask is left untouched for the downstream
`recover_left_padding()`.
## Reproduction (current `main`, `0a18dd72`)
Real model **`Qwen/Qwen3.6-35B-A3B`** (hybrid GatedDeltaNet + MoE; its
full-attention layers use TE
sliding-window attention), `tp=2, pp=1`, `micro_batch_size=2`,
`remove_microbatch_padding=False`,
left-padded batch:
- **Stock `main` crashes** in the megatron forward:
```
transformer_engine/pytorch/attention/dot_product_attention/utils.py:1363,
in get_full_mask
attention_mask = torch.logical_or(swa_mask, attention_mask)
RuntimeError: The size of tensor a (24) must match the size of tensor b
(2) at non-singleton dimension 2
```
(`24` = sequence length, `2` = `micro_batch_size`.)
- **With this patch** the same forward completes and returns logprobs —
test passes.
- **Regression:** a standard non-SWA model (`Qwen/Qwen3-0.6B`,
non-packed, `micro_batch_size=2`,
left-padded) passes the HF-vs-megatron parity test **identically on
stock and patched** — the
helper is a no-op there.
Lifting `micro_batch_size` from 1 to 4 on this path (once unblocked)
increased training throughput
~3.3× at unchanged reward in our multi-LoRA RL runs.
## Changes
- Add `skyrl/backends/skyrl_train/distributed/megatron/mask_utils.py`
with `to_te_attention_mask`.
- Route both `model(...)` call sites in `megatron_model_wrapper.py` (4
invocations after #1841)
through the helper.
- Add `tests/backends/skyrl_train/distributed/test_mask_utils.py`.
## Design notes
- The reshape lives in a standalone pure-torch helper so it is
unit-testable on CPU without the
`megatron` extra.
- No-op on the packed path: there `new_attention_mask` is `None` and
`packed_seq_params` drives
masking, so the helper returns the input unchanged. An already
higher-rank mask is also returned
unchanged, so non-SWA TE attention (which already accepts the 2-D mask)
is unaffected.
- Not folded into `remove_left_padding()` because that function's 2-D
return is also consumed by
`recover_left_padding()`, which needs the 2-D form.
- `True = padding` matches TE's mask convention. GatedDeltaNet /
linear-attention layers ignore
`attention_mask`. No config or API change; behavior is identical at
`micro_batch_size = 1`.
## Testing
- Unit (CPU): `uv run --extra dev pytest
tests/backends/skyrl_train/distributed/test_mask_utils.py`
— 5 passed.
- Integration (GPU, run on 2× B300):
`test_megatron_forward[tp2_pp1_policy]` with
`Qwen/Qwen3.6-35B-A3B` — **fails on stock** at `get_full_mask`, **passes
with this patch**;
`Qwen/Qwen3-0.6B` passes on both (regression-safe).
## Checklist
- [x] Follows Google Python style; comments describe what the code does.
- [x] CPU unit test added; GPU reproduction + fix verified on
Qwen3.6-35B-A3B.
- [x] No paths/naming changes affecting `.claude/` docs; no
example-script or doc changes needed.
- [x] `bash format.sh` (ruff 0.11.9 + black 24.10.0 + gitleaks) clean —
verified, zero auto-fixes.
---------
Co-authored-by: Eric Tang <erictang000@gmail.com>1 parent 19b0e13 commit bd8b9ea
2 files changed
Lines changed: 20 additions & 4 deletions
File tree
- skyrl/backends/skyrl_train
- distributed/megatron
- workers/megatron
Lines changed: 15 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
658 | 658 | | |
659 | 659 | | |
660 | 660 | | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
Lines changed: 5 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| |||
311 | 312 | | |
312 | 313 | | |
313 | 314 | | |
314 | | - | |
| 315 | + | |
315 | 316 | | |
316 | 317 | | |
317 | 318 | | |
| |||
321 | 322 | | |
322 | 323 | | |
323 | 324 | | |
324 | | - | |
| 325 | + | |
325 | 326 | | |
326 | 327 | | |
327 | 328 | | |
| |||
748 | 749 | | |
749 | 750 | | |
750 | 751 | | |
751 | | - | |
| 752 | + | |
752 | 753 | | |
753 | 754 | | |
754 | 755 | | |
| |||
758 | 759 | | |
759 | 760 | | |
760 | 761 | | |
761 | | - | |
| 762 | + | |
762 | 763 | | |
763 | 764 | | |
764 | 765 | | |
| |||
0 commit comments