@@ -1274,7 +1274,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12741274 layer_w0 = jax .lax .psum (layer_w0 , "tensor_transpose" )
12751275 if self .config .mlp_bias :
12761276 layer_w0 = layer_w0 + w0_bias
1277- layer_w0 = adc .checkpoint_name (layer_w0 , "mlpwi_0 " )
1277+ layer_w0 = adc .checkpoint_name (layer_w0 , "moe_mlpwi_0 " )
12781278
12791279 layer_w1 = gmm_fn (
12801280 x ,
@@ -1288,7 +1288,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12881288 layer_w1 = jax .lax .psum (layer_w1 , "tensor_transpose" )
12891289 if self .config .mlp_bias :
12901290 layer_w1 = layer_w1 + w1_bias
1291- layer_w1 = adc .checkpoint_name (layer_w1 , "mlpwi_1 " )
1291+ layer_w1 = adc .checkpoint_name (layer_w1 , "moe_mlpwi_1 " )
12921292 intermediate_layer = self .apply_ffn_activation (layer_w0 , layer_w1 )
12931293
12941294 intermediate_output = gmm_fn (
@@ -1305,7 +1305,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13051305 )
13061306 if self .config .mlp_bias :
13071307 intermediate_output = intermediate_output + wo_bias
1308- intermediate_output = adc .checkpoint_name (intermediate_output , "mlpwo " )
1308+ intermediate_output = adc .checkpoint_name (intermediate_output , "moe_mlpwo " )
13091309
13101310 if self .config .use_ring_of_experts :
13111311 # Set the outputs of tokens which were not processed to 0.
@@ -1860,7 +1860,7 @@ def dense_matmul(
18601860 layer_w0 ,
18611861 mlp_axis ,
18621862 )
1863- layer_w0 = adc .checkpoint_name (layer_w0 , "mlpwi_0 " )
1863+ layer_w0 = adc .checkpoint_name (layer_w0 , "moe_mlpwi_0 " )
18641864 with jax .named_scope ("wi_1" ):
18651865 w1_kernel_axes = ("exp" , None , "mlp" )
18661866 w1_kernel = self .maybe_all_gather_kernel_weight_in_expert_parallelism (w1_kernel , w1_kernel_axes )
@@ -1876,7 +1876,7 @@ def dense_matmul(
18761876 layer_w1 ,
18771877 mlp_axis ,
18781878 )
1879- layer_w1 = adc .checkpoint_name (layer_w1 , "mlpwi_1 " )
1879+ layer_w1 = adc .checkpoint_name (layer_w1 , "moe_mlpwi_1 " )
18801880 layer_multiply = self .apply_ffn_activation (layer_w0 , layer_w1 )
18811881 with jax .named_scope ("wo" ):
18821882 wo_kernel_axes = ("exp" , "mlp" , None )
@@ -1902,7 +1902,7 @@ def dense_matmul(
19021902 "activation_embed" ,
19031903 ),
19041904 )
1905- intermediate_layer = adc .checkpoint_name (intermediate_layer , "mlpwo " )
1905+ intermediate_layer = adc .checkpoint_name (intermediate_layer , "moe_mlpwo " )
19061906 with jax .named_scope ("combine" ):
19071907 # Matmul & element wise operation
19081908 output = self .get_einsum (rhs_mesh_axes = mask_axes , einsum_name = COMBINE )(
@@ -1931,7 +1931,7 @@ def dense_matmul(
19311931 layer_w0 = layer_w0 + w0_bias [None , None , :, :]
19321932 if self .config .activations_in_float32 :
19331933 layer_w0 = layer_w0 .astype (jnp .float32 )
1934- layer_w0 = adc .checkpoint_name (layer_w0 , "mlpwi_0 " )
1934+ layer_w0 = adc .checkpoint_name (layer_w0 , "moe_mlpwi_0 " )
19351935 with jax .named_scope ("wi_1" ):
19361936 layer_w1 = self .get_einsum (rhs_mesh_axes = self .wi_kernel_axes )(
19371937 "BSM,EMH -> BSEH" , inputs , w1_kernel , precision = matmul_precision
@@ -1940,7 +1940,7 @@ def dense_matmul(
19401940 layer_w1 = layer_w1 + w1_bias [None , None , :, :]
19411941 if self .config .activations_in_float32 :
19421942 layer_w1 = layer_w1 .astype (jnp .float32 )
1943- layer_w1 = adc .checkpoint_name (layer_w1 , "mlpwi_1 " )
1943+ layer_w1 = adc .checkpoint_name (layer_w1 , "moe_mlpwi_1 " )
19441944 layer_multiply = self .apply_ffn_activation (layer_w0 , layer_w1 )
19451945
19461946 with jax .named_scope ("wo" ):
@@ -1954,7 +1954,7 @@ def dense_matmul(
19541954 intermediate_layer = intermediate_layer + wo_bias [None , None , :, :]
19551955 if self .config .activations_in_float32 :
19561956 intermediate_layer = intermediate_layer .astype (jnp .float32 )
1957- intermediate_layer = adc .checkpoint_name (intermediate_layer , "mlpwo " )
1957+ intermediate_layer = adc .checkpoint_name (intermediate_layer , "moe_mlpwo " )
19581958 with jax .named_scope ("weight_sum" ):
19591959 if is_llama4_decoder_layer :
19601960 weights = self .reshape_and_update_weights (jnp .ones_like (top_k_weights ), top_k_indices )
0 commit comments