@@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
278278
279279 contract_ind = tuple (range (0 , len (norm_axis )))
280280 output_sharding = (
281- create_sharding (self .mesh , ("activation_batch_no_exp " , "activation_length_no_exp " , None ))
281+ create_sharding (self .mesh , ("activation_batch_no_exp_moe " , "activation_length_no_exp_moe " , None ))
282282 if self .shard_mode == ShardMode .EXPLICIT
283283 else None
284284 )
@@ -351,16 +351,16 @@ def __init__(
351351
352352 if self .config .shard_exp_on_fsdp :
353353 # special sharding for dsv3
354- self .wi_kernel_axes = ("embed_no_exp " , None , "mlp" )
355- self .wo_kernel_axes = ("embed_no_exp " , "mlp" , None )
354+ self .wi_kernel_axes = ("embed_no_exp_moe " , None , "mlp" )
355+ self .wo_kernel_axes = ("embed_no_exp_moe " , "mlp" , None )
356356 elif self .config .use_2d_fsdp_sharding :
357- self .wi_kernel_axes = ("embed_no_exp " , "mlp" , None )
358- self .wo_kernel_axes = ("embed_no_exp " , "mlp" , None )
357+ self .wi_kernel_axes = ("embed_no_exp_moe " , "mlp" , None )
358+ self .wo_kernel_axes = ("embed_no_exp_moe " , "mlp" , None )
359359 elif self .config .use_batch_split_schedule :
360360 self .wi_kernel_axes , self .wo_kernel_axes = get_batchsplit_init_kernel_axes ()
361361 else :
362- self .wi_kernel_axes = ("exp" , "embed_no_exp " , "mlp" )
363- self .wo_kernel_axes = ("exp" , "mlp" , "embed_no_exp " )
362+ self .wi_kernel_axes = ("exp" , "embed_no_exp_moe " , "mlp" )
363+ self .wo_kernel_axes = ("exp" , "mlp" , "embed_no_exp_moe " )
364364
365365 if self .config .attention == "vllm_rpa" :
366366 # vLLM uses 'model' as the tensor parallelism axis name
@@ -437,7 +437,7 @@ def __init__(
437437
438438 if self .config .mlp_bias :
439439 wi_bias_axes = ("exp" , "activation_mlp" )
440- wo_bias_axes = ("exp" , "activation_embed " )
440+ wo_bias_axes = ("exp" , "activation_embed_moe " )
441441 wi_bias_shape = (self .num_experts , self .intermediate_dim )
442442 wo_bias_shape = (self .num_experts , self .config .emb_dim )
443443 self .wi_0_bias = nnx .Param (
@@ -1018,7 +1018,7 @@ def gmm(
10181018 self ._expert_parallelism_name
10191019 in tuple (
10201020 filter (
1021- lambda tup : tup [0 ] == "activation_batch " ,
1021+ lambda tup : tup [0 ] == "activation_batch_moe " ,
10221022 self .config .logical_axis_rules ,
10231023 )
10241024 )[
@@ -1028,26 +1028,26 @@ def gmm(
10281028 except : # pylint: disable=bare-except
10291029 is_batch_sharded_by_expert = False
10301030 if is_batch_sharded_by_expert and inputs .shape [0 ] > 1 :
1031- batch_logical_axis = "activation_batch "
1031+ batch_logical_axis = "activation_batch_moe "
10321032 else :
1033- batch_logical_axis = "activation_batch_no_exp "
1033+ batch_logical_axis = "activation_batch_no_exp_moe "
10341034
10351035 if self .get_tensor_transpose_parallelism_size () > 1 :
10361036 input_partition_pspec = self ._logical_to_mesh_axes (
1037- (batch_logical_axis , "activation_norm_length " , "activation_embed " )
1037+ (batch_logical_axis , "activation_norm_length_moe " , "activation_embed_moe " )
10381038 )
10391039 w0_bias_pspec = self ._logical_to_mesh_axes (("exp" , None ))
10401040 w1_bias_pspec = self ._logical_to_mesh_axes (("exp" , None ))
1041- wo_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_embed " ))
1041+ wo_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_embed_moe " ))
10421042 else :
1043- input_partition_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length " , None ))
1043+ input_partition_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length_moe " , None ))
10441044 w0_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_mlp" ))
10451045 w1_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_mlp" ))
1046- wo_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_embed " ))
1046+ wo_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_embed_moe " ))
10471047
1048- gate_logits_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length " , None ))
1048+ gate_logits_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length_moe " , None ))
10491049 if self .config .model_name .startswith ("deepseek3" ):
1050- pre_bias_logits_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length " , None ))
1050+ pre_bias_logits_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length_moe " , None ))
10511051 else :
10521052 # pre_bias_logits is None for non-DeepSeek v3 models
10531053 pre_bias_logits_pspec = None
@@ -1099,7 +1099,7 @@ def gmm(
10991099 P (), # Replicate the input key
11001100 ),
11011101 out_specs = (
1102- self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length " , "activation_embed " )),
1102+ self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length_moe " , "activation_embed_moe " )),
11031103 P (), # Handle None or replicate the output
11041104 P (), # Handle None or replicate the output
11051105 ),
@@ -1411,13 +1411,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14111411 wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp_with_fsdp" , "mlp_no_fsdp" , "embed_tensor_transpose" ))
14121412
14131413 if self .get_tensor_transpose_parallelism_size () > 1 :
1414- input_axes = (batch_logical_axis , "activation_norm_length " , "activation_embed " )
1414+ input_axes = (batch_logical_axis , "activation_norm_length_moe " , "activation_embed_moe " )
14151415 else :
1416- input_axes = (batch_logical_axis , "activation_norm_length " , None )
1416+ input_axes = (batch_logical_axis , "activation_norm_length_moe " , None )
14171417
1418- gate_logits_axes = (batch_logical_axis , "activation_norm_length " , None )
1418+ gate_logits_axes = (batch_logical_axis , "activation_norm_length_moe " , None )
14191419 if self .config .model_name .startswith ("deepseek3" ):
1420- pre_bias_logits_axes = (batch_logical_axis , "activation_norm_length " , None )
1420+ pre_bias_logits_axes = (batch_logical_axis , "activation_norm_length_moe " , None )
14211421 else :
14221422 pre_bias_logits_axes = None
14231423
@@ -1436,13 +1436,13 @@ def reshape_and_update_weights(self, weights, indices):
14361436 update_weights = jnp .zeros ((weights .shape [0 ], weights .shape [1 ], self .num_experts ), dtype = self .dtype )
14371437 index_update = (
14381438 self ._maybe_shard_with_logical (
1439- jnp .arange (weights .shape [0 ])[:, None , None ], ("activation_batch_no_exp " , None , None )
1439+ jnp .arange (weights .shape [0 ])[:, None , None ], ("activation_batch_no_exp_moe " , None , None )
14401440 ),
1441- self ._maybe_shard_with_logical (jnp .arange (weights .shape [1 ])[:, None ], ("activation_length_no_exp " , None )),
1441+ self ._maybe_shard_with_logical (jnp .arange (weights .shape [1 ])[:, None ], ("activation_length_no_exp_moe " , None )),
14421442 indices ,
14431443 )
14441444 weight_sharding = (
1445- create_sharding (self .mesh , ("activation_batch_no_exp " , "activation_length_no_exp " , None ))
1445+ create_sharding (self .mesh , ("activation_batch_no_exp_moe " , "activation_length_no_exp_moe " , None ))
14461446 if self .config .shard_mode == ShardMode .EXPLICIT
14471447 else None
14481448 )
@@ -1497,15 +1497,15 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs):
14971497 expert_mask ,
14981498 (batch_size , cp , sub_seq * self .num_experts_per_tok , self .num_experts ),
14991499 )
1500- expert_mask_fused = self ._maybe_shard_with_logical (expert_mask_fused , ("activation_batch " , None , None , None ))
1500+ expert_mask_fused = self ._maybe_shard_with_logical (expert_mask_fused , ("activation_batch_moe " , None , None , None ))
15011501 expert_token_count_fused = jnp .cumsum (expert_mask_fused , axis = 2 )
15021502 expert_token_count = jnp .reshape (
15031503 expert_token_count_fused ,
15041504 ((batch_size , cp , sub_seq , self .num_experts_per_tok , self .num_experts )),
15051505 )
15061506 expert_token_count = self ._maybe_shard_with_logical (
15071507 expert_token_count ,
1508- ("activation_batch " , "activation_norm_length " , None , None , None ),
1508+ ("activation_batch_moe " , "activation_norm_length_moe " , None , None , None ),
15091509 )
15101510 trunc_expert_mask = expert_mask * jnp .less_equal (expert_token_count , expert_capacity_per_batch )
15111511 combined_expert_mask = jnp .sum (trunc_expert_mask , axis = 3 )
@@ -1585,15 +1585,15 @@ def generate_masks(self, top_k_indices, softmax_probs):
15851585 expert_mask ,
15861586 (batch_size , seq_len * self .num_experts_per_tok , self .num_experts ),
15871587 )
1588- expert_mask_fused = self ._maybe_shard_with_logical (expert_mask_fused , ("activation_batch " , None , None ))
1588+ expert_mask_fused = self ._maybe_shard_with_logical (expert_mask_fused , ("activation_batch_moe " , None , None ))
15891589 expert_token_count_fused = jnp .cumsum (expert_mask_fused , axis = 1 )
15901590 expert_token_count = jnp .reshape (
15911591 expert_token_count_fused ,
15921592 ((batch_size , seq_len , self .num_experts_per_tok , self .num_experts )),
15931593 )
15941594 expert_token_count = self ._maybe_shard_with_logical (
15951595 expert_token_count ,
1596- ("activation_batch " , "activation_norm_length " , None , None ),
1596+ ("activation_batch_moe " , "activation_norm_length_moe " , None , None ),
15971597 )
15981598 trunc_expert_mask = expert_mask * jnp .less_equal (expert_token_count , expert_capacity_per_batch )
15991599 combined_expert_mask = jnp .sum (trunc_expert_mask , axis = 2 )
@@ -1691,11 +1691,13 @@ def dense_matmul(
16911691 ) -> tuple [jax .Array , Optional [jax .Array ], Optional [jax .Array ]]:
16921692 """Dense matrix multiplication."""
16931693 # gate_logits: batch, length, expert
1694- gate_logits = self ._maybe_shard_with_logical (gate_logits , ("activation_batch" , "activation_norm_length" , None ))
1694+ gate_logits = self ._maybe_shard_with_logical (
1695+ gate_logits , ("activation_batch_moe" , "activation_length_no_exp_moe" , None )
1696+ )
16951697 if self .config .model_name .startswith ("deepseek3" ):
16961698 # pre_bias_logits is None for non-DeepSeek v3 models
16971699 pre_bias_logits = self ._maybe_shard_with_logical (
1698- pre_bias_logits , ("activation_batch " , "activation_norm_length " , None )
1700+ pre_bias_logits , ("activation_batch_moe " , "activation_length_no_exp_moe " , None )
16991701 )
17001702 top_k_weights , top_k_indices = self .get_topk (gate_logits , pre_bias_logits , self .rngs )
17011703 is_llama4_decoder_layer = self .config .decoder_block == ctypes .DecoderBlockType .LLAMA4
@@ -1735,16 +1737,16 @@ def dense_matmul(
17351737 dispatch_mask , combine_mask = self .generate_masks (
17361738 top_k_indices , weights # pylint: disable=undefined-variable,possibly-used-before-assignment
17371739 )
1738- mask_axes = ("activation_batch " , "activation_norm_length " , None , None )
1740+ mask_axes = ("activation_batch_moe " , "activation_norm_length_moe " , None , None )
17391741 dispatch_axis = (
17401742 "activation_exp" ,
1741- "activation_batch_no_exp " ,
1743+ "activation_batch_no_exp_moe " ,
17421744 None ,
1743- "activation_embed " ,
1745+ "activation_embed_moe " ,
17441746 )
17451747 mlp_axis = (
17461748 "activation_exp" ,
1747- "activation_batch_no_exp " ,
1749+ "activation_batch_no_exp_moe " ,
17481750 None ,
17491751 "activation_mlp" ,
17501752 )
@@ -1759,56 +1761,56 @@ def dense_matmul(
17591761 dispatch_mask , combine_mask = self .generate_masks_subgroup (top_k_indices , softmax_probs )
17601762 if self .get_context_autoregressive_parallelism_size () > 0 and cp == 1 :
17611763 mask_axes = (
1762- "activation_norm_length " ,
1763- "activation_batch " ,
1764+ "activation_norm_length_moe " ,
1765+ "activation_batch_moe " ,
17641766 None ,
17651767 None ,
17661768 None ,
17671769 )
17681770 input_axis = (
1769- "activation_norm_length " ,
1770- "activation_batch " ,
1771+ "activation_norm_length_moe " ,
1772+ "activation_batch_moe " ,
17711773 None ,
1772- "activation_embed " ,
1774+ "activation_embed_moe " ,
17731775 )
17741776 dispatch_axis = (
17751777 "activation_exp" ,
1776- "activation_batch_no_exp " ,
1778+ "activation_batch_no_exp_moe " ,
17771779 None ,
17781780 None ,
1779- "activation_embed " ,
1781+ "activation_embed_moe " ,
17801782 )
17811783 mlp_axis = (
17821784 "activation_exp" ,
1783- "activation_batch_no_exp " ,
1785+ "activation_batch_no_exp_moe " ,
17841786 None ,
17851787 None ,
17861788 "activation_mlp" ,
17871789 )
17881790 else :
17891791 mask_axes = (
1790- "activation_batch " ,
1791- "activation_norm_length " ,
1792+ "activation_batch_moe " ,
1793+ "activation_norm_length_moe " ,
17921794 None ,
17931795 None ,
17941796 None ,
17951797 )
17961798 input_axis = (
1797- "activation_batch " ,
1798- "activation_norm_length " ,
1799+ "activation_batch_moe " ,
1800+ "activation_norm_length_moe " ,
17991801 None ,
1800- "activation_embed " ,
1802+ "activation_embed_moe " ,
18011803 )
18021804 dispatch_axis = (
18031805 "activation_exp" ,
1804- "activation_batch_no_exp " ,
1806+ "activation_batch_no_exp_moe " ,
18051807 None ,
18061808 None ,
1807- "activation_embed " ,
1809+ "activation_embed_moe " ,
18081810 )
18091811 mlp_axis = (
18101812 "activation_exp" ,
1811- "activation_batch_no_exp " ,
1813+ "activation_batch_no_exp_moe " ,
18121814 None ,
18131815 None ,
18141816 "activation_mlp" ,
@@ -1834,10 +1836,10 @@ def dense_matmul(
18341836 dispatch ,
18351837 (
18361838 None ,
1837- "activation_batch_no_exp " ,
1838- "activation_norm_length " ,
1839+ "activation_batch_no_exp_moe " ,
1840+ "activation_norm_length_moe " ,
18391841 None ,
1840- "activation_embed " ,
1842+ "activation_embed_moe " ,
18411843 ),
18421844 )
18431845 dispatch = self ._maybe_shard_with_logical (
@@ -1897,9 +1899,9 @@ def dense_matmul(
18971899 intermediate_layer ,
18981900 (
18991901 "activation_exp" ,
1900- "activation_batch_no_exp " ,
1902+ "activation_batch_no_exp_moe " ,
19011903 None ,
1902- "activation_embed " ,
1904+ "activation_embed_moe " ,
19031905 ),
19041906 )
19051907 intermediate_layer = adc .checkpoint_name (intermediate_layer , "mlpwo" )
@@ -1922,7 +1924,9 @@ def dense_matmul(
19221924 )
19231925 return output , lb_loss , bias_updates
19241926 else :
1925- inputs = self ._maybe_shard_with_logical (inputs , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
1927+ inputs = self ._maybe_shard_with_logical (
1928+ inputs , ("activation_batch_moe" , "activation_norm_length_moe" , "activation_embed_moe" )
1929+ )
19261930 with jax .named_scope ("wi_0" ):
19271931 layer_w0 = self .get_einsum (rhs_mesh_axes = self .wi_kernel_axes )(
19281932 "BSM,EMH -> BSEH" , inputs , w0_kernel , precision = matmul_precision
@@ -2082,7 +2086,7 @@ def __init__(
20822086 num_experts_per_tok = self .config .num_experts_per_tok ,
20832087 mesh = self .mesh ,
20842088 kernel_init = nd_dense_init (1.0 , "fan_in" , "truncated_normal" ),
2085- kernel_axes = ("embed " , None ),
2089+ kernel_axes = ("embed_moe " , None ),
20862090 intermediate_dim = self .config .moe_mlp_dim ,
20872091 dtype = self .config .dtype ,
20882092 weight_dtype = self .config .weight_dtype ,
0 commit comments