diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 139db7fd54..68fde88dbd 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1237,6 +1237,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r global_sorted_experts=selected_experts, use_custom_sort_vjp=self.config.use_custom_sort_vjp, ) + mask = jnp.arange(x.shape[0]) < jnp.sum(group_sizes) if self.config.mlp_bias: w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias) @@ -1300,6 +1301,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") if self.config.mlp_bias: layer_w0 = layer_w0 + w0_bias + layer_w0 = jnp.where(mask[:, None], layer_w0, 0) layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0") layer_w1 = gmm_fn( @@ -1312,6 +1314,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose") if self.config.mlp_bias: layer_w1 = layer_w1 + w1_bias + layer_w1 = jnp.where(mask[:, None], layer_w1, 0) layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1) @@ -1327,6 +1330,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): ) if self.config.mlp_bias: intermediate_output = intermediate_output + wo_bias + intermediate_output = jnp.where(mask[:, None], intermediate_output, 0) intermediate_output = adc.checkpoint_name(intermediate_output, "moe_mlpwo") if self.config.use_ring_of_experts: