Expert parallelism (EP) partitions MoE experts across an EP process
group — each rank holds num_experts / ep experts and tokens are
shuffled between ranks by all-to-all so every token reaches its
assigned expert. With ep=1 (the default), experts are replicated on
every rank and EP is a no-op.
Entry points:
apply_expert_parallel(model, device_mesh)— prunes eachMoEMLPto the local expert slice and stores EP metadata (ep_group,ep_world_size,local_expert_start,num_local_experts).ep_dispatch_and_compute(x, weights, indices, moe, ...)— runs insideMoEMLP.forward()whenep_world_size > 1and implements the all-to-all dispatch / local compute / all-to-all combine.
# kempnerforge/model/moe.py, MoEMLP.forward
if self.ep_world_size > 1:
output = ep_dispatch_and_compute(
x, weights, indices, self,
self.ep_group, self.local_expert_start,
self.num_local_experts, self.ep_world_size,
gradient_scale=self.gradient_scale,
)With ep=1, ep_world_size stays at 1 (the default set in
MoEMLP.__init__) and the forward path runs experts locally. With
ep>1, apply_expert_parallel bumps ep_world_size to the EP mesh
size and populates the other metadata.
ep_dispatch_and_compute is a seven-step sequence:
| # | Step | What it does |
|---|---|---|
| 1 | Expand | (num_tokens, top_k) → flat (num_tokens · top_k,) list of (token_id, expert_id, weight) entries |
| 2 | Sort | Stable-sort entries by target EP rank so same-destination tokens are contiguous |
| 3 | Exchange counts | dist.all_to_all_single on send_counts → every rank learns recv_counts |
| 4 | Dispatch | _AllToAll sends x_sorted to expert-owning ranks; a second all-to-all ships the expert IDs |
| 5 | Local compute | Grouped GEMM over received tokens (sorted by local expert) when torch._grouped_mm is available; fallback is per-expert masked forward |
| 6 | Combine | Reverse all-to-all sends processed tokens back to the originating ranks |
| 7 | Weighted sum | scatter_add_ combines the top_k expert outputs per token with their router weights |
The dispatch all-to-all is wrapped in
_AllToAll
— a custom torch.autograd.Function whose backward is the same
all-to-all with send/recv counts swapped. That's what makes the
forward path differentiable.
If a local expert receives zero tokens in a step, its parameters never enter the autograd graph — and FSDP2's reduce-scatter, which fires only after every param in a unit has accumulated a gradient, hangs forever.
ep_dispatch_and_compute forces an
AccumulateGrad hook to fire on each unused expert by adding a
zero-valued sum of its parameters into the output:
for i in range(num_local_experts):
if tokens_per_expert[i] == 0:
for p in moe.experts[i].parameters():
local_output = local_output + p.sum() * 0Similar zero-contributions handle the packed-expert path and the
case where the dispatch all-to-all would otherwise have no gradient
edge back from local_output to received_tokens (which would
cause the backward all-to-all to be skipped on one side —
positional mismatch in NCCL → deadlock).
for module in model.modules():
if not isinstance(module, MoEMLP):
continue
assert num_experts % ep_size == 0
start = ep_rank * (num_experts // ep_size)
end = start + (num_experts // ep_size)
if module.packed_experts:
# replace Parameter with sliced view (can't resize in place)
module.up_w = Parameter(module.up_w.data[start:end].clone())
module.down_w = Parameter(module.down_w.data[start:end].clone())
...
else:
module.experts = ModuleList([module.experts[i] for i in range(start, end)])
module.ep_world_size = ep_size
module.ep_group = ep_group
module.local_expert_start = start
module.num_local_experts = num_experts // ep_sizeThe router (moe.router) is not sharded — every rank keeps the
full router weights so it can make the routing decision locally
before dispatch. Shared experts (moe.shared_expert) are also kept
on every rank.
EP runs after TP and before FSDP2 — see Parallelism order.
- EP + TP: TP shards the non-MoE Linears (attention q/k/v/o and
shared-expert gate/up/down) along the
tpmesh dim. EP shards the routed experts alongep. Dense TP layers are untouched byapply_expert_parallel. - EP + FSDP2: FSDP2 wraps the MoE layer's
attentionandmlpseparately (per-sub-module wrapping) rather than the whole block — see FSDP2 § EP-MoE. Per-block wrapping would cause FSDP2's reduce-scatter to fire between the two EP all-to-alls in backward, deadlocking. - EP + FP8: expert Linears are excluded from the Float8 pass
(
"experts" in fqn → Falsein the filter). The grouped GEMM path (torch._grouped_mm) doesn't go throughFloat8Linear.forward, so FP8 applied there is ineffective and adds surprise failures. See FP8 § Exclusion rules.
[distributed]
ep = 2 # expert parallelism degree
[model]
num_experts = 8 # global expert count
moe_top_k = 2 # experts per token
moe_shared_experts = 1 # optional always-on expert
moe_packed_experts = false # grouped GEMM with packed weights (opt-in)
moe_gradient_scale = false # per-expert gradient normalization (opt-in)num_experts % ep == 0 is checked at apply time. The parallelism
arithmetic (see DistributedConfig)
requires dp_replicate · dp_shard · tp · pp · cp · ep == world_size.
dp_shard=4, tp=4, ep=2, pp=1
num_experts=8, moe_top_k=2
mesh: ("dp_shard", "ep", "tp") → (4, 2, 4)
- Each EP group (size 2) splits the 8 experts as
experts 0-3on rank 0 andexperts 4-7on rank 1. - Within each EP group, TP shards the per-expert Linears along
tp=4. - FSDP shards the remaining params across the
dp_shard=4axis with per-sub-module wrapping.
Benchmark and reproducer: Benchmarks § MoE Expert Parallelism.
When moe_gradient_scale = true, the output of each local expert is
multiplied by avg_tokens / tokens_for_this_expert so high-traffic
experts don't dominate the gradient. The scaling happens on
local_output before the combine all-to-all, so the adjusted
gradient flows back through the dispatch all-to-all to the router
and expert params correctly. Disabled by default — it changes
gradient magnitudes and should be validated against a baseline run
before flipping on.
- MoE overview — architecture, routers, auxiliary losses. This distributed/ page is the canonical EP reference; the MoE pages link here.
- FSDP2 § EP-MoE — the per-sub-module wrapping pattern.
- FP8 § Exclusion rules — why experts +
router + shared_expert are skipped by
apply_float8. - Parallelism order — EP's place in the apply sequence.
- Validation rules § Expert parallel —
the
num_experts % ep == 0check. - Benchmarks § MoE Expert Parallelism — measured EP speedup.