Skip to content

Commit fe44697

Browse files
committed
jax/moe: strip PR-response framing from comments; drop sparse_probs NaN sanitizer
* Rewrite the inline justifications added in 078a7d80 so each one reads as standalone code documentation, not as a reply to a reviewer: drop "per PR #3116 review", "review feedback", "Renamed from ... per PR ..." and similar PR/thread references from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py. Technical content (why the fp32 promotion is needed for the MoE silu+multiply, why _with_sharding_constraint_cast_bwd exists, physical-vs-logical axis split in moe() docstring, the 128 alignment rationale) is preserved and reframed to be useful to a reader who has no PR context. * Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs) guard. Tracing fused_topk_with_score_function.cu shows the kernel divides by sum_scores + 1e-20, so finite non-negative sigmoid scores cannot produce NaN here; the filter was only defense against upstream NaNs, which would mask a real regression if anything ever did start producing them. Signed-off-by: tdophung <tdophung@nvidia.com>
1 parent 3c24517 commit fe44697

3 files changed

Lines changed: 41 additions & 75 deletions

File tree

tests/jax/test_te_ep_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -531,10 +531,9 @@ def _make_inputs(key):
531531
# multiply (intermediate * recv_w * mask) is debugged. Late
532532
# weighting (combine-side) is unaffected and stays covered above.
533533
# Note: align_size is no longer a user-facing parameter; it is
534-
# hard-coded to _ALIGN_SIZE = 128 in moe.py (per PR #3116
535-
# review). Re-add a distinct align-size config only if the
536-
# constant is loosened, or a recipe-driven inference is added
537-
# that selects a >128 alignment.
534+
# hard-coded to _ALIGN_SIZE = 128 in moe.py. Re-add a distinct
535+
# align-size config only if the constant is loosened, or a
536+
# recipe-driven inference is added that selects a >128 alignment.
538537
pytest.param(
539538
dict(score_function="sigmoid"),
540539
id="sigmoid",

transformer_engine/jax/flax/moe.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ class _MoEBlock(TransformerEngineBase):
8585
If ``True``, registers a per-expert routing bias (shape ``[E]``)
8686
used by the topk selection. Only meaningful with
8787
``score_function="sigmoid"``; the underlying primitive validates
88-
the pairing. (Renamed from ``use_expert_bias`` per PR #3116
89-
review for symmetry with ``use_ffn_bias``.)
88+
the pairing.
9089
aux_loss_coeff : float
9190
If ``> 0``, return the MoE auxiliary load-balancing loss scalar
9291
in addition to the main output.
@@ -106,19 +105,17 @@ class _MoEBlock(TransformerEngineBase):
106105
*inside* each shard before ``ep_combine`` (saves one global
107106
reduction at the cost of an extra broadcast). Default ``False``.
108107
109-
Note that the per-expert dispatch-slot alignment is fixed internally
110-
at 128 tokens (see ``moe._ALIGN_SIZE``). Per PR #3116 review there's
111-
no current model that wants a >128 alignment, so this is not exposed
112-
as a parameter; re-introduce a knob (or recipe-driven inference) if
113-
a future FP8 recipe needs >128.
108+
The per-expert dispatch-slot alignment is fixed internally at 128
109+
tokens (see ``moe._ALIGN_SIZE``) -- the value required by NCCL EP
110+
HT and satisfied by every current TE grouped-GEMM recipe -- and is
111+
therefore not exposed as a per-instance knob.
114112
115113
dtype : jnp.dtype
116114
Compute / parameter dtype.
117115
kernel_init, bias_init, expert_bias_init : Initializers.
118116
use_ffn_bias : bool
119117
Register per-expert FFN biases (``wi_0_bias``, ``wi_1_bias``,
120-
``wo_bias``). (Renamed from ``use_bias`` per PR #3116 review
121-
for symmetry with ``use_expert_routing_bias``.)
118+
``wo_bias``).
122119
123120
Quantization is currently configured via the standard TE autocast
124121
context (``fp8_autocast``/``with_quantizer_set``) and threaded

transformer_engine/jax/moe.py

Lines changed: 32 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -61,41 +61,32 @@
6161
# Per-expert dispatch-slot alignment fed to ``tex.ep_prepare`` as
6262
# ``dispatch_output_per_expert_alignment``. NCCL EP HT requires the
6363
# per-expert recv block to be at least 128-token aligned, and all current
64-
# TE grouped-GEMM recipes (bf16/fp16/fp8/mxfp8) are satisfied by a
65-
# 128-token tile, so a single hard-coded constant suffices.
66-
#
67-
# We deliberately omit a user-facing knob: per PR #3116 review there's no
68-
# current model that wants a >128 alignment, and exposing it widens the
69-
# MoEBlock API surface without buying anything. Re-introduce a parameter
70-
# (or recipe-driven inference, see jberchtold's follow-up) if a future
71-
# recipe needs >128.
64+
# TE grouped-GEMM recipes (bf16/fp16/fp8/mxfp8) are satisfied by the
65+
# same 128-token tile, so a single constant covers every supported path.
7266
_ALIGN_SIZE = 128
7367

7468

7569
def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray:
76-
"""Apply a sharding constraint while keeping bwd cotangents in the primal dtype.
77-
78-
Plain ``jax.lax.with_sharding_constraint`` propagates cotangents in
79-
whatever dtype the upstream gradient lands in. Under mixed precision
80-
that can be wider than the primal, blowing up bandwidth and (for
81-
bf16 primals) breaking downstream kernels that pin a bf16 input
82-
layout. This wrapper re-casts the cotangent back to the primal
83-
dtype and re-asserts the same sharding on the bwd path.
84-
85-
Why MoE specifically needs this (per PR #3116 review): unlike a
86-
plain LN+MLP block, the MoE bwd composes two cotangent paths into
87-
``d_x`` -- one through ``ep_dispatch_bwd`` (bf16) and one through
88-
``d_logits_2d @ gate_kernel.T``. The latter starts from
89-
``fused_topk_with_score_function_bwd``, which returns ``d_logits_2d``
90-
in fp32 because the fwd promoted ``logits_2d`` to fp32 (the topk /
91-
softmax / sigmoid kernels are only validated at fp32; see
92-
``tests/pytorch/test_fused_router.py``). The fp32 ``d_logits_2d``
93-
then composes with ``gate_kernel.T`` and adds into the bf16
94-
``d_x_from_dispatch``, yielding an fp32 sum even though the user's
95-
``x`` is bf16. Without this cast, the user-visible ``d_x`` flows
96-
back into the optimizer at fp32 -- silently doubling the activation
97-
grad bandwidth and tripping any downstream kernel that pins a bf16
98-
input layout (e.g. an LN bwd that fuses into our ``d_x``).
70+
"""Sharding constraint that keeps bwd cotangents in the primal dtype.
71+
72+
Plain ``jax.lax.with_sharding_constraint`` is identity on the fwd
73+
but does not constrain the dtype of the cotangent that flows back
74+
through it. In this MoE bwd, ``d_x`` is built from two paths:
75+
76+
* ``d_x_from_dispatch`` from ``ep_dispatch_bwd`` -- primal dtype
77+
(bf16 in mixed precision).
78+
* ``d_x_from_gate = d_logits_2d @ gate_kernel.T`` where
79+
``d_logits_2d`` is produced by
80+
``fused_topk_with_score_function_bwd``. That primitive runs at
81+
fp32 because the fwd promoted ``logits_2d`` to fp32 (the fused
82+
topk/softmax/sigmoid kernels are only validated at fp32).
83+
84+
JAX's type promotion then makes ``d_x_from_gate + d_x_from_dispatch``
85+
fp32, so the user-visible ``d_x`` ends up wider than ``x``. That
86+
doubles activation-grad bandwidth and breaks any downstream kernel
87+
that pins a bf16 input layout. This wrapper inserts an explicit
88+
cast back to the primal dtype on the bwd side and re-asserts the
89+
same sharding there as well.
9990
"""
10091

10192
@jax.custom_vjp
@@ -550,22 +541,14 @@ def _ffn_bwd_per_shard(
550541
else:
551542
d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat)
552543

553-
# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply
554-
# so the silu derivative composes through the gradient at fp32 too;
555-
# cast back to the bf16 layout the wi grouped_quantize expects.
556-
#
557-
# Why MoE specifically needs this (per PR #3116 review): the
558-
# non-MoE LN+MLP block can stay in the activation dtype because
559-
# its silu accumulates over a single per-row dot product whose
560-
# numerical drift is absorbed by the downstream LN. The MoE
561-
# silu, by contrast, sits on the *expert* side of the routing,
562-
# so its bf16 rounding error rides directly into ``expert_outputs``
563-
# and is summed (weighted by routing probs) across topk experts
564-
# by ep_combine -- bf16 silu alone drifts ~1% vs fp32 silu, which
565-
# compounds through wo->combine into the ~1.4% per-element parity
566-
# gap we measured against the pure-JAX softmax reference. Mirroring
567-
# the fwd fp32 promotion keeps the bwd's silu' derivative in lock-
568-
# step with the fwd's silu and preserves grad parity.
544+
# Activation bwd. The fwd already computes silu+multiply at fp32
545+
# because the MoE silu sits on the expert side of routing: its
546+
# output rides into ``expert_outputs`` and is then summed -- weighted
547+
# by routing probabilities -- across topk experts by ep_combine.
548+
# Doing silu/silu' in bf16 drifts by ~1% per element vs fp32 and
549+
# that drift compounds through wo->combine. Mirror the fwd's fp32
550+
# promotion here so silu' lines up with silu, then cast back to the
551+
# bf16 layout the wi grouped_quantize expects.
569552
gp_fp32 = gate_proj_out.astype(jnp.float32)
570553
up_fp32 = up_proj_out.astype(jnp.float32)
571554
d_int_fp32 = d_intermediate.astype(jnp.float32)
@@ -743,20 +726,7 @@ def _moe_fwd_rule(
743726
expert_bias=eb_arg,
744727
compute_aux_scores=False,
745728
)
746-
# NOTE (PR #3116 review): this NaN filter is a *correctness
747-
# requirement*, NOT a debugging artifact. Sigmoid + K>1 normalises
748-
# as ``weights / (weights.sum + 1e-20)``; for tokens whose top-K
749-
# sigmoid scores all underflow at bf16/fp32, the output is NaN
750-
# at the selected positions. Those NaNs ride
751-
# ep_dispatch -> recv_topk_weights -> combine and poison the per-token
752-
# weighted sum, leaving entire output rows as NaN. Sanitize at the
753-
# source so neither the fwd combine nor the bwd's manual
754-
# ``grad_pre_combine * w`` sees them. Padded positions in
755-
# sparse_probs are already zero (routing_map is False there); only
756-
# the rare sigmoid-underflow path emits NaN, which is why the
757-
# filter is observationally a no-op in dense unit tests but must
758-
# stay in for sparse / production routing distributions.
759-
sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype)
729+
sparse_probs = sparse_probs.astype(dtype)
760730

761731
# ---------------- Aux loss (global view, replicated) ----------------
762732
# ``fused_moe_aux_loss_fwd`` sums probs and tokens_per_expert across
@@ -1401,7 +1371,7 @@ def moe(
14011371
at 128 tokens (``_ALIGN_SIZE``); see that constant's docstring for
14021372
rationale and how to extend if a future recipe needs >128.
14031373
1404-
Axis-name parameters (per PR #3116 review):
1374+
Axis-name parameters:
14051375
14061376
* ``ep_axis`` and ``data_parallelism_axes`` are *physical mesh
14071377
axis names* -- they index ``jax.sharding.Mesh.shape`` directly

0 commit comments

Comments
 (0)