Skip to content

Commit fecb0ed

Browse files
committed
[JAX] MoE: cherry-pick 3 independent fixes from jberchtold/te_ep_integration
Pull three small, orthogonal correctness improvements from jberchtold's parallel work on teddy/te_ep_integration that don't touch the FFN shard_map or our dispatch zero-init workaround: 1. ``effective_align = max(align_size, 128)`` floor on the per-rank receive slots in ``moe.py``. NCCL EP requires each expert-major output block to be at least 128-token aligned; the previous ``align_size > 0`` branch could emit a smaller natural block on tiny configs and trip the dispatch buffer check. (df61642) 2. Size-1-axis guard in ``_ep_outer_axis()`` in both ``cpp_extensions/ep.py`` and ``ep.py``. A dp/fsdp axis that is sized 1 in the active mesh is now treated as absent so we don't pin EP-output specs to a degenerate axis that JAX may silently collapse. Mirrored the helper into ``ep.py`` so both files share the same predicate. (2210702) 3. ``_with_sharding_constraint_cast_bwd`` custom-VJP wrapper in ``moe.py``, applied to the inbound activation re-pin. Keeps the bwd cotangent in the primal dtype and re-asserts the same sharding on the bwd path, instead of letting a wider gradient land back at the caller. (2210702) Deliberately deferred: his shard_map removal + new global-view FFN call sites in ``2210702a``'s ``moe.py`` rewrite. Those depend on the grouped-GEMM custom partitioning landing on main and are a later-phase integration sweep.
1 parent 42db5b6 commit fecb0ed

3 files changed

Lines changed: 52 additions & 9 deletions

File tree

transformer_engine/jax/cpp_extensions/ep.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import transformer_engine_jax
2626
from .base import BasePrimitive, register_primitive
27-
from ..sharding import global_mesh_resource
27+
from ..sharding import global_mesh_resource, get_mesh_axis_size
2828

2929
__all__ = [
3030
"EpConfig",
@@ -187,8 +187,15 @@ def _ep_outer_axis():
187187
188188
When set, EP-output globals carry an extra leading ``dp_size`` dim so SPMD
189189
sees each DP color's slab as distinct (rather than replicated across DP).
190+
191+
A dp/fsdp axis that is sized 1 in the active mesh is treated as absent so
192+
we don't pin EP-output specs to a degenerate axis that JAX may collapse.
190193
"""
191194
gsr = global_mesh_resource()
195+
if gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1:
196+
return gsr.dp_resource
197+
if gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1:
198+
return gsr.fsdp_resource
192199
return gsr.dp_resource or gsr.fsdp_resource
193200

194201

@@ -536,7 +543,7 @@ def _resolve_out_partition_spec(out_partition_spec, num_leading):
536543
"ep_combine: ep_resource is not set on the active MeshResource;"
537544
" pass out_sharding=... explicitly."
538545
)
539-
outer = gsr.dp_resource or gsr.fsdp_resource
546+
outer = _ep_outer_axis()
540547
leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource
541548
return (leading,) + (None,) * num_leading
542549

transformer_engine/jax/ep.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,24 @@ def _dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank
229229
return primal, (handle_mem, out_leading, top_k)
230230

231231

232+
def _ep_outer_axis():
233+
"""Mirror of cpp_extensions.ep._ep_outer_axis (size-1 axes treated as absent)."""
234+
gsr = global_mesh_resource()
235+
if gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1:
236+
return gsr.dp_resource
237+
if gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1:
238+
return gsr.fsdp_resource
239+
return gsr.dp_resource or gsr.fsdp_resource
240+
241+
232242
def _dispatch_bwd(handle, recv_capacity_per_rank, res, g_outputs):
233243
del recv_capacity_per_rank
234244
handle_mem, out_leading, top_k = res
235245
# Re-pin cotangent sharding: XLA transpose can drop the EP axis on a
236246
# single-fwd-output cotangent, landing a global tensor in the FFI.
237247
gsr = global_mesh_resource()
238248
ep_axis = gsr.ep_resource
239-
outer = gsr.dp_resource or gsr.fsdp_resource
249+
outer = _ep_outer_axis()
240250
leading = (outer, ep_axis) if outer is not None else ep_axis
241251
g_recv_tokens = jax.lax.with_sharding_constraint(
242252
g_outputs[0], jax.sharding.PartitionSpec(leading, None, None)
@@ -315,7 +325,7 @@ def _combine_bwd(handle, _num_local_tokens, _out_sharding, res, g_result):
315325
spec = jax.sharding.PartitionSpec(*_out_sharding)
316326
else:
317327
ep_axis = gsr.ep_resource
318-
outer = gsr.dp_resource or gsr.fsdp_resource
328+
outer = _ep_outer_axis()
319329
leading = (outer, ep_axis) if outer is not None and ep_axis is not None else ep_axis
320330
spec = (
321331
jax.sharding.PartitionSpec(leading, *([None] * (g_result.ndim - 1)))

transformer_engine/jax/moe.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,31 @@ def _inspect(x: jnp.ndarray, name: str) -> jnp.ndarray:
8383
__all__ = ["moe"]
8484

8585

86+
def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray:
87+
"""Apply a sharding constraint while keeping bwd cotangents in the primal dtype.
88+
89+
Plain ``jax.lax.with_sharding_constraint`` propagates cotangents in
90+
whatever dtype the upstream gradient lands in; under mixed precision
91+
that can be wider than the primal, blowing up bandwidth and (for
92+
bf16 primals) breaking downstream kernels that pin a bf16 input
93+
layout. This wrapper re-casts the cotangent back to the primal
94+
dtype and re-asserts the same sharding on the bwd path.
95+
"""
96+
97+
@jax.custom_vjp
98+
def _constraint(y):
99+
return jax.lax.with_sharding_constraint(y, sharding)
100+
101+
def _constraint_fwd(y):
102+
return jax.lax.with_sharding_constraint(y, sharding), jnp.zeros((), dtype=y.dtype)
103+
104+
def _constraint_bwd(dtype_ref, grad):
105+
return (jax.lax.with_sharding_constraint(grad.astype(dtype_ref.dtype), sharding),)
106+
107+
_constraint.defvjp(_constraint_fwd, _constraint_bwd)
108+
return _constraint(x)
109+
110+
86111
# =============================================================================
87112
# Process-level NCCL EP bootstrap (must run eagerly, outside jax.jit)
88113
# =============================================================================
@@ -631,10 +656,11 @@ def _moe_fwd_rule(
631656
# local expert. We must size to that worst case or NCCL EP's HT kernel
632657
# rejects the dispatch buffer with ``invalid argument``.
633658
natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S
634-
if align_size > 0:
635-
slots_per_expert = ((natural_spe + align_size - 1) // align_size) * align_size
636-
else:
637-
slots_per_expert = natural_spe
659+
# NCCL EP requires each expert-major output block to be at least
660+
# 128-token aligned. Keep larger caller-requested alignments, but
661+
# do not emit a smaller natural block size for tiny tests.
662+
effective_align = max(int(align_size), 128)
663+
slots_per_expert = ((natural_spe + effective_align - 1) // effective_align) * effective_align
638664
recv_pr = num_local_experts * slots_per_expert
639665

640666
_te_ep_assert_compatible_bootstrap(
@@ -1406,7 +1432,7 @@ def moe(
14061432
UserWarning,
14071433
stacklevel=2,
14081434
)
1409-
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, expected_spec))
1435+
x = _with_sharding_constraint_cast_bwd(x, NamedSharding(mesh, expected_spec))
14101436

14111437
# custom_vjp can't trace through None args; lower expert_bias to an
14121438
# empty shape-(0,) tensor that fused_topk_with_score_function treats

0 commit comments

Comments
 (0)