Skip to content

Commit 3c24517

Browse files
committed
jax/moe: address PR #3116 review feedback (hardcode align + expand inline justifications)
Responds to jberchtold-nvidia's PR #3116 review threads on ``transformer_engine/jax/moe.py``. All changes are confined to a single file because each review thread targets a localized region and splitting mid-file would risk reordering bugs. Per review thread: 1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't seen something like this required for our other VJPs." -- Expand the helper's docstring to spell out exactly why MoE needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent from ep_dispatch_bwd with an fp32 cotangent from fused_topk_with_score_function_bwd (which the fwd's logits_2d -> fp32 promotion forces). Without the cast, ``d_x`` surfaces at fp32 even when ``x`` is bf16, doubling activation grad bandwidth and breaking any downstream LN bwd that pins a bf16 layout. (Review thread "Why do we need this utility function?".) 2. "Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block." -- Expand the comment above the bwd activation fp32 promotion to explain the MoE-specific math: LN+MLP's silu sits behind a downstream LN that absorbs the bf16 rounding error, while MoE's silu sits on the *expert* side of routing -- the bf16 rounding rides directly into expert_outputs and is summed across topk experts by ep_combine. Bf16 silu alone drifts ~1% vs fp32 silu and compounds through wo->combine into the ~1.4% per-element parity gap we measured against the pure-JAX softmax reference. Mirroring the fwd's fp32 promotion in the bwd keeps silu' in lock-step with silu. (Review thread on "# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply".) 3. "Do we have a use-case for user-specified alignments beyond 128 currently? ... it'd make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API. We can always expand the API to support a user-specified align size in the future." -- Implement the suggestion. Drop ``align_size`` from ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align = max(int(align_size), 128)`` with the new module-level ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring accordingly. (Review thread on "natural_spe = num_ep * max_tokens_per_rank".) 4. "Which axis name inputs are physical mesh axes and why can be logical axes? ... No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes." -- Add an "Axis-name parameters" section to ``moe()``'s docstring listing which kwargs are physical mesh axes (``ep_axis``, ``data_parallelism_axes`` -- they index ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size`` and to construct the ``P((dp..., ep), None, None)`` for ``jax.lax.with_sharding_constraint``) vs logical axes (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``, ``wo_kernel_axes`` -- resolved via the Flax logical-axis rules). Also document why ``ep_axis`` / ``data_parallelism_axes`` are intentionally non-logical: the EP comm-group construction (``dp_color = rank // ep_size``) and the bootstrap signature check both require concrete integer sizes. (Review thread on "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".) 5. "Is this NaN filtering a debugging artifact or something we need in the final version?" -- Strengthen the inline comment above ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)`` to explicitly call this out as a CORRECTNESS REQUIREMENT, not a debugging artifact: it covers the sigmoid+K>1 underflow path where top-K sigmoid scores all round to zero and the ``weights / (weights.sum + 1e-20)`` normalisation emits NaN. Observationally the filter is a no-op on the dense unit-test distributions, but it must stay in for sparse / production routing. (Review thread on "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).") Not addressed in this commit (intentional): * Review thread on the ``align_size: int = 0`` placeholder in ``flax/moe.py`` ("Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user"). That's jberchtold's own follow-up. * Review thread on the explicit ``tree_flatten`` / ``tree_unflatten`` on ``_Ctx`` ("better to use the ``@flax_struct.dataclass``"). Deferred to a separate, testable commit because changing a ``custom_vjp`` residual's pytree registration touches subtle ordering / None-handling semantics that warrant their own bisect surface. * Review thread on ``use_bias`` / ``use_expert_bias`` renames -- handled in the immediately preceding commit ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``. * Review thread on the ``expert_bias`` fp32 init -- already resolved during the Phuong PR #3036 resync (the redundant ``jnp.float32`` second-dtype argument on ``self.param`` was dropped; ``expert_bias`` now lives at ``self.dtype``). Signed-off-by: tdophung <tdophung@nvidia.com>
1 parent 5e524d0 commit 3c24517

1 file changed

Lines changed: 86 additions & 29 deletions

File tree

transformer_engine/jax/moe.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,44 @@
5858
__all__ = ["moe"]
5959

6060

61+
# Per-expert dispatch-slot alignment fed to ``tex.ep_prepare`` as
62+
# ``dispatch_output_per_expert_alignment``. NCCL EP HT requires the
63+
# per-expert recv block to be at least 128-token aligned, and all current
64+
# TE grouped-GEMM recipes (bf16/fp16/fp8/mxfp8) are satisfied by a
65+
# 128-token tile, so a single hard-coded constant suffices.
66+
#
67+
# We deliberately omit a user-facing knob: per PR #3116 review there's no
68+
# current model that wants a >128 alignment, and exposing it widens the
69+
# MoEBlock API surface without buying anything. Re-introduce a parameter
70+
# (or recipe-driven inference, see jberchtold's follow-up) if a future
71+
# recipe needs >128.
72+
_ALIGN_SIZE = 128
73+
74+
6175
def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray:
6276
"""Apply a sharding constraint while keeping bwd cotangents in the primal dtype.
6377
6478
Plain ``jax.lax.with_sharding_constraint`` propagates cotangents in
65-
whatever dtype the upstream gradient lands in; under mixed precision
79+
whatever dtype the upstream gradient lands in. Under mixed precision
6680
that can be wider than the primal, blowing up bandwidth and (for
6781
bf16 primals) breaking downstream kernels that pin a bf16 input
6882
layout. This wrapper re-casts the cotangent back to the primal
6983
dtype and re-asserts the same sharding on the bwd path.
84+
85+
Why MoE specifically needs this (per PR #3116 review): unlike a
86+
plain LN+MLP block, the MoE bwd composes two cotangent paths into
87+
``d_x`` -- one through ``ep_dispatch_bwd`` (bf16) and one through
88+
``d_logits_2d @ gate_kernel.T``. The latter starts from
89+
``fused_topk_with_score_function_bwd``, which returns ``d_logits_2d``
90+
in fp32 because the fwd promoted ``logits_2d`` to fp32 (the topk /
91+
softmax / sigmoid kernels are only validated at fp32; see
92+
``tests/pytorch/test_fused_router.py``). The fp32 ``d_logits_2d``
93+
then composes with ``gate_kernel.T`` and adds into the bf16
94+
``d_x_from_dispatch``, yielding an fp32 sum even though the user's
95+
``x`` is bf16. Without this cast, the user-visible ``d_x`` flows
96+
back into the optimizer at fp32 -- silently doubling the activation
97+
grad bandwidth and tripping any downstream kernel that pins a bf16
98+
input layout (e.g. an LN bwd that fuses into our ``d_x``).
7099
"""
71100

72101
@jax.custom_vjp
@@ -524,6 +553,19 @@ def _ffn_bwd_per_shard(
524553
# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply
525554
# so the silu derivative composes through the gradient at fp32 too;
526555
# cast back to the bf16 layout the wi grouped_quantize expects.
556+
#
557+
# Why MoE specifically needs this (per PR #3116 review): the
558+
# non-MoE LN+MLP block can stay in the activation dtype because
559+
# its silu accumulates over a single per-row dot product whose
560+
# numerical drift is absorbed by the downstream LN. The MoE
561+
# silu, by contrast, sits on the *expert* side of the routing,
562+
# so its bf16 rounding error rides directly into ``expert_outputs``
563+
# and is summed (weighted by routing probs) across topk experts
564+
# by ep_combine -- bf16 silu alone drifts ~1% vs fp32 silu, which
565+
# compounds through wo->combine into the ~1.4% per-element parity
566+
# gap we measured against the pure-JAX softmax reference. Mirroring
567+
# the fwd fp32 promotion keeps the bwd's silu' derivative in lock-
568+
# step with the fwd's silu and preserves grad parity.
527569
gp_fp32 = gate_proj_out.astype(jnp.float32)
528570
up_fp32 = up_proj_out.astype(jnp.float32)
529571
d_int_fp32 = d_intermediate.astype(jnp.float32)
@@ -606,7 +648,6 @@ def _moe_fwd_rule(
606648
wo_kernel_axes,
607649
dtype,
608650
apply_topk_weights_early,
609-
align_size,
610651
):
611652
"""Forward: gate -> topk -> ep_dispatch -> shard_map(FFN) -> ep_combine.
612653
@@ -651,10 +692,8 @@ def _moe_fwd_rule(
651692
# rejects the dispatch buffer with ``invalid argument``.
652693
natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S
653694
# NCCL EP requires each expert-major output block to be at least
654-
# 128-token aligned. Keep larger caller-requested alignments, but
655-
# do not emit a smaller natural block size for tiny tests.
656-
effective_align = max(int(align_size), 128)
657-
slots_per_expert = ((natural_spe + effective_align - 1) // effective_align) * effective_align
695+
# ``_ALIGN_SIZE`` (=128) tokens; see the constant's docstring.
696+
slots_per_expert = ((natural_spe + _ALIGN_SIZE - 1) // _ALIGN_SIZE) * _ALIGN_SIZE
658697
recv_pr = num_local_experts * slots_per_expert
659698

660699
_te_ep_assert_compatible_bootstrap(
@@ -704,15 +743,19 @@ def _moe_fwd_rule(
704743
expert_bias=eb_arg,
705744
compute_aux_scores=False,
706745
)
707-
# Sigmoid + K>1 normalises as `weights / (weights.sum + 1e-20)`; for
708-
# tokens whose top-K sigmoid scores all underflow at bf16/fp32 the
709-
# output is NaN at the selected positions. Those NaNs ride
746+
# NOTE (PR #3116 review): this NaN filter is a *correctness
747+
# requirement*, NOT a debugging artifact. Sigmoid + K>1 normalises
748+
# as ``weights / (weights.sum + 1e-20)``; for tokens whose top-K
749+
# sigmoid scores all underflow at bf16/fp32, the output is NaN
750+
# at the selected positions. Those NaNs ride
710751
# ep_dispatch -> recv_topk_weights -> combine and poison the per-token
711752
# weighted sum, leaving entire output rows as NaN. Sanitize at the
712753
# source so neither the fwd combine nor the bwd's manual
713-
# `grad_pre_combine * w` sees them. Padded positions in sparse_probs
714-
# are already zero (routing_map is False there); only the rare
715-
# underflow path emits NaN.
754+
# ``grad_pre_combine * w`` sees them. Padded positions in
755+
# sparse_probs are already zero (routing_map is False there); only
756+
# the rare sigmoid-underflow path emits NaN, which is why the
757+
# filter is observationally a no-op in dense unit tests but must
758+
# stay in for sparse / production routing distributions.
716759
sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype)
717760

718761
# ---------------- Aux loss (global view, replicated) ----------------
@@ -963,12 +1006,11 @@ def _moe_bwd_rule(
9631006
wo_kernel_axes,
9641007
dtype,
9651008
apply_topk_weights_early,
966-
align_size,
9671009
residuals,
9681010
cotangents,
9691011
):
9701012
"""Backward mirror of :func:`_moe_fwd_rule`."""
971-
del num_groups, group_topk, dtype, align_size # captured in residuals / unused in bwd
1013+
del num_groups, group_topk, dtype # captured in residuals / unused in bwd
9721014
from jax.experimental.shard_map import shard_map
9731015

9741016
d_output, d_aux_loss = cotangents
@@ -1243,7 +1285,7 @@ def _bwd_body(*args):
12431285
# =============================================================================
12441286

12451287

1246-
@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 27)))
1288+
@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 26)))
12471289
def _moe(
12481290
x,
12491291
gate_kernel,
@@ -1271,7 +1313,6 @@ def _moe(
12711313
wo_kernel_axes,
12721314
dtype,
12731315
apply_topk_weights_early,
1274-
align_size,
12751316
):
12761317
primal, _ = _moe_fwd_rule(
12771318
x,
@@ -1300,7 +1341,6 @@ def _moe(
13001341
wo_kernel_axes,
13011342
dtype,
13021343
apply_topk_weights_early,
1303-
align_size,
13041344
)
13051345
return primal
13061346

@@ -1329,7 +1369,6 @@ def moe(
13291369
scaling_factor: float = 1.0,
13301370
aux_loss_coeff: float = 0.0,
13311371
apply_topk_weights_early: bool = False,
1332-
align_size: int = 0,
13331372
ep_axis: str,
13341373
data_parallelism_axes: Tuple[str, ...] = (),
13351374
input_axes: Tuple[Optional[str], ...] = (),
@@ -1357,16 +1396,35 @@ def moe(
13571396
all-gather over the routing-side logits is inserted so the
13581397
``fused_moe_aux_loss`` kernel sees a global ``[T_global, E]``
13591398
view; this lives off the dispatch critical path.
1360-
align_size : int
1361-
Minimum per-expert slot alignment passed to ``tex.ep_prepare``
1362-
as ``dispatch_output_per_expert_alignment``. ``0`` (default)
1363-
means use the NCCL-EP-required natural slot count
1364-
``ep_size * max_tokens_per_rank == (B/dp)*S`` (the per-rank
1365-
all-tokens-to-one-expert worst case the HT kernel demands).
1366-
Any positive value rounds that count up to the nearest
1367-
multiple, growing the per-rank receive buffer accordingly.
1368-
Set to ``128`` for FP8 recipes that require 128-aligned
1369-
grouped-GEMM tiles.
1399+
1400+
Note that the per-expert dispatch-slot alignment is fixed internally
1401+
at 128 tokens (``_ALIGN_SIZE``); see that constant's docstring for
1402+
rationale and how to extend if a future recipe needs >128.
1403+
1404+
Axis-name parameters (per PR #3116 review):
1405+
1406+
* ``ep_axis`` and ``data_parallelism_axes`` are *physical mesh
1407+
axis names* -- they index ``jax.sharding.Mesh.shape`` directly
1408+
(to compute ``num_ep`` / ``dp_size`` and to construct
1409+
``P((dp..., ep), None, None)`` for the per-shard
1410+
``jax.lax.with_sharding_constraint`` calls that JAX requires
1411+
to refer to real mesh axes).
1412+
* ``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``,
1413+
``wo_kernel_axes`` are *logical axis names* (e.g.
1414+
``"batch"``, ``"embed"``, ``"mlp"``, ``"exp"``) -- they get
1415+
resolved via the active Flax logical-axis rules and consumed
1416+
by ``with_sharding_constraint_by_logical_axes``. They are
1417+
``Optional[str]`` tuples so a rule of ``None`` means
1418+
"replicated on this axis".
1419+
1420+
Logical-axis support for ``ep_axis`` / ``data_parallelism_axes``
1421+
is intentionally out of scope: the EP comm-group construction
1422+
(``dp_color = rank // ep_size``) and the bootstrap signature
1423+
check both require concrete integer sizes, so a logical name
1424+
would have to be resolved to a physical one anyway before any
1425+
EP primitive is called. If a downstream pipeline needs to plumb
1426+
logical names all the way to ``moe()``, do the rule lookup at
1427+
the call site.
13701428
13711429
See module docstring for the rest of the parameter semantics and the
13721430
surrounding design rationale.
@@ -1428,7 +1486,6 @@ def moe(
14281486
wo_kernel_axes,
14291487
dtype,
14301488
apply_topk_weights_early,
1431-
align_size,
14321489
)
14331490
if aux_loss_coeff <= 0.0:
14341491
aux_loss = None

0 commit comments

Comments
 (0)