Commit fecb0ed
committed
[JAX] MoE: cherry-pick 3 independent fixes from jberchtold/te_ep_integration
Pull three small, orthogonal correctness improvements from jberchtold's
parallel work on teddy/te_ep_integration that don't touch the FFN
shard_map or our dispatch zero-init workaround:
1. ``effective_align = max(align_size, 128)`` floor on the per-rank
receive slots in ``moe.py``. NCCL EP requires each expert-major
output block to be at least 128-token aligned; the previous
``align_size > 0`` branch could emit a smaller natural block on
tiny configs and trip the dispatch buffer check. (df61642)
2. Size-1-axis guard in ``_ep_outer_axis()`` in both
``cpp_extensions/ep.py`` and ``ep.py``. A dp/fsdp axis that is
sized 1 in the active mesh is now treated as absent so we don't
pin EP-output specs to a degenerate axis that JAX may silently
collapse. Mirrored the helper into ``ep.py`` so both files share
the same predicate. (2210702)
3. ``_with_sharding_constraint_cast_bwd`` custom-VJP wrapper in
``moe.py``, applied to the inbound activation re-pin. Keeps the
bwd cotangent in the primal dtype and re-asserts the same
sharding on the bwd path, instead of letting a wider gradient
land back at the caller. (2210702)
Deliberately deferred: his shard_map removal + new global-view
FFN call sites in ``2210702a``'s ``moe.py`` rewrite. Those depend
on the grouped-GEMM custom partitioning landing on main and are a
later-phase integration sweep.1 parent 42db5b6 commit fecb0ed
3 files changed
Lines changed: 52 additions & 9 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
| 27 | + | |
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
| |||
187 | 187 | | |
188 | 188 | | |
189 | 189 | | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
190 | 193 | | |
191 | 194 | | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
192 | 199 | | |
193 | 200 | | |
194 | 201 | | |
| |||
536 | 543 | | |
537 | 544 | | |
538 | 545 | | |
539 | | - | |
| 546 | + | |
540 | 547 | | |
541 | 548 | | |
542 | 549 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
229 | 229 | | |
230 | 230 | | |
231 | 231 | | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
232 | 242 | | |
233 | 243 | | |
234 | 244 | | |
235 | 245 | | |
236 | 246 | | |
237 | 247 | | |
238 | 248 | | |
239 | | - | |
| 249 | + | |
240 | 250 | | |
241 | 251 | | |
242 | 252 | | |
| |||
315 | 325 | | |
316 | 326 | | |
317 | 327 | | |
318 | | - | |
| 328 | + | |
319 | 329 | | |
320 | 330 | | |
321 | 331 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
83 | 83 | | |
84 | 84 | | |
85 | 85 | | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
86 | 111 | | |
87 | 112 | | |
88 | 113 | | |
| |||
631 | 656 | | |
632 | 657 | | |
633 | 658 | | |
634 | | - | |
635 | | - | |
636 | | - | |
637 | | - | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
638 | 664 | | |
639 | 665 | | |
640 | 666 | | |
| |||
1406 | 1432 | | |
1407 | 1433 | | |
1408 | 1434 | | |
1409 | | - | |
| 1435 | + | |
1410 | 1436 | | |
1411 | 1437 | | |
1412 | 1438 | | |
| |||
0 commit comments