Skip to content

Commit 56541ae

Browse files
Merge pull request #3981 from AI-Hypercomputer:refactor_moe
PiperOrigin-RevId: 926186385
2 parents 39c1b73 + 554cd4f commit 56541ae

1 file changed

Lines changed: 147 additions & 95 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 147 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from aqt.jax.v2 import aqt_tensor as aqt
2525
from flax import nnx
26+
from flax import struct
2627
import jax
2728
from jax import ad_checkpoint as adc
2829
from jax.experimental import xla_metadata
@@ -56,6 +57,39 @@
5657
COMBINE = "combine"
5758

5859

60+
@struct.dataclass
61+
class RouteMetadata:
62+
"""EP communication state needed to undo the forward all-to-all after expert computation."""
63+
64+
# Index of this device's EP shard.
65+
expert_shard_id: int
66+
# Permutation of [tokens received by this expert shard], sorted by local expert ID.
67+
local_sorted_indices: Optional[jax.Array]
68+
# Shape [num_ep]. Aggregates group_sizes per EP shard; tracks how many local tokens are routed to each EP shard.
69+
reshaped_group_sizes: Optional[jax.Array]
70+
# Shape [num_ep, num_ep]. all_gather of reshaped_group_sizes across EP shards.
71+
# [i, j] = number of tokens from batch shard i sent to expert shard j.
72+
all_shards_group_sizes: Optional[jax.Array]
73+
74+
75+
@struct.dataclass
76+
class RouteOutput:
77+
"""Holds state of routing output"""
78+
79+
# Shape [num experts], tracks number of local tokens routed to every expert.
80+
group_sizes: jax.Array
81+
# Indices of experts chosen for each token.
82+
selected_experts: jax.Array
83+
# Tokens sorted by experts they are routed to.
84+
sorted_selected_experts: jax.Array
85+
# Weights for each of the selected experts.
86+
weights: jax.Array
87+
# Auxiliary loss for token distribution among experts.
88+
lb_loss: Optional[jax.Array]
89+
# Dynamic bias updates for loss-free load balancing, used only for Deepseek models
90+
bias_updates: Optional[jax.Array]
91+
92+
5993
def _sort_activations(
6094
inputs: jax.Array,
6195
sort_indices: jax.Array,
@@ -1300,37 +1334,15 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert):
13001334
) = get_routed_moe_shardings(is_batch_sharded_by_expert)
13011335
w0_pspec, w1_pspec, wo_pspec = maybe_aqt_partition(w0_kernel, w0_pspec, w1_kernel, w1_pspec, wo_kernel, wo_pspec)
13021336

1303-
@functools.partial(
1304-
jax.shard_map,
1305-
mesh=self.mesh,
1306-
in_specs=(
1307-
input_partition_pspec,
1308-
gate_logits_pspec,
1309-
pre_bias_logits_pspec,
1310-
w0_pspec,
1311-
w1_pspec,
1312-
wo_pspec,
1313-
w0_bias_pspec,
1314-
w1_bias_pspec,
1315-
wo_bias_pspec,
1316-
P(), # Replicate the input key
1317-
),
1318-
out_specs=(
1319-
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
1320-
P(), # Handle None or replicate the output
1321-
P(), # Handle None or replicate the output
1322-
),
1323-
check_vma=False,
1324-
)
1325-
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
1326-
batch_size, sequence_length, _ = x.shape
1327-
num_expert_parallelism = self.get_expert_parallelism_size()
1328-
if num_expert_parallelism > 1:
1329-
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name)
1330-
else:
1331-
expert_shard_id = 0
1332-
num_expert_parallelism = self.get_expert_parallelism_size()
1333-
global_group_sizes = None
1337+
def route(x, logits, pre_bias_logits, rngs):
1338+
"""Performs both across device and within device token routing/sorting"""
1339+
num_ep = self.get_expert_parallelism_size()
1340+
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name) if num_ep > 1 else 0
1341+
1342+
local_sorted_indices = None
1343+
all_shards_group_sizes = None
1344+
reshaped_group_sizes = None
1345+
13341346
if self.config.use_ring_of_experts:
13351347
# The ring-of-experts strategy first duplicates the inputs to all
13361348
# expert shards, and then routes within each shard.
@@ -1342,7 +1354,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13421354
)
13431355

13441356
# "Route" tokens within each shard.
1345-
num_experts_per_shard = self.config.num_experts // num_expert_parallelism
1357+
num_experts_per_shard = self.config.num_experts // num_ep
13461358
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
13471359
x,
13481360
logits,
@@ -1362,10 +1374,10 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13621374
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs
13631375
)
13641376

1365-
if num_expert_parallelism > 1:
1377+
if num_ep > 1:
13661378
batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data"
13671379
# get group sizes for all shards
1368-
local_expert_size = self.config.num_experts // num_expert_parallelism
1380+
local_expert_size = self.config.num_experts // num_ep
13691381
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
13701382
global_group_sizes = group_sizes
13711383

@@ -1374,12 +1386,12 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13741386
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
13751387
all_shards_group_sizes,
13761388
expert_shard_id,
1377-
num_expert_parallelism,
1389+
num_ep,
13781390
)
13791391

13801392
buffer_size = self.get_ragged_buffer_size(
13811393
jnp.shape(x)[0],
1382-
num_expert_parallelism,
1394+
num_ep,
13831395
self.config.num_experts,
13841396
self.config.num_experts_per_tok,
13851397
self.config.ragged_buffer_factor,
@@ -1416,35 +1428,40 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
14161428
use_ragged_sort=self.config.use_ragged_sort,
14171429
)
14181430

1419-
if self.config.mlp_bias:
1420-
w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias)
1421-
1422-
def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
1423-
if pspec_dim_axes is None:
1424-
return []
1425-
axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes
1426-
active = []
1427-
for ax in axes:
1428-
if ax and self.mesh.shape.get(ax, 1) > 1:
1429-
active.append((ax, tensor_dim_index))
1430-
return active
1431+
return (
1432+
x,
1433+
RouteOutput(
1434+
group_sizes=group_sizes,
1435+
selected_experts=selected_experts,
1436+
sorted_selected_experts=sorted_selected_experts,
1437+
weights=weights,
1438+
lb_loss=lb_loss,
1439+
bias_updates=bias_updates,
1440+
),
1441+
RouteMetadata(
1442+
expert_shard_id=expert_shard_id,
1443+
local_sorted_indices=local_sorted_indices,
1444+
all_shards_group_sizes=all_shards_group_sizes,
1445+
reshaped_group_sizes=reshaped_group_sizes,
1446+
),
1447+
)
14311448

1449+
def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
1450+
if pspec_dim_axes is None:
1451+
return []
1452+
axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes
1453+
active = []
1454+
for ax in axes:
1455+
if ax and self.mesh.shape.get(ax, 1) > 1:
1456+
active.append((ax, tensor_dim_index))
1457+
return active
1458+
1459+
def get_wi_gmm_params():
14321460
wi_gather_axes = []
1433-
wo_gather_axes = []
1434-
14351461
if weight_gather:
14361462
# wi [Experts, In, Hidden] -> Gather Exp(0) and Hidden(2)
14371463
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[0], 0))
14381464
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[2], 2))
1439-
1440-
# wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1441-
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0))
1442-
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1))
1443-
gmm_fn = functools.partial(
1444-
gmm,
1445-
group_sizes=group_sizes,
1446-
expert_assignments=selected_experts,
1447-
)
14481465
wi_tile_size = (
14491466
self.config.wi_tile_fwd_batch_seq, # m (LHS batch)
14501467
self.config.wi_tile_fwd_embed_dim, # k (contracting)
@@ -1456,6 +1473,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14561473
self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim
14571474
self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim
14581475
)
1476+
return wi_gather_axes, wi_tile_size
1477+
1478+
def get_wo_gmm_params():
1479+
wo_gather_axes = []
1480+
if weight_gather:
1481+
# wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1482+
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0))
1483+
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1))
14591484
wo_tile_size = (
14601485
self.config.wo_tile_fwd_batch_seq, # m (LHS batch)
14611486
self.config.wo_tile_fwd_mlp_dim, # k (contracting)
@@ -1467,32 +1492,59 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14671492
self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim
14681493
self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
14691494
)
1495+
return wo_gather_axes, wo_tile_size
14701496

1471-
layer_w0 = gmm_fn(
1472-
x,
1473-
w0,
1474-
tiling=wi_tile_size,
1475-
weight_gather_axes=wi_gather_axes,
1476-
)
1497+
def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather):
1498+
"""Run the two up-projections (gate + up) and apply the FFN activation."""
1499+
wi_gather_axes, wi_tile_size = get_wi_gmm_params()
1500+
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
14771501
if self.get_tensor_transpose_parallelism_size() > 1:
14781502
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
14791503
if self.config.mlp_bias:
14801504
layer_w0 = layer_w0 + w0_bias
14811505
layer_w0 = adc.checkpoint_name(adc.checkpoint_name(layer_w0, "mlpwi_0"), "moe_mlpwi_0")
14821506

1483-
layer_w1 = gmm_fn(
1484-
x,
1485-
w1,
1486-
tiling=wi_tile_size,
1487-
weight_gather_axes=wi_gather_axes,
1488-
)
1507+
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
14891508
if self.get_tensor_transpose_parallelism_size() > 1:
14901509
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
14911510
if self.config.mlp_bias:
14921511
layer_w1 = layer_w1 + w1_bias
1493-
layer_w1 = adc.checkpoint_name(adc.checkpoint_name(layer_w1, "mlpwi_1"), "moe_mlpwi_1")
1494-
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
1512+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
1513+
return self.apply_ffn_activation(layer_w0, layer_w1)
1514+
1515+
@functools.partial(
1516+
jax.shard_map,
1517+
mesh=self.mesh,
1518+
in_specs=(
1519+
input_partition_pspec,
1520+
gate_logits_pspec,
1521+
pre_bias_logits_pspec,
1522+
w0_pspec,
1523+
w1_pspec,
1524+
wo_pspec,
1525+
w0_bias_pspec,
1526+
w1_bias_pspec,
1527+
wo_bias_pspec,
1528+
P(), # Replicate the input key
1529+
),
1530+
out_specs=(
1531+
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
1532+
P(), # Handle None or replicate the output
1533+
P(), # Handle None or replicate the output
1534+
),
1535+
check_vma=False,
1536+
)
1537+
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
1538+
batch_size, sequence_length, _ = x.shape
1539+
x, routing, route_metadata = route(x, logits, pre_bias_logits, rngs)
14951540

1541+
if self.config.mlp_bias:
1542+
w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias)
1543+
1544+
gmm_fn = functools.partial(gmm, group_sizes=routing.group_sizes, expert_assignments=routing.selected_experts)
1545+
intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather)
1546+
1547+
wo_gather_axes, wo_tile_size = get_wo_gmm_params()
14961548
intermediate_output = gmm_fn(
14971549
intermediate_layer,
14981550
wo,
@@ -1509,18 +1561,18 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15091561

15101562
if self.config.use_ring_of_experts:
15111563
# Set the outputs of tokens which were not processed to 0.
1512-
mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(group_sizes)
1564+
mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(routing.group_sizes)
15131565
intermediate_output = jnp.where(mask[:, None], intermediate_output, 0)
15141566

15151567
# Unsort and deduplicate the outputs locally.
15161568
output = self.unpermute(
15171569
intermediate_output,
1518-
sorted_selected_experts,
1519-
weights,
1570+
routing.sorted_selected_experts,
1571+
routing.weights,
15201572
batch_size=batch_size,
15211573
sequence_length=sequence_length,
15221574
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
1523-
group_sizes=group_sizes,
1575+
group_sizes=routing.group_sizes,
15241576
)
15251577

15261578
# Sum up the partial outputs across the expert shards.
@@ -1530,9 +1582,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15301582
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
15311583

15321584
else:
1533-
if num_expert_parallelism > 1:
1585+
if self.get_expert_parallelism_size() > 1:
15341586
original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok
1535-
if sorted_selected_experts.shape[0] != original_inputs_first_dim:
1587+
if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim:
15361588
raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!")
15371589
output_shape = jax.lax.empty(
15381590
(
@@ -1548,22 +1600,23 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15481600
# Mirror the ragged-prefix gather used in `local_permute`. The
15491601
# un-permute can use the same valid-prefix length because the
15501602
# routed token count is identical for forward and backward.
1551-
valid_end = jnp.sum(group_sizes).astype(jnp.int32) # pylint: disable=undefined-variable
1603+
valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32)
15521604
local_output = a2a_ragged_unsort(
15531605
intermediate_output,
1554-
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
1606+
jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable
15551607
valid_end,
15561608
)
15571609
else:
15581610
local_output = _sort_activations(
15591611
intermediate_output,
1560-
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
1612+
jnp.argsort(route_metadata.local_sorted_indices),
15611613
self.config.use_custom_sort_vjp,
15621614
)
1615+
15631616
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1564-
jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable
1565-
expert_shard_id,
1566-
num_expert_parallelism,
1617+
jnp.transpose(route_metadata.all_shards_group_sizes),
1618+
route_metadata.expert_shard_id,
1619+
self.get_expert_parallelism_size(),
15671620
)
15681621
intermediate_output = jax.lax.ragged_all_to_all(
15691622
local_output,
@@ -1575,14 +1628,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15751628
axis_name=self._expert_parallelism_name,
15761629
)
15771630
else:
1578-
# If bach is replicated across EP shards then each shard should send
1631+
# If batch is replicated across EP shards then each shard should send
15791632
# 0..local_shard_size data to the other shards and receive the
1580-
# local_shard data from all of the other shards using
1581-
# ragged_all_to_all.
1633+
# local_shard data from all of the other shards using ragged_all_to_all.
15821634
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1583-
reshaped_group_sizes, # pylint: disable=undefined-variable
1584-
expert_shard_id,
1585-
num_expert_parallelism,
1635+
route_metadata.reshaped_group_sizes,
1636+
route_metadata.expert_shard_id,
1637+
self.get_expert_parallelism_size(),
15861638
is_batch_sharded=False,
15871639
)
15881640
intermediate_output = jax.lax.ragged_all_to_all(
@@ -1597,15 +1649,15 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15971649

15981650
output = self.unpermute(
15991651
intermediate_output,
1600-
sorted_selected_experts,
1601-
weights,
1652+
routing.sorted_selected_experts,
1653+
routing.weights,
16021654
batch_size=batch_size,
16031655
sequence_length=sequence_length,
16041656
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
1605-
group_sizes=group_sizes,
1657+
group_sizes=routing.group_sizes,
16061658
)
16071659

1608-
return output, lb_loss, bias_updates
1660+
return output, routing.lb_loss, routing.bias_updates
16091661

16101662
if self.config.moe_fsdp_use_two_stage_all_gather:
16111663
# Unshard on fsdp axis

0 commit comments

Comments
 (0)