Skip to content

Commit 149bd9a

Browse files
committed
fix
1 parent 56db4d5 commit 149bd9a

1 file changed

Lines changed: 148 additions & 95 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 148 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""MoE related Layers."""
1717

18+
import dataclasses
1819
import enum
1920
import functools
2021
import math
@@ -23,6 +24,7 @@
2324

2425
from aqt.jax.v2 import aqt_tensor as aqt
2526
from flax import nnx
27+
from flax import struct
2628
import jax
2729
from jax import ad_checkpoint as adc
2830
from jax.experimental import xla_metadata
@@ -56,6 +58,39 @@
5658
COMBINE = "combine"
5759

5860

61+
@struct.dataclass
62+
class RouteMetadata:
63+
"""EP communication state needed to undo the forward all-to-all after expert computation."""
64+
65+
# Index of this device's EP shard.
66+
expert_shard_id: int
67+
# Permutation of [tokens received by this expert shard], sorted by local expert ID.
68+
local_sorted_indices: Optional[jax.Array]
69+
# Shape [num_ep]. Aggregates group_sizes per EP shard; tracks how many local tokens are routed to each EP shard.
70+
reshaped_group_sizes: Optional[jax.Array]
71+
# Shape [num_ep, num_ep]. all_gather of reshaped_group_sizes across EP shards.
72+
# [i, j] = number of tokens from batch shard i sent to expert shard j.
73+
all_shards_group_sizes: Optional[jax.Array]
74+
75+
76+
@struct.dataclass
77+
class RouteOutput:
78+
"""Holds state of routing output"""
79+
80+
# Shape [num experts], tracks number of local tokens routed to every expert.
81+
group_sizes: jax.Array
82+
# Indices of experts chosen for each token.
83+
selected_experts: jax.Array
84+
# Tokens sorted by experts they are routed to.
85+
sorted_selected_experts: jax.Array
86+
# Weights for each of the selected experts.
87+
weights: jax.Array
88+
# Auxiliary loss for token distribution among experts.
89+
lb_loss: Optional[jax.Array]
90+
# Dynamic bias updates for loss-free load balancing, used only for Deepseek models
91+
bias_updates: Optional[jax.Array]
92+
93+
5994
def _sort_activations(
6095
inputs: jax.Array,
6196
sort_indices: jax.Array,
@@ -1300,37 +1335,15 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert):
13001335
) = get_routed_moe_shardings(is_batch_sharded_by_expert)
13011336
w0_pspec, w1_pspec, wo_pspec = maybe_aqt_partition(w0_kernel, w0_pspec, w1_kernel, w1_pspec, wo_kernel, wo_pspec)
13021337

1303-
@functools.partial(
1304-
jax.shard_map,
1305-
mesh=self.mesh,
1306-
in_specs=(
1307-
input_partition_pspec,
1308-
gate_logits_pspec,
1309-
pre_bias_logits_pspec,
1310-
w0_pspec,
1311-
w1_pspec,
1312-
wo_pspec,
1313-
w0_bias_pspec,
1314-
w1_bias_pspec,
1315-
wo_bias_pspec,
1316-
P(), # Replicate the input key
1317-
),
1318-
out_specs=(
1319-
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
1320-
P(), # Handle None or replicate the output
1321-
P(), # Handle None or replicate the output
1322-
),
1323-
check_vma=False,
1324-
)
1325-
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
1326-
batch_size, sequence_length, _ = x.shape
1327-
num_expert_parallelism = self.get_expert_parallelism_size()
1328-
if num_expert_parallelism > 1:
1329-
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name)
1330-
else:
1331-
expert_shard_id = 0
1332-
num_expert_parallelism = self.get_expert_parallelism_size()
1333-
global_group_sizes = None
1338+
def route(x, logits, pre_bias_logits, rngs):
1339+
"""Performs both across device and within device token routing/sorting"""
1340+
num_ep = self.get_expert_parallelism_size()
1341+
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name) if num_ep > 1 else None
1342+
1343+
local_sorted_indices = None
1344+
all_shards_group_sizes = None
1345+
reshaped_group_sizes = None
1346+
13341347
if self.config.use_ring_of_experts:
13351348
# The ring-of-experts strategy first duplicates the inputs to all
13361349
# expert shards, and then routes within each shard.
@@ -1342,7 +1355,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13421355
)
13431356

13441357
# "Route" tokens within each shard.
1345-
num_experts_per_shard = self.config.num_experts // num_expert_parallelism
1358+
num_experts_per_shard = self.config.num_experts // num_ep
13461359
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
13471360
x,
13481361
logits,
@@ -1362,10 +1375,10 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13621375
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs
13631376
)
13641377

1365-
if num_expert_parallelism > 1:
1378+
if num_ep > 1:
13661379
batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data"
13671380
# get group sizes for all shards
1368-
local_expert_size = self.config.num_experts // num_expert_parallelism
1381+
local_expert_size = self.config.num_experts // num_ep
13691382
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
13701383
global_group_sizes = group_sizes
13711384

@@ -1374,12 +1387,12 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13741387
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
13751388
all_shards_group_sizes,
13761389
expert_shard_id,
1377-
num_expert_parallelism,
1390+
num_ep,
13781391
)
13791392

13801393
buffer_size = self.get_ragged_buffer_size(
13811394
jnp.shape(x)[0],
1382-
num_expert_parallelism,
1395+
num_ep,
13831396
self.config.num_experts,
13841397
self.config.num_experts_per_tok,
13851398
self.config.ragged_buffer_factor,
@@ -1416,35 +1429,40 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
14161429
use_ragged_sort=self.config.use_ragged_sort,
14171430
)
14181431

1419-
if self.config.mlp_bias:
1420-
w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias)
1421-
1422-
def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
1423-
if pspec_dim_axes is None:
1424-
return []
1425-
axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes
1426-
active = []
1427-
for ax in axes:
1428-
if ax and self.mesh.shape.get(ax, 1) > 1:
1429-
active.append((ax, tensor_dim_index))
1430-
return active
1432+
return (
1433+
x,
1434+
RouteOutput(
1435+
group_sizes=group_sizes,
1436+
selected_experts=selected_experts,
1437+
sorted_selected_experts=sorted_selected_experts,
1438+
weights=weights,
1439+
lb_loss=lb_loss,
1440+
bias_updates=bias_updates,
1441+
),
1442+
RouteMetadata(
1443+
expert_shard_id=expert_shard_id,
1444+
local_sorted_indices=local_sorted_indices,
1445+
all_shards_group_sizes=all_shards_group_sizes,
1446+
reshaped_group_sizes=reshaped_group_sizes,
1447+
),
1448+
)
14311449

1450+
def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
1451+
if pspec_dim_axes is None:
1452+
return []
1453+
axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes
1454+
active = []
1455+
for ax in axes:
1456+
if ax and self.mesh.shape.get(ax, 1) > 1:
1457+
active.append((ax, tensor_dim_index))
1458+
return active
1459+
1460+
def get_wi_gmm_params():
14321461
wi_gather_axes = []
1433-
wo_gather_axes = []
1434-
14351462
if weight_gather:
14361463
# wi [Experts, In, Hidden] -> Gather Exp(0) and Hidden(2)
14371464
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[0], 0))
14381465
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[2], 2))
1439-
1440-
# wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1441-
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0))
1442-
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1))
1443-
gmm_fn = functools.partial(
1444-
gmm,
1445-
group_sizes=group_sizes,
1446-
expert_assignments=selected_experts,
1447-
)
14481466
wi_tile_size = (
14491467
self.config.wi_tile_fwd_batch_seq, # m (LHS batch)
14501468
self.config.wi_tile_fwd_embed_dim, # k (contracting)
@@ -1456,6 +1474,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14561474
self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim
14571475
self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim
14581476
)
1477+
return wi_gather_axes, wi_tile_size
1478+
1479+
def get_wo_gmm_params():
1480+
wo_gather_axes = []
1481+
if weight_gather:
1482+
# wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1483+
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0))
1484+
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1))
14591485
wo_tile_size = (
14601486
self.config.wo_tile_fwd_batch_seq, # m (LHS batch)
14611487
self.config.wo_tile_fwd_mlp_dim, # k (contracting)
@@ -1467,32 +1493,59 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14671493
self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim
14681494
self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
14691495
)
1496+
return wo_gather_axes, wo_tile_size
14701497

1471-
layer_w0 = gmm_fn(
1472-
x,
1473-
w0,
1474-
tiling=wi_tile_size,
1475-
weight_gather_axes=wi_gather_axes,
1476-
)
1498+
def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather):
1499+
"""Run the two up-projections (gate + up) and apply the FFN activation."""
1500+
wi_gather_axes, wi_tile_size = get_wi_gmm_params()
1501+
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
14771502
if self.get_tensor_transpose_parallelism_size() > 1:
14781503
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
14791504
if self.config.mlp_bias:
14801505
layer_w0 = layer_w0 + w0_bias
14811506
layer_w0 = adc.checkpoint_name(adc.checkpoint_name(layer_w0, "mlpwi_0"), "moe_mlpwi_0")
14821507

1483-
layer_w1 = gmm_fn(
1484-
x,
1485-
w1,
1486-
tiling=wi_tile_size,
1487-
weight_gather_axes=wi_gather_axes,
1488-
)
1508+
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
14891509
if self.get_tensor_transpose_parallelism_size() > 1:
14901510
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
14911511
if self.config.mlp_bias:
14921512
layer_w1 = layer_w1 + w1_bias
1493-
layer_w1 = adc.checkpoint_name(adc.checkpoint_name(layer_w1, "mlpwi_1"), "moe_mlpwi_1")
1494-
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
1513+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
1514+
return self.apply_ffn_activation(layer_w0, layer_w1)
1515+
1516+
@functools.partial(
1517+
jax.shard_map,
1518+
mesh=self.mesh,
1519+
in_specs=(
1520+
input_partition_pspec,
1521+
gate_logits_pspec,
1522+
pre_bias_logits_pspec,
1523+
w0_pspec,
1524+
w1_pspec,
1525+
wo_pspec,
1526+
w0_bias_pspec,
1527+
w1_bias_pspec,
1528+
wo_bias_pspec,
1529+
P(), # Replicate the input key
1530+
),
1531+
out_specs=(
1532+
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
1533+
P(), # Handle None or replicate the output
1534+
P(), # Handle None or replicate the output
1535+
),
1536+
check_vma=False,
1537+
)
1538+
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
1539+
batch_size, sequence_length, _ = x.shape
1540+
x, routing, route_metadata = route(x, logits, pre_bias_logits, rngs)
14951541

1542+
if self.config.mlp_bias:
1543+
w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias)
1544+
1545+
gmm_fn = functools.partial(gmm, group_sizes=routing.group_sizes, expert_assignments=routing.selected_experts)
1546+
intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather)
1547+
1548+
wo_gather_axes, wo_tile_size = get_wo_gmm_params()
14961549
intermediate_output = gmm_fn(
14971550
intermediate_layer,
14981551
wo,
@@ -1509,18 +1562,18 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15091562

15101563
if self.config.use_ring_of_experts:
15111564
# Set the outputs of tokens which were not processed to 0.
1512-
mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(group_sizes)
1565+
mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(routing.group_sizes)
15131566
intermediate_output = jnp.where(mask[:, None], intermediate_output, 0)
15141567

15151568
# Unsort and deduplicate the outputs locally.
15161569
output = self.unpermute(
15171570
intermediate_output,
1518-
sorted_selected_experts,
1519-
weights,
1571+
routing.sorted_selected_experts,
1572+
routing.weights,
15201573
batch_size=batch_size,
15211574
sequence_length=sequence_length,
15221575
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
1523-
group_sizes=group_sizes,
1576+
group_sizes=routing.group_sizes,
15241577
)
15251578

15261579
# Sum up the partial outputs across the expert shards.
@@ -1530,9 +1583,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15301583
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
15311584

15321585
else:
1533-
if num_expert_parallelism > 1:
1586+
if self.get_expert_parallelism_size() > 1:
15341587
original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok
1535-
if sorted_selected_experts.shape[0] != original_inputs_first_dim:
1588+
if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim:
15361589
raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!")
15371590
output_shape = jax.lax.empty(
15381591
(
@@ -1548,22 +1601,23 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15481601
# Mirror the ragged-prefix gather used in `local_permute`. The
15491602
# un-permute can use the same valid-prefix length because the
15501603
# routed token count is identical for forward and backward.
1551-
valid_end = jnp.sum(group_sizes).astype(jnp.int32) # pylint: disable=undefined-variable
1604+
valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32)
15521605
local_output = a2a_ragged_unsort(
15531606
intermediate_output,
1554-
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
1607+
jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable
15551608
valid_end,
15561609
)
15571610
else:
15581611
local_output = _sort_activations(
15591612
intermediate_output,
1560-
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
1613+
jnp.argsort(route_metadata.local_sorted_indices),
15611614
self.config.use_custom_sort_vjp,
15621615
)
1616+
15631617
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1564-
jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable
1565-
expert_shard_id,
1566-
num_expert_parallelism,
1618+
jnp.transpose(route_metadata.all_shards_group_sizes),
1619+
route_metadata.expert_shard_id,
1620+
self.get_expert_parallelism_size(),
15671621
)
15681622
intermediate_output = jax.lax.ragged_all_to_all(
15691623
local_output,
@@ -1575,14 +1629,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15751629
axis_name=self._expert_parallelism_name,
15761630
)
15771631
else:
1578-
# If bach is replicated across EP shards then each shard should send
1632+
# If batch is replicated across EP shards then each shard should send
15791633
# 0..local_shard_size data to the other shards and receive the
1580-
# local_shard data from all of the other shards using
1581-
# ragged_all_to_all.
1634+
# local_shard data from all of the other shards using ragged_all_to_all.
15821635
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1583-
reshaped_group_sizes, # pylint: disable=undefined-variable
1584-
expert_shard_id,
1585-
num_expert_parallelism,
1636+
route_metadata.reshaped_group_sizes,
1637+
route_metadata.expert_shard_id,
1638+
self.get_expert_parallelism_size(),
15861639
is_batch_sharded=False,
15871640
)
15881641
intermediate_output = jax.lax.ragged_all_to_all(
@@ -1597,15 +1650,15 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15971650

15981651
output = self.unpermute(
15991652
intermediate_output,
1600-
sorted_selected_experts,
1601-
weights,
1653+
routing.sorted_selected_experts,
1654+
routing.weights,
16021655
batch_size=batch_size,
16031656
sequence_length=sequence_length,
16041657
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
1605-
group_sizes=group_sizes,
1658+
group_sizes=routing.group_sizes,
16061659
)
16071660

1608-
return output, lb_loss, bias_updates
1661+
return output, routing.lb_loss, routing.bias_updates
16091662

16101663
if self.config.moe_fsdp_use_two_stage_all_gather:
16111664
# Unshard on fsdp axis

0 commit comments

Comments
 (0)