1515
1616"""MoE related Layers."""
1717
18+ import dataclasses
1819import enum
1920import functools
2021import math
2324
2425from aqt .jax .v2 import aqt_tensor as aqt
2526from flax import nnx
27+ from flax import struct
2628import jax
2729from jax import ad_checkpoint as adc
2830from jax .experimental import xla_metadata
5658COMBINE = "combine"
5759
5860
61+ @struct .dataclass
62+ class RouteMetadata :
63+ """EP communication state needed to undo the forward all-to-all after expert computation."""
64+
65+ # Index of this device's EP shard.
66+ expert_shard_id : int
67+ # Permutation of [tokens received by this expert shard], sorted by local expert ID.
68+ local_sorted_indices : Optional [jax .Array ]
69+ # Shape [num_ep]. Aggregates group_sizes per EP shard; tracks how many local tokens are routed to each EP shard.
70+ reshaped_group_sizes : Optional [jax .Array ]
71+ # Shape [num_ep, num_ep]. all_gather of reshaped_group_sizes across EP shards.
72+ # [i, j] = number of tokens from batch shard i sent to expert shard j.
73+ all_shards_group_sizes : Optional [jax .Array ]
74+
75+
76+ @struct .dataclass
77+ class RouteOutput :
78+ """Holds state of routing output"""
79+
80+ # Shape [num experts], tracks number of local tokens routed to every expert.
81+ group_sizes : jax .Array
82+ # Indices of experts chosen for each token.
83+ selected_experts : jax .Array
84+ # Tokens sorted by experts they are routed to.
85+ sorted_selected_experts : jax .Array
86+ # Weights for each of the selected experts.
87+ weights : jax .Array
88+ # Auxiliary loss for token distribution among experts.
89+ lb_loss : Optional [jax .Array ]
90+ # Dynamic bias updates for loss-free load balancing, used only for Deepseek models
91+ bias_updates : Optional [jax .Array ]
92+
93+
5994def _sort_activations (
6095 inputs : jax .Array ,
6196 sort_indices : jax .Array ,
@@ -1300,37 +1335,15 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert):
13001335 ) = get_routed_moe_shardings (is_batch_sharded_by_expert )
13011336 w0_pspec , w1_pspec , wo_pspec = maybe_aqt_partition (w0_kernel , w0_pspec , w1_kernel , w1_pspec , wo_kernel , wo_pspec )
13021337
1303- @functools .partial (
1304- jax .shard_map ,
1305- mesh = self .mesh ,
1306- in_specs = (
1307- input_partition_pspec ,
1308- gate_logits_pspec ,
1309- pre_bias_logits_pspec ,
1310- w0_pspec ,
1311- w1_pspec ,
1312- wo_pspec ,
1313- w0_bias_pspec ,
1314- w1_bias_pspec ,
1315- wo_bias_pspec ,
1316- P (), # Replicate the input key
1317- ),
1318- out_specs = (
1319- self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length" , "activation_embed" )),
1320- P (), # Handle None or replicate the output
1321- P (), # Handle None or replicate the output
1322- ),
1323- check_vma = False ,
1324- )
1325- def wrapper (x , logits , pre_bias_logits , w0 , w1 , wo , w0_bias , w1_bias , wo_bias , rngs ):
1326- batch_size , sequence_length , _ = x .shape
1327- num_expert_parallelism = self .get_expert_parallelism_size ()
1328- if num_expert_parallelism > 1 :
1329- expert_shard_id = jax .lax .axis_index (self ._expert_parallelism_name )
1330- else :
1331- expert_shard_id = 0
1332- num_expert_parallelism = self .get_expert_parallelism_size ()
1333- global_group_sizes = None
1338+ def route (x , logits , pre_bias_logits , rngs ):
1339+ """Performs both across device and within device token routing/sorting"""
1340+ num_ep = self .get_expert_parallelism_size ()
1341+ expert_shard_id = jax .lax .axis_index (self ._expert_parallelism_name ) if num_ep > 1 else None
1342+
1343+ local_sorted_indices = None
1344+ all_shards_group_sizes = None
1345+ reshaped_group_sizes = None
1346+
13341347 if self .config .use_ring_of_experts :
13351348 # The ring-of-experts strategy first duplicates the inputs to all
13361349 # expert shards, and then routes within each shard.
@@ -1342,7 +1355,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13421355 )
13431356
13441357 # "Route" tokens within each shard.
1345- num_experts_per_shard = self .config .num_experts // num_expert_parallelism
1358+ num_experts_per_shard = self .config .num_experts // num_ep
13461359 x , sorted_selected_experts , weights , group_sizes , selected_experts , lb_loss , bias_updates = self .permute (
13471360 x ,
13481361 logits ,
@@ -1362,10 +1375,10 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13621375 x , logits , pre_bias_logits , self .config .use_custom_sort_vjp , rngs
13631376 )
13641377
1365- if num_expert_parallelism > 1 :
1378+ if num_ep > 1 :
13661379 batch_axis = self ._expert_parallelism_name if is_batch_sharded_by_expert else "data"
13671380 # get group sizes for all shards
1368- local_expert_size = self .config .num_experts // num_expert_parallelism
1381+ local_expert_size = self .config .num_experts // num_ep
13691382 reshaped_group_sizes = jnp .sum (group_sizes .reshape (- 1 , local_expert_size ), axis = 1 )
13701383 global_group_sizes = group_sizes
13711384
@@ -1374,12 +1387,12 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
13741387 input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
13751388 all_shards_group_sizes ,
13761389 expert_shard_id ,
1377- num_expert_parallelism ,
1390+ num_ep ,
13781391 )
13791392
13801393 buffer_size = self .get_ragged_buffer_size (
13811394 jnp .shape (x )[0 ],
1382- num_expert_parallelism ,
1395+ num_ep ,
13831396 self .config .num_experts ,
13841397 self .config .num_experts_per_tok ,
13851398 self .config .ragged_buffer_factor ,
@@ -1416,35 +1429,40 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
14161429 use_ragged_sort = self .config .use_ragged_sort ,
14171430 )
14181431
1419- if self .config .mlp_bias :
1420- w0_bias , w1_bias , wo_bias = self .transform_bias (selected_experts , w0_bias , w1_bias , wo_bias )
1421-
1422- def get_active_sharding_axes (pspec_dim_axes , tensor_dim_index ):
1423- if pspec_dim_axes is None :
1424- return []
1425- axes = (pspec_dim_axes ,) if isinstance (pspec_dim_axes , str ) else pspec_dim_axes
1426- active = []
1427- for ax in axes :
1428- if ax and self .mesh .shape .get (ax , 1 ) > 1 :
1429- active .append ((ax , tensor_dim_index ))
1430- return active
1432+ return (
1433+ x ,
1434+ RouteOutput (
1435+ group_sizes = group_sizes ,
1436+ selected_experts = selected_experts ,
1437+ sorted_selected_experts = sorted_selected_experts ,
1438+ weights = weights ,
1439+ lb_loss = lb_loss ,
1440+ bias_updates = bias_updates ,
1441+ ),
1442+ RouteMetadata (
1443+ expert_shard_id = expert_shard_id ,
1444+ local_sorted_indices = local_sorted_indices ,
1445+ all_shards_group_sizes = all_shards_group_sizes ,
1446+ reshaped_group_sizes = reshaped_group_sizes ,
1447+ ),
1448+ )
14311449
1450+ def get_active_sharding_axes (pspec_dim_axes , tensor_dim_index ):
1451+ if pspec_dim_axes is None :
1452+ return []
1453+ axes = (pspec_dim_axes ,) if isinstance (pspec_dim_axes , str ) else pspec_dim_axes
1454+ active = []
1455+ for ax in axes :
1456+ if ax and self .mesh .shape .get (ax , 1 ) > 1 :
1457+ active .append ((ax , tensor_dim_index ))
1458+ return active
1459+
1460+ def get_wi_gmm_params ():
14321461 wi_gather_axes = []
1433- wo_gather_axes = []
1434-
14351462 if weight_gather :
14361463 # wi [Experts, In, Hidden] -> Gather Exp(0) and Hidden(2)
14371464 wi_gather_axes .extend (get_active_sharding_axes (w0_pspec [0 ], 0 ))
14381465 wi_gather_axes .extend (get_active_sharding_axes (w0_pspec [2 ], 2 ))
1439-
1440- # wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1441- wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [0 ], 0 ))
1442- wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [1 ], 1 ))
1443- gmm_fn = functools .partial (
1444- gmm ,
1445- group_sizes = group_sizes ,
1446- expert_assignments = selected_experts ,
1447- )
14481466 wi_tile_size = (
14491467 self .config .wi_tile_fwd_batch_seq , # m (LHS batch)
14501468 self .config .wi_tile_fwd_embed_dim , # k (contracting)
@@ -1456,6 +1474,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14561474 self .config .wi_tile_drhs_embed_dim , # Called k in megablox, but this is LHS batch dim
14571475 self .config .wi_tile_drhs_mlp_dim , # Called n in megablox, and indeed is RHS batch dim
14581476 )
1477+ return wi_gather_axes , wi_tile_size
1478+
1479+ def get_wo_gmm_params ():
1480+ wo_gather_axes = []
1481+ if weight_gather :
1482+ # wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1483+ wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [0 ], 0 ))
1484+ wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [1 ], 1 ))
14591485 wo_tile_size = (
14601486 self .config .wo_tile_fwd_batch_seq , # m (LHS batch)
14611487 self .config .wo_tile_fwd_mlp_dim , # k (contracting)
@@ -1467,32 +1493,59 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14671493 self .config .wo_tile_drhs_mlp_dim , # Called k in megablox, but this is LHS batch dim
14681494 self .config .wo_tile_drhs_embed_dim , # Called n in megablox, and indeed is the RHS batch dim
14691495 )
1496+ return wo_gather_axes , wo_tile_size
14701497
1471- layer_w0 = gmm_fn (
1472- x ,
1473- w0 ,
1474- tiling = wi_tile_size ,
1475- weight_gather_axes = wi_gather_axes ,
1476- )
1498+ def gmm_up (x , w0 , w1 , w0_bias , w1_bias , gmm_fn , weight_gather ):
1499+ """Run the two up-projections (gate + up) and apply the FFN activation."""
1500+ wi_gather_axes , wi_tile_size = get_wi_gmm_params ()
1501+ layer_w0 = gmm_fn (x , w0 , tiling = wi_tile_size , weight_gather_axes = wi_gather_axes )
14771502 if self .get_tensor_transpose_parallelism_size () > 1 :
14781503 layer_w0 = jax .lax .psum (layer_w0 , "tensor_transpose" )
14791504 if self .config .mlp_bias :
14801505 layer_w0 = layer_w0 + w0_bias
14811506 layer_w0 = adc .checkpoint_name (adc .checkpoint_name (layer_w0 , "mlpwi_0" ), "moe_mlpwi_0" )
14821507
1483- layer_w1 = gmm_fn (
1484- x ,
1485- w1 ,
1486- tiling = wi_tile_size ,
1487- weight_gather_axes = wi_gather_axes ,
1488- )
1508+ layer_w1 = gmm_fn (x , w1 , tiling = wi_tile_size , weight_gather_axes = wi_gather_axes )
14891509 if self .get_tensor_transpose_parallelism_size () > 1 :
14901510 layer_w1 = jax .lax .psum (layer_w1 , "tensor_transpose" )
14911511 if self .config .mlp_bias :
14921512 layer_w1 = layer_w1 + w1_bias
1493- layer_w1 = adc .checkpoint_name (adc .checkpoint_name (layer_w1 , "mlpwi_1" ), "moe_mlpwi_1" )
1494- intermediate_layer = self .apply_ffn_activation (layer_w0 , layer_w1 )
1513+ layer_w1 = adc .checkpoint_name (layer_w1 , "moe_mlpwi_1" )
1514+ return self .apply_ffn_activation (layer_w0 , layer_w1 )
1515+
1516+ @functools .partial (
1517+ jax .shard_map ,
1518+ mesh = self .mesh ,
1519+ in_specs = (
1520+ input_partition_pspec ,
1521+ gate_logits_pspec ,
1522+ pre_bias_logits_pspec ,
1523+ w0_pspec ,
1524+ w1_pspec ,
1525+ wo_pspec ,
1526+ w0_bias_pspec ,
1527+ w1_bias_pspec ,
1528+ wo_bias_pspec ,
1529+ P (), # Replicate the input key
1530+ ),
1531+ out_specs = (
1532+ self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length" , "activation_embed" )),
1533+ P (), # Handle None or replicate the output
1534+ P (), # Handle None or replicate the output
1535+ ),
1536+ check_vma = False ,
1537+ )
1538+ def wrapper (x , logits , pre_bias_logits , w0 , w1 , wo , w0_bias , w1_bias , wo_bias , rngs ):
1539+ batch_size , sequence_length , _ = x .shape
1540+ x , routing , route_metadata = route (x , logits , pre_bias_logits , rngs )
14951541
1542+ if self .config .mlp_bias :
1543+ w0_bias , w1_bias , wo_bias = self .transform_bias (routing .selected_experts , w0_bias , w1_bias , wo_bias )
1544+
1545+ gmm_fn = functools .partial (gmm , group_sizes = routing .group_sizes , expert_assignments = routing .selected_experts )
1546+ intermediate_layer = gmm_up (x , w0 , w1 , w0_bias , w1_bias , gmm_fn , weight_gather )
1547+
1548+ wo_gather_axes , wo_tile_size = get_wo_gmm_params ()
14961549 intermediate_output = gmm_fn (
14971550 intermediate_layer ,
14981551 wo ,
@@ -1509,18 +1562,18 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15091562
15101563 if self .config .use_ring_of_experts :
15111564 # Set the outputs of tokens which were not processed to 0.
1512- mask = jnp .arange (intermediate_output .shape [0 ]) < jnp .sum (group_sizes )
1565+ mask = jnp .arange (intermediate_output .shape [0 ]) < jnp .sum (routing . group_sizes )
15131566 intermediate_output = jnp .where (mask [:, None ], intermediate_output , 0 )
15141567
15151568 # Unsort and deduplicate the outputs locally.
15161569 output = self .unpermute (
15171570 intermediate_output ,
1518- sorted_selected_experts ,
1519- weights ,
1571+ routing . sorted_selected_experts ,
1572+ routing . weights ,
15201573 batch_size = batch_size ,
15211574 sequence_length = sequence_length ,
15221575 use_custom_sort_vjp = self .config .use_custom_sort_vjp ,
1523- group_sizes = group_sizes ,
1576+ group_sizes = routing . group_sizes ,
15241577 )
15251578
15261579 # Sum up the partial outputs across the expert shards.
@@ -1530,9 +1583,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15301583 output = jax .lax .psum_scatter (output , self ._expert_parallelism_name , scatter_dimension = 0 , tiled = True )
15311584
15321585 else :
1533- if num_expert_parallelism > 1 :
1586+ if self . get_expert_parallelism_size () > 1 :
15341587 original_inputs_first_dim = batch_size * sequence_length * self .config .num_experts_per_tok
1535- if sorted_selected_experts .shape [0 ] != original_inputs_first_dim :
1588+ if routing . sorted_selected_experts .shape [0 ] != original_inputs_first_dim :
15361589 raise ValueError ("original_inputs_first_dim does not match the original tensor" " shape!" )
15371590 output_shape = jax .lax .empty (
15381591 (
@@ -1548,22 +1601,23 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15481601 # Mirror the ragged-prefix gather used in `local_permute`. The
15491602 # un-permute can use the same valid-prefix length because the
15501603 # routed token count is identical for forward and backward.
1551- valid_end = jnp .sum (group_sizes ).astype (jnp .int32 ) # pylint: disable=undefined-variable
1604+ valid_end = jnp .sum (routing . group_sizes ).astype (jnp .int32 )
15521605 local_output = a2a_ragged_unsort (
15531606 intermediate_output ,
1554- jnp .argsort (local_sorted_indices ), # pylint: disable=undefined-variable
1607+ jnp .argsort (route_metadata . local_sorted_indices ), # pylint: disable=undefined-variable
15551608 valid_end ,
15561609 )
15571610 else :
15581611 local_output = _sort_activations (
15591612 intermediate_output ,
1560- jnp .argsort (local_sorted_indices ), # pylint: disable=undefined-variable
1613+ jnp .argsort (route_metadata . local_sorted_indices ),
15611614 self .config .use_custom_sort_vjp ,
15621615 )
1616+
15631617 input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1564- jnp .transpose (all_shards_group_sizes ), # pylint: disable=undefined-variable
1565- expert_shard_id ,
1566- num_expert_parallelism ,
1618+ jnp .transpose (route_metadata . all_shards_group_sizes ),
1619+ route_metadata . expert_shard_id ,
1620+ self . get_expert_parallelism_size () ,
15671621 )
15681622 intermediate_output = jax .lax .ragged_all_to_all (
15691623 local_output ,
@@ -1575,14 +1629,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15751629 axis_name = self ._expert_parallelism_name ,
15761630 )
15771631 else :
1578- # If bach is replicated across EP shards then each shard should send
1632+ # If batch is replicated across EP shards then each shard should send
15791633 # 0..local_shard_size data to the other shards and receive the
1580- # local_shard data from all of the other shards using
1581- # ragged_all_to_all.
1634+ # local_shard data from all of the other shards using ragged_all_to_all.
15821635 input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1583- reshaped_group_sizes , # pylint: disable=undefined-variable
1584- expert_shard_id ,
1585- num_expert_parallelism ,
1636+ route_metadata . reshaped_group_sizes ,
1637+ route_metadata . expert_shard_id ,
1638+ self . get_expert_parallelism_size () ,
15861639 is_batch_sharded = False ,
15871640 )
15881641 intermediate_output = jax .lax .ragged_all_to_all (
@@ -1597,15 +1650,15 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15971650
15981651 output = self .unpermute (
15991652 intermediate_output ,
1600- sorted_selected_experts ,
1601- weights ,
1653+ routing . sorted_selected_experts ,
1654+ routing . weights ,
16021655 batch_size = batch_size ,
16031656 sequence_length = sequence_length ,
16041657 use_custom_sort_vjp = self .config .use_custom_sort_vjp ,
1605- group_sizes = group_sizes ,
1658+ group_sizes = routing . group_sizes ,
16061659 )
16071660
1608- return output , lb_loss , bias_updates
1661+ return output , routing . lb_loss , routing . bias_updates
16091662
16101663 if self .config .moe_fsdp_use_two_stage_all_gather :
16111664 # Unshard on fsdp axis
0 commit comments