Commit 3c24517
committed
jax/moe: address PR #3116 review feedback (hardcode align + expand inline justifications)
Responds to jberchtold-nvidia's PR #3116 review threads on
``transformer_engine/jax/moe.py``. All changes are confined to a
single file because each review thread targets a localized region
and splitting mid-file would risk reordering bugs.
Per review thread:
1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't
seen something like this required for our other VJPs."
-- Expand the helper's docstring to spell out exactly why MoE
needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent
from ep_dispatch_bwd with an fp32 cotangent from
fused_topk_with_score_function_bwd (which the fwd's
logits_2d -> fp32 promotion forces). Without the cast, ``d_x``
surfaces at fp32 even when ``x`` is bf16, doubling activation
grad bandwidth and breaking any downstream LN bwd that pins a
bf16 layout. (Review thread "Why do we need this utility
function?".)
2. "Why is this dtype casting required? I don't recall us needing
it for the non-MoE LNMLP block."
-- Expand the comment above the bwd activation fp32 promotion
to explain the MoE-specific math: LN+MLP's silu sits behind a
downstream LN that absorbs the bf16 rounding error, while
MoE's silu sits on the *expert* side of routing -- the bf16
rounding rides directly into expert_outputs and is summed
across topk experts by ep_combine. Bf16 silu alone drifts ~1%
vs fp32 silu and compounds through wo->combine into the ~1.4%
per-element parity gap we measured against the pure-JAX
softmax reference. Mirroring the fwd's fp32 promotion in the
bwd keeps silu' in lock-step with silu. (Review thread on
"# Activation bwd. Mirror the fwd's fp32 promotion of
silu+multiply".)
3. "Do we have a use-case for user-specified alignments beyond
128 currently? ... it'd make sense to instead hardcode
_ALIGN_SIZE = 128 as a constant at the top of the file for
now to simplify this MoEBlock API. We can always expand the
API to support a user-specified align size in the future."
-- Implement the suggestion. Drop ``align_size`` from
``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public
``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from
``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align
= max(int(align_size), 128)`` with the new module-level
``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring
accordingly. (Review thread on
"natural_spe = num_ep * max_tokens_per_rank".)
4. "Which axis name inputs are physical mesh axes and why can be
logical axes? ... No need to make any changes for now, I just
want to assess which are which and then we can discuss if it
makes sense to support logical on some/all or if some are
required to be physical axes."
-- Add an "Axis-name parameters" section to ``moe()``'s
docstring listing which kwargs are physical mesh axes
(``ep_axis``, ``data_parallelism_axes`` -- they index
``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size``
and to construct the ``P((dp..., ep), None, None)`` for
``jax.lax.with_sharding_constraint``) vs logical axes
(``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``,
``wo_kernel_axes`` -- resolved via the Flax logical-axis
rules). Also document why ``ep_axis`` / ``data_parallelism_axes``
are intentionally non-logical: the EP comm-group construction
(``dp_color = rank // ep_size``) and the bootstrap signature
check both require concrete integer sizes. (Review thread on
"batch_pspec_axis = (*data_parallelism_axes, ep_axis)".)
5. "Is this NaN filtering a debugging artifact or something we
need in the final version?"
-- Strengthen the inline comment above
``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)``
to explicitly call this out as a CORRECTNESS REQUIREMENT, not
a debugging artifact: it covers the sigmoid+K>1 underflow
path where top-K sigmoid scores all round to zero and the
``weights / (weights.sum + 1e-20)`` normalisation emits NaN.
Observationally the filter is a no-op on the dense unit-test
distributions, but it must stay in for sparse / production
routing. (Review thread on
"sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).")
Not addressed in this commit (intentional):
* Review thread on the ``align_size: int = 0`` placeholder in
``flax/moe.py`` ("Placeholder comment for me to fix this so
align_size is inferred automatically based on the recipe and
doesn't need to be specified by the user"). That's
jberchtold's own follow-up.
* Review thread on the explicit ``tree_flatten`` /
``tree_unflatten`` on ``_Ctx`` ("better to use the
``@flax_struct.dataclass``"). Deferred to a separate, testable
commit because changing a ``custom_vjp`` residual's pytree
registration touches subtle ordering / None-handling semantics
that warrant their own bisect surface.
* Review thread on ``use_bias`` / ``use_expert_bias`` renames --
handled in the immediately preceding commit
``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``.
* Review thread on the ``expert_bias`` fp32 init -- already
resolved during the Phuong PR #3036 resync (the redundant
``jnp.float32`` second-dtype argument on ``self.param`` was
dropped; ``expert_bias`` now lives at ``self.dtype``).
Signed-off-by: tdophung <tdophung@nvidia.com>1 parent 5e524d0 commit 3c24517
1 file changed
Lines changed: 86 additions & 29 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
58 | 58 | | |
59 | 59 | | |
60 | 60 | | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
61 | 75 | | |
62 | 76 | | |
63 | 77 | | |
64 | 78 | | |
65 | | - | |
| 79 | + | |
66 | 80 | | |
67 | 81 | | |
68 | 82 | | |
69 | 83 | | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
70 | 99 | | |
71 | 100 | | |
72 | 101 | | |
| |||
524 | 553 | | |
525 | 554 | | |
526 | 555 | | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
527 | 569 | | |
528 | 570 | | |
529 | 571 | | |
| |||
606 | 648 | | |
607 | 649 | | |
608 | 650 | | |
609 | | - | |
610 | 651 | | |
611 | 652 | | |
612 | 653 | | |
| |||
651 | 692 | | |
652 | 693 | | |
653 | 694 | | |
654 | | - | |
655 | | - | |
656 | | - | |
657 | | - | |
| 695 | + | |
| 696 | + | |
658 | 697 | | |
659 | 698 | | |
660 | 699 | | |
| |||
704 | 743 | | |
705 | 744 | | |
706 | 745 | | |
707 | | - | |
708 | | - | |
709 | | - | |
| 746 | + | |
| 747 | + | |
| 748 | + | |
| 749 | + | |
| 750 | + | |
710 | 751 | | |
711 | 752 | | |
712 | 753 | | |
713 | | - | |
714 | | - | |
715 | | - | |
| 754 | + | |
| 755 | + | |
| 756 | + | |
| 757 | + | |
| 758 | + | |
716 | 759 | | |
717 | 760 | | |
718 | 761 | | |
| |||
963 | 1006 | | |
964 | 1007 | | |
965 | 1008 | | |
966 | | - | |
967 | 1009 | | |
968 | 1010 | | |
969 | 1011 | | |
970 | 1012 | | |
971 | | - | |
| 1013 | + | |
972 | 1014 | | |
973 | 1015 | | |
974 | 1016 | | |
| |||
1243 | 1285 | | |
1244 | 1286 | | |
1245 | 1287 | | |
1246 | | - | |
| 1288 | + | |
1247 | 1289 | | |
1248 | 1290 | | |
1249 | 1291 | | |
| |||
1271 | 1313 | | |
1272 | 1314 | | |
1273 | 1315 | | |
1274 | | - | |
1275 | 1316 | | |
1276 | 1317 | | |
1277 | 1318 | | |
| |||
1300 | 1341 | | |
1301 | 1342 | | |
1302 | 1343 | | |
1303 | | - | |
1304 | 1344 | | |
1305 | 1345 | | |
1306 | 1346 | | |
| |||
1329 | 1369 | | |
1330 | 1370 | | |
1331 | 1371 | | |
1332 | | - | |
1333 | 1372 | | |
1334 | 1373 | | |
1335 | 1374 | | |
| |||
1357 | 1396 | | |
1358 | 1397 | | |
1359 | 1398 | | |
1360 | | - | |
1361 | | - | |
1362 | | - | |
1363 | | - | |
1364 | | - | |
1365 | | - | |
1366 | | - | |
1367 | | - | |
1368 | | - | |
1369 | | - | |
| 1399 | + | |
| 1400 | + | |
| 1401 | + | |
| 1402 | + | |
| 1403 | + | |
| 1404 | + | |
| 1405 | + | |
| 1406 | + | |
| 1407 | + | |
| 1408 | + | |
| 1409 | + | |
| 1410 | + | |
| 1411 | + | |
| 1412 | + | |
| 1413 | + | |
| 1414 | + | |
| 1415 | + | |
| 1416 | + | |
| 1417 | + | |
| 1418 | + | |
| 1419 | + | |
| 1420 | + | |
| 1421 | + | |
| 1422 | + | |
| 1423 | + | |
| 1424 | + | |
| 1425 | + | |
| 1426 | + | |
| 1427 | + | |
1370 | 1428 | | |
1371 | 1429 | | |
1372 | 1430 | | |
| |||
1428 | 1486 | | |
1429 | 1487 | | |
1430 | 1488 | | |
1431 | | - | |
1432 | 1489 | | |
1433 | 1490 | | |
1434 | 1491 | | |
| |||
0 commit comments