Commit 2b8763e
Pad MoE expert input to multiple of 32 for MXFP8 compatibility
After all-to-all dispatch in the MoE block, the per-rank token count is
data-dependent (routing decisions produce different expert loads per step).
MXFP8 requires tensor dims divisible by 32, and FP8 requires product of
non-last dims divisible by 8 - these assertions fire on the post-dispatch
expert_input when the batch token count happens to land on an unaligned
value, causing training to hang or crash.
Pad the token dimension to the next multiple of 32 before GroupedLinear,
attribute the padding to the last expert so m_splits still sums correctly,
then slice the padding off the output. Branch is a no-op for non-MXFP8
runs and when the count is already aligned.
Upstream attention layers get alignment via the collator's
pad_sequences_to_be_divisible_by config; this patch only addresses the
MoE block where alltoall creates a second source of misalignment.
Verified on 8x B300 SXM6 with Mixtral-8x7B EP=8 at SEQ=8192:
FP8 1.196 s/step, MXFP8 1.248 s/step (previously hung/crashed).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Timur Rvachov <trvachov@nvidia.com>1 parent 39faaca commit 2b8763e
3 files changed
Lines changed: 66 additions & 3 deletions
File tree
- bionemo-recipes
- models/mixtral
- recipes
- mixtral_native_te
- opengenome2_mixtral_native_te
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
409 | 409 | | |
410 | 410 | | |
411 | 411 | | |
412 | | - | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
413 | 434 | | |
414 | 435 | | |
415 | 436 | | |
| |||
Lines changed: 22 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
415 | 415 | | |
416 | 416 | | |
417 | 417 | | |
418 | | - | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
419 | 440 | | |
420 | 441 | | |
421 | 442 | | |
| |||
Lines changed: 22 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
409 | 409 | | |
410 | 410 | | |
411 | 411 | | |
412 | | - | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
413 | 434 | | |
414 | 435 | | |
415 | 436 | | |
| |||
0 commit comments