2323
2424from aqt .jax .v2 import aqt_tensor as aqt
2525from flax import nnx
26+ from flax import struct
2627import jax
2728from jax import ad_checkpoint as adc
2829from jax .experimental import xla_metadata
5657COMBINE = "combine"
5758
5859
60+ @struct .dataclass
61+ class RouteMetadata :
62+ """EP communication state needed to undo the forward all-to-all after expert computation."""
63+
64+ # Index of this device's EP shard.
65+ expert_shard_id : int
66+ # Permutation of [tokens received by this expert shard], sorted by local expert ID.
67+ local_sorted_indices : Optional [jax .Array ]
68+ # Shape [num_ep]. Aggregates group_sizes per EP shard; tracks how many local tokens are routed to each EP shard.
69+ reshaped_group_sizes : Optional [jax .Array ]
70+ # Shape [num_ep, num_ep]. all_gather of reshaped_group_sizes across EP shards.
71+ # [i, j] = number of tokens from batch shard i sent to expert shard j.
72+ all_shards_group_sizes : Optional [jax .Array ]
73+
74+
75+ @struct .dataclass
76+ class RouteOutput :
77+ """Holds state of routing output"""
78+
79+ # Shape [num experts], tracks number of local tokens routed to every expert.
80+ group_sizes : jax .Array
81+ # Indices of experts chosen for each token.
82+ selected_experts : jax .Array
83+ # Tokens sorted by experts they are routed to.
84+ sorted_selected_experts : jax .Array
85+ # Weights for each of the selected experts.
86+ weights : jax .Array
87+ # Auxiliary loss for token distribution among experts.
88+ lb_loss : Optional [jax .Array ]
89+ # Dynamic bias updates for loss-free load balancing, used only for Deepseek models
90+ bias_updates : Optional [jax .Array ]
91+
92+
5993def _sort_activations (
6094 inputs : jax .Array ,
6195 sort_indices : jax .Array ,
@@ -1300,37 +1334,15 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert):
13001334 ) = get_routed_moe_shardings (is_batch_sharded_by_expert )
13011335 w0_pspec , w1_pspec , wo_pspec = maybe_aqt_partition (w0_kernel , w0_pspec , w1_kernel , w1_pspec , wo_kernel , wo_pspec )
13021336
1303- @functools .partial (
1304- jax .shard_map ,
1305- mesh = self .mesh ,
1306- in_specs = (
1307- input_partition_pspec ,
1308- gate_logits_pspec ,
1309- pre_bias_logits_pspec ,
1310- w0_pspec ,
1311- w1_pspec ,
1312- wo_pspec ,
1313- w0_bias_pspec ,
1314- w1_bias_pspec ,
1315- wo_bias_pspec ,
1316- P (), # Replicate the input key
1317- ),
1318- out_specs = (
1319- self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length" , "activation_embed" )),
1320- P (), # Handle None or replicate the output
1321- P (), # Handle None or replicate the output
1322- ),
1323- check_vma = False ,
1324- )
1325- def wrapper (x , logits , pre_bias_logits , w0 , w1 , wo , w0_bias , w1_bias , wo_bias , rngs ):
1326- batch_size , sequence_length , _ = x .shape
1327- num_expert_parallelism = self .get_expert_parallelism_size ()
1328- if num_expert_parallelism > 1 :
1329- expert_shard_id = jax .lax .axis_index (self ._expert_parallelism_name )
1330- else :
1331- expert_shard_id = 0
1332- num_expert_parallelism = self .get_expert_parallelism_size ()
1333- global_group_sizes = None
1337+ def route (x , logits , pre_bias_logits , rngs ):
1338+ """Performs both across device and within device token routing/sorting"""
1339+ num_ep = self .get_expert_parallelism_size ()
1340+ expert_shard_id = jax .lax .axis_index (self ._expert_parallelism_name ) if num_ep > 1 else 0
1341+
1342+ local_sorted_indices = None
1343+ all_shards_group_sizes = None
1344+ reshaped_group_sizes = None
1345+
13341346 if self .config .use_ring_of_experts :
13351347 # The ring-of-experts strategy first duplicates the inputs to all
13361348 # expert shards, and then routes within each shard.
@@ -1342,7 +1354,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13421354 )
13431355
13441356 # "Route" tokens within each shard.
1345- num_experts_per_shard = self .config .num_experts // num_expert_parallelism
1357+ num_experts_per_shard = self .config .num_experts // num_ep
13461358 x , sorted_selected_experts , weights , group_sizes , selected_experts , lb_loss , bias_updates = self .permute (
13471359 x ,
13481360 logits ,
@@ -1362,10 +1374,10 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13621374 x , logits , pre_bias_logits , self .config .use_custom_sort_vjp , rngs
13631375 )
13641376
1365- if num_expert_parallelism > 1 :
1377+ if num_ep > 1 :
13661378 batch_axis = self ._expert_parallelism_name if is_batch_sharded_by_expert else "data"
13671379 # get group sizes for all shards
1368- local_expert_size = self .config .num_experts // num_expert_parallelism
1380+ local_expert_size = self .config .num_experts // num_ep
13691381 reshaped_group_sizes = jnp .sum (group_sizes .reshape (- 1 , local_expert_size ), axis = 1 )
13701382 global_group_sizes = group_sizes
13711383
@@ -1374,12 +1386,12 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13741386 input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
13751387 all_shards_group_sizes ,
13761388 expert_shard_id ,
1377- num_expert_parallelism ,
1389+ num_ep ,
13781390 )
13791391
13801392 buffer_size = self .get_ragged_buffer_size (
13811393 jnp .shape (x )[0 ],
1382- num_expert_parallelism ,
1394+ num_ep ,
13831395 self .config .num_experts ,
13841396 self .config .num_experts_per_tok ,
13851397 self .config .ragged_buffer_factor ,
@@ -1416,35 +1428,40 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
14161428 use_ragged_sort = self .config .use_ragged_sort ,
14171429 )
14181430
1419- if self .config .mlp_bias :
1420- w0_bias , w1_bias , wo_bias = self .transform_bias (selected_experts , w0_bias , w1_bias , wo_bias )
1421-
1422- def get_active_sharding_axes (pspec_dim_axes , tensor_dim_index ):
1423- if pspec_dim_axes is None :
1424- return []
1425- axes = (pspec_dim_axes ,) if isinstance (pspec_dim_axes , str ) else pspec_dim_axes
1426- active = []
1427- for ax in axes :
1428- if ax and self .mesh .shape .get (ax , 1 ) > 1 :
1429- active .append ((ax , tensor_dim_index ))
1430- return active
1431+ return (
1432+ x ,
1433+ RouteOutput (
1434+ group_sizes = group_sizes ,
1435+ selected_experts = selected_experts ,
1436+ sorted_selected_experts = sorted_selected_experts ,
1437+ weights = weights ,
1438+ lb_loss = lb_loss ,
1439+ bias_updates = bias_updates ,
1440+ ),
1441+ RouteMetadata (
1442+ expert_shard_id = expert_shard_id ,
1443+ local_sorted_indices = local_sorted_indices ,
1444+ all_shards_group_sizes = all_shards_group_sizes ,
1445+ reshaped_group_sizes = reshaped_group_sizes ,
1446+ ),
1447+ )
14311448
1449+ def get_active_sharding_axes (pspec_dim_axes , tensor_dim_index ):
1450+ if pspec_dim_axes is None :
1451+ return []
1452+ axes = (pspec_dim_axes ,) if isinstance (pspec_dim_axes , str ) else pspec_dim_axes
1453+ active = []
1454+ for ax in axes :
1455+ if ax and self .mesh .shape .get (ax , 1 ) > 1 :
1456+ active .append ((ax , tensor_dim_index ))
1457+ return active
1458+
1459+ def get_wi_gmm_params ():
14321460 wi_gather_axes = []
1433- wo_gather_axes = []
1434-
14351461 if weight_gather :
14361462 # wi [Experts, In, Hidden] -> Gather Exp(0) and Hidden(2)
14371463 wi_gather_axes .extend (get_active_sharding_axes (w0_pspec [0 ], 0 ))
14381464 wi_gather_axes .extend (get_active_sharding_axes (w0_pspec [2 ], 2 ))
1439-
1440- # wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1441- wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [0 ], 0 ))
1442- wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [1 ], 1 ))
1443- gmm_fn = functools .partial (
1444- gmm ,
1445- group_sizes = group_sizes ,
1446- expert_assignments = selected_experts ,
1447- )
14481465 wi_tile_size = (
14491466 self .config .wi_tile_fwd_batch_seq , # m (LHS batch)
14501467 self .config .wi_tile_fwd_embed_dim , # k (contracting)
@@ -1456,6 +1473,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14561473 self .config .wi_tile_drhs_embed_dim , # Called k in megablox, but this is LHS batch dim
14571474 self .config .wi_tile_drhs_mlp_dim , # Called n in megablox, and indeed is RHS batch dim
14581475 )
1476+ return wi_gather_axes , wi_tile_size
1477+
1478+ def get_wo_gmm_params ():
1479+ wo_gather_axes = []
1480+ if weight_gather :
1481+ # wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1482+ wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [0 ], 0 ))
1483+ wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [1 ], 1 ))
14591484 wo_tile_size = (
14601485 self .config .wo_tile_fwd_batch_seq , # m (LHS batch)
14611486 self .config .wo_tile_fwd_mlp_dim , # k (contracting)
@@ -1467,32 +1492,59 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14671492 self .config .wo_tile_drhs_mlp_dim , # Called k in megablox, but this is LHS batch dim
14681493 self .config .wo_tile_drhs_embed_dim , # Called n in megablox, and indeed is the RHS batch dim
14691494 )
1495+ return wo_gather_axes , wo_tile_size
14701496
1471- layer_w0 = gmm_fn (
1472- x ,
1473- w0 ,
1474- tiling = wi_tile_size ,
1475- weight_gather_axes = wi_gather_axes ,
1476- )
1497+ def gmm_up (x , w0 , w1 , w0_bias , w1_bias , gmm_fn , weight_gather ):
1498+ """Run the two up-projections (gate + up) and apply the FFN activation."""
1499+ wi_gather_axes , wi_tile_size = get_wi_gmm_params ()
1500+ layer_w0 = gmm_fn (x , w0 , tiling = wi_tile_size , weight_gather_axes = wi_gather_axes )
14771501 if self .get_tensor_transpose_parallelism_size () > 1 :
14781502 layer_w0 = jax .lax .psum (layer_w0 , "tensor_transpose" )
14791503 if self .config .mlp_bias :
14801504 layer_w0 = layer_w0 + w0_bias
14811505 layer_w0 = adc .checkpoint_name (adc .checkpoint_name (layer_w0 , "mlpwi_0" ), "moe_mlpwi_0" )
14821506
1483- layer_w1 = gmm_fn (
1484- x ,
1485- w1 ,
1486- tiling = wi_tile_size ,
1487- weight_gather_axes = wi_gather_axes ,
1488- )
1507+ layer_w1 = gmm_fn (x , w1 , tiling = wi_tile_size , weight_gather_axes = wi_gather_axes )
14891508 if self .get_tensor_transpose_parallelism_size () > 1 :
14901509 layer_w1 = jax .lax .psum (layer_w1 , "tensor_transpose" )
14911510 if self .config .mlp_bias :
14921511 layer_w1 = layer_w1 + w1_bias
1493- layer_w1 = adc .checkpoint_name (adc .checkpoint_name (layer_w1 , "mlpwi_1" ), "moe_mlpwi_1" )
1494- intermediate_layer = self .apply_ffn_activation (layer_w0 , layer_w1 )
1512+ layer_w1 = adc .checkpoint_name (layer_w1 , "moe_mlpwi_1" )
1513+ return self .apply_ffn_activation (layer_w0 , layer_w1 )
1514+
1515+ @functools .partial (
1516+ jax .shard_map ,
1517+ mesh = self .mesh ,
1518+ in_specs = (
1519+ input_partition_pspec ,
1520+ gate_logits_pspec ,
1521+ pre_bias_logits_pspec ,
1522+ w0_pspec ,
1523+ w1_pspec ,
1524+ wo_pspec ,
1525+ w0_bias_pspec ,
1526+ w1_bias_pspec ,
1527+ wo_bias_pspec ,
1528+ P (), # Replicate the input key
1529+ ),
1530+ out_specs = (
1531+ self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length" , "activation_embed" )),
1532+ P (), # Handle None or replicate the output
1533+ P (), # Handle None or replicate the output
1534+ ),
1535+ check_vma = False ,
1536+ )
1537+ def wrapper (x , logits , pre_bias_logits , w0 , w1 , wo , w0_bias , w1_bias , wo_bias , rngs ):
1538+ batch_size , sequence_length , _ = x .shape
1539+ x , routing , route_metadata = route (x , logits , pre_bias_logits , rngs )
14951540
1541+ if self .config .mlp_bias :
1542+ w0_bias , w1_bias , wo_bias = self .transform_bias (routing .selected_experts , w0_bias , w1_bias , wo_bias )
1543+
1544+ gmm_fn = functools .partial (gmm , group_sizes = routing .group_sizes , expert_assignments = routing .selected_experts )
1545+ intermediate_layer = gmm_up (x , w0 , w1 , w0_bias , w1_bias , gmm_fn , weight_gather )
1546+
1547+ wo_gather_axes , wo_tile_size = get_wo_gmm_params ()
14961548 intermediate_output = gmm_fn (
14971549 intermediate_layer ,
14981550 wo ,
@@ -1509,18 +1561,18 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15091561
15101562 if self .config .use_ring_of_experts :
15111563 # Set the outputs of tokens which were not processed to 0.
1512- mask = jnp .arange (intermediate_output .shape [0 ]) < jnp .sum (group_sizes )
1564+ mask = jnp .arange (intermediate_output .shape [0 ]) < jnp .sum (routing . group_sizes )
15131565 intermediate_output = jnp .where (mask [:, None ], intermediate_output , 0 )
15141566
15151567 # Unsort and deduplicate the outputs locally.
15161568 output = self .unpermute (
15171569 intermediate_output ,
1518- sorted_selected_experts ,
1519- weights ,
1570+ routing . sorted_selected_experts ,
1571+ routing . weights ,
15201572 batch_size = batch_size ,
15211573 sequence_length = sequence_length ,
15221574 use_custom_sort_vjp = self .config .use_custom_sort_vjp ,
1523- group_sizes = group_sizes ,
1575+ group_sizes = routing . group_sizes ,
15241576 )
15251577
15261578 # Sum up the partial outputs across the expert shards.
@@ -1530,9 +1582,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15301582 output = jax .lax .psum_scatter (output , self ._expert_parallelism_name , scatter_dimension = 0 , tiled = True )
15311583
15321584 else :
1533- if num_expert_parallelism > 1 :
1585+ if self . get_expert_parallelism_size () > 1 :
15341586 original_inputs_first_dim = batch_size * sequence_length * self .config .num_experts_per_tok
1535- if sorted_selected_experts .shape [0 ] != original_inputs_first_dim :
1587+ if routing . sorted_selected_experts .shape [0 ] != original_inputs_first_dim :
15361588 raise ValueError ("original_inputs_first_dim does not match the original tensor" " shape!" )
15371589 output_shape = jax .lax .empty (
15381590 (
@@ -1548,22 +1600,23 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15481600 # Mirror the ragged-prefix gather used in `local_permute`. The
15491601 # un-permute can use the same valid-prefix length because the
15501602 # routed token count is identical for forward and backward.
1551- valid_end = jnp .sum (group_sizes ).astype (jnp .int32 ) # pylint: disable=undefined-variable
1603+ valid_end = jnp .sum (routing . group_sizes ).astype (jnp .int32 )
15521604 local_output = a2a_ragged_unsort (
15531605 intermediate_output ,
1554- jnp .argsort (local_sorted_indices ), # pylint: disable=undefined-variable
1606+ jnp .argsort (route_metadata . local_sorted_indices ), # pylint: disable=undefined-variable
15551607 valid_end ,
15561608 )
15571609 else :
15581610 local_output = _sort_activations (
15591611 intermediate_output ,
1560- jnp .argsort (local_sorted_indices ), # pylint: disable=undefined-variable
1612+ jnp .argsort (route_metadata . local_sorted_indices ),
15611613 self .config .use_custom_sort_vjp ,
15621614 )
1615+
15631616 input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1564- jnp .transpose (all_shards_group_sizes ), # pylint: disable=undefined-variable
1565- expert_shard_id ,
1566- num_expert_parallelism ,
1617+ jnp .transpose (route_metadata . all_shards_group_sizes ),
1618+ route_metadata . expert_shard_id ,
1619+ self . get_expert_parallelism_size () ,
15671620 )
15681621 intermediate_output = jax .lax .ragged_all_to_all (
15691622 local_output ,
@@ -1575,14 +1628,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15751628 axis_name = self ._expert_parallelism_name ,
15761629 )
15771630 else :
1578- # If bach is replicated across EP shards then each shard should send
1631+ # If batch is replicated across EP shards then each shard should send
15791632 # 0..local_shard_size data to the other shards and receive the
1580- # local_shard data from all of the other shards using
1581- # ragged_all_to_all.
1633+ # local_shard data from all of the other shards using ragged_all_to_all.
15821634 input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1583- reshaped_group_sizes , # pylint: disable=undefined-variable
1584- expert_shard_id ,
1585- num_expert_parallelism ,
1635+ route_metadata . reshaped_group_sizes ,
1636+ route_metadata . expert_shard_id ,
1637+ self . get_expert_parallelism_size () ,
15861638 is_batch_sharded = False ,
15871639 )
15881640 intermediate_output = jax .lax .ragged_all_to_all (
@@ -1597,15 +1649,15 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15971649
15981650 output = self .unpermute (
15991651 intermediate_output ,
1600- sorted_selected_experts ,
1601- weights ,
1652+ routing . sorted_selected_experts ,
1653+ routing . weights ,
16021654 batch_size = batch_size ,
16031655 sequence_length = sequence_length ,
16041656 use_custom_sort_vjp = self .config .use_custom_sort_vjp ,
1605- group_sizes = group_sizes ,
1657+ group_sizes = routing . group_sizes ,
16061658 )
16071659
1608- return output , lb_loss , bias_updates
1660+ return output , routing . lb_loss , routing . bias_updates
16091661
16101662 if self .config .moe_fsdp_use_two_stage_all_gather :
16111663 # Unshard on fsdp axis
0 commit comments