|
61 | 61 | # Per-expert dispatch-slot alignment fed to ``tex.ep_prepare`` as |
62 | 62 | # ``dispatch_output_per_expert_alignment``. NCCL EP HT requires the |
63 | 63 | # per-expert recv block to be at least 128-token aligned, and all current |
64 | | -# TE grouped-GEMM recipes (bf16/fp16/fp8/mxfp8) are satisfied by a |
65 | | -# 128-token tile, so a single hard-coded constant suffices. |
66 | | -# |
67 | | -# We deliberately omit a user-facing knob: per PR #3116 review there's no |
68 | | -# current model that wants a >128 alignment, and exposing it widens the |
69 | | -# MoEBlock API surface without buying anything. Re-introduce a parameter |
70 | | -# (or recipe-driven inference, see jberchtold's follow-up) if a future |
71 | | -# recipe needs >128. |
| 64 | +# TE grouped-GEMM recipes (bf16/fp16/fp8/mxfp8) are satisfied by the |
| 65 | +# same 128-token tile, so a single constant covers every supported path. |
72 | 66 | _ALIGN_SIZE = 128 |
73 | 67 |
|
74 | 68 |
|
75 | 69 | def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray: |
76 | | - """Apply a sharding constraint while keeping bwd cotangents in the primal dtype. |
77 | | -
|
78 | | - Plain ``jax.lax.with_sharding_constraint`` propagates cotangents in |
79 | | - whatever dtype the upstream gradient lands in. Under mixed precision |
80 | | - that can be wider than the primal, blowing up bandwidth and (for |
81 | | - bf16 primals) breaking downstream kernels that pin a bf16 input |
82 | | - layout. This wrapper re-casts the cotangent back to the primal |
83 | | - dtype and re-asserts the same sharding on the bwd path. |
84 | | -
|
85 | | - Why MoE specifically needs this (per PR #3116 review): unlike a |
86 | | - plain LN+MLP block, the MoE bwd composes two cotangent paths into |
87 | | - ``d_x`` -- one through ``ep_dispatch_bwd`` (bf16) and one through |
88 | | - ``d_logits_2d @ gate_kernel.T``. The latter starts from |
89 | | - ``fused_topk_with_score_function_bwd``, which returns ``d_logits_2d`` |
90 | | - in fp32 because the fwd promoted ``logits_2d`` to fp32 (the topk / |
91 | | - softmax / sigmoid kernels are only validated at fp32; see |
92 | | - ``tests/pytorch/test_fused_router.py``). The fp32 ``d_logits_2d`` |
93 | | - then composes with ``gate_kernel.T`` and adds into the bf16 |
94 | | - ``d_x_from_dispatch``, yielding an fp32 sum even though the user's |
95 | | - ``x`` is bf16. Without this cast, the user-visible ``d_x`` flows |
96 | | - back into the optimizer at fp32 -- silently doubling the activation |
97 | | - grad bandwidth and tripping any downstream kernel that pins a bf16 |
98 | | - input layout (e.g. an LN bwd that fuses into our ``d_x``). |
| 70 | + """Sharding constraint that keeps bwd cotangents in the primal dtype. |
| 71 | +
|
| 72 | + Plain ``jax.lax.with_sharding_constraint`` is identity on the fwd |
| 73 | + but does not constrain the dtype of the cotangent that flows back |
| 74 | + through it. In this MoE bwd, ``d_x`` is built from two paths: |
| 75 | +
|
| 76 | + * ``d_x_from_dispatch`` from ``ep_dispatch_bwd`` -- primal dtype |
| 77 | + (bf16 in mixed precision). |
| 78 | + * ``d_x_from_gate = d_logits_2d @ gate_kernel.T`` where |
| 79 | + ``d_logits_2d`` is produced by |
| 80 | + ``fused_topk_with_score_function_bwd``. That primitive runs at |
| 81 | + fp32 because the fwd promoted ``logits_2d`` to fp32 (the fused |
| 82 | + topk/softmax/sigmoid kernels are only validated at fp32). |
| 83 | +
|
| 84 | + JAX's type promotion then makes ``d_x_from_gate + d_x_from_dispatch`` |
| 85 | + fp32, so the user-visible ``d_x`` ends up wider than ``x``. That |
| 86 | + doubles activation-grad bandwidth and breaks any downstream kernel |
| 87 | + that pins a bf16 input layout. This wrapper inserts an explicit |
| 88 | + cast back to the primal dtype on the bwd side and re-asserts the |
| 89 | + same sharding there as well. |
99 | 90 | """ |
100 | 91 |
|
101 | 92 | @jax.custom_vjp |
@@ -550,22 +541,14 @@ def _ffn_bwd_per_shard( |
550 | 541 | else: |
551 | 542 | d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat) |
552 | 543 |
|
553 | | - # Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply |
554 | | - # so the silu derivative composes through the gradient at fp32 too; |
555 | | - # cast back to the bf16 layout the wi grouped_quantize expects. |
556 | | - # |
557 | | - # Why MoE specifically needs this (per PR #3116 review): the |
558 | | - # non-MoE LN+MLP block can stay in the activation dtype because |
559 | | - # its silu accumulates over a single per-row dot product whose |
560 | | - # numerical drift is absorbed by the downstream LN. The MoE |
561 | | - # silu, by contrast, sits on the *expert* side of the routing, |
562 | | - # so its bf16 rounding error rides directly into ``expert_outputs`` |
563 | | - # and is summed (weighted by routing probs) across topk experts |
564 | | - # by ep_combine -- bf16 silu alone drifts ~1% vs fp32 silu, which |
565 | | - # compounds through wo->combine into the ~1.4% per-element parity |
566 | | - # gap we measured against the pure-JAX softmax reference. Mirroring |
567 | | - # the fwd fp32 promotion keeps the bwd's silu' derivative in lock- |
568 | | - # step with the fwd's silu and preserves grad parity. |
| 544 | + # Activation bwd. The fwd already computes silu+multiply at fp32 |
| 545 | + # because the MoE silu sits on the expert side of routing: its |
| 546 | + # output rides into ``expert_outputs`` and is then summed -- weighted |
| 547 | + # by routing probabilities -- across topk experts by ep_combine. |
| 548 | + # Doing silu/silu' in bf16 drifts by ~1% per element vs fp32 and |
| 549 | + # that drift compounds through wo->combine. Mirror the fwd's fp32 |
| 550 | + # promotion here so silu' lines up with silu, then cast back to the |
| 551 | + # bf16 layout the wi grouped_quantize expects. |
569 | 552 | gp_fp32 = gate_proj_out.astype(jnp.float32) |
570 | 553 | up_fp32 = up_proj_out.astype(jnp.float32) |
571 | 554 | d_int_fp32 = d_intermediate.astype(jnp.float32) |
@@ -743,20 +726,7 @@ def _moe_fwd_rule( |
743 | 726 | expert_bias=eb_arg, |
744 | 727 | compute_aux_scores=False, |
745 | 728 | ) |
746 | | - # NOTE (PR #3116 review): this NaN filter is a *correctness |
747 | | - # requirement*, NOT a debugging artifact. Sigmoid + K>1 normalises |
748 | | - # as ``weights / (weights.sum + 1e-20)``; for tokens whose top-K |
749 | | - # sigmoid scores all underflow at bf16/fp32, the output is NaN |
750 | | - # at the selected positions. Those NaNs ride |
751 | | - # ep_dispatch -> recv_topk_weights -> combine and poison the per-token |
752 | | - # weighted sum, leaving entire output rows as NaN. Sanitize at the |
753 | | - # source so neither the fwd combine nor the bwd's manual |
754 | | - # ``grad_pre_combine * w`` sees them. Padded positions in |
755 | | - # sparse_probs are already zero (routing_map is False there); only |
756 | | - # the rare sigmoid-underflow path emits NaN, which is why the |
757 | | - # filter is observationally a no-op in dense unit tests but must |
758 | | - # stay in for sparse / production routing distributions. |
759 | | - sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype) |
| 729 | + sparse_probs = sparse_probs.astype(dtype) |
760 | 730 |
|
761 | 731 | # ---------------- Aux loss (global view, replicated) ---------------- |
762 | 732 | # ``fused_moe_aux_loss_fwd`` sums probs and tokens_per_expert across |
@@ -1401,7 +1371,7 @@ def moe( |
1401 | 1371 | at 128 tokens (``_ALIGN_SIZE``); see that constant's docstring for |
1402 | 1372 | rationale and how to extend if a future recipe needs >128. |
1403 | 1373 |
|
1404 | | - Axis-name parameters (per PR #3116 review): |
| 1374 | + Axis-name parameters: |
1405 | 1375 |
|
1406 | 1376 | * ``ep_axis`` and ``data_parallelism_axes`` are *physical mesh |
1407 | 1377 | axis names* -- they index ``jax.sharding.Mesh.shape`` directly |
|
0 commit comments