diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 85db2bcce4..eea7b9c2d0 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -435,7 +435,9 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], @@ -448,7 +450,10 @@ logical_axis_rules: [ ['activation_attn_length_no_exp', ['context']], ['activation_length_no_exp', ['sequence', 'context']], ['activation_length_no_exp', ['context']], + ['activation_length_no_exp_moe', ['sequence', 'context']], + ['activation_length_no_exp_moe', ['context']], ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], + ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']], ['activation_q_length', ['context', 'expert']], ['activation_q_length_no_exp', ['context']], ['prefill_activation_length', ['sequence', 'context']], @@ -456,6 +461,7 @@ logical_axis_rules: [ ['activation_kv_length', []], ['activation_attn_embed', ['tensor', 'tensor_transpose']], ['activation_embed', ['tensor', 'tensor_transpose']], + ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], @@ -484,6 +490,14 @@ logical_axis_rules: [ ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], + ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], + ['embed_moe', ['fsdp', 'sequence', 'context', 'expert']], + ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], + ['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], + ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_no_exp_moe', ['fsdp', 'sequence', 'context']], ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 62f2dbe370..7b99d96978 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -28,7 +28,9 @@ mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert'] data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'expert']], + ['activation_batch_moe', ['data', 'fsdp', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp']], + ['activation_batch_no_exp_moe', ['data', 'fsdp']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']], ['activation_heads', ['tensor']], @@ -38,6 +40,7 @@ logical_axis_rules: [ ['activation_q_length', ['expert']], ['activation_attn_embed', ['tensor']], ['activation_embed', ['tensor']], + ['activation_embed_moe', ['tensor']], ['activation_mlp', ['tensor']], ['activation_kv', ['tensor']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']], @@ -55,7 +58,10 @@ logical_axis_rules: [ ['q_heads', ['tensor']], ['kv_heads', ['tensor']], ['embed', ['fsdp', 'expert']], + ['embed_moe', ['fsdp', 'expert']], ['embed_no_exp', ['fsdp']], + ['embed_no_exp_moe', ['fsdp']], + ['embed_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['norm', ['tensor']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml index db9aafce8b..c8a28c5b24 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml @@ -19,6 +19,8 @@ data_sharding: [['fsdp']] logical_axis_rules: [ ['activation_batch', ['fsdp']], ['activation_batch_no_exp', ['fsdp']], + ['activation_batch_moe', ['fsdp']], + ['activation_batch_no_exp_moe', ['fsdp']], ['activation_embed_and_logits_batch', ['fsdp']], ['activation_embed_and_logits_batch_sequence', ['fsdp']], ['activation_prefill_kv_batch', ['fsdp']], @@ -27,6 +29,8 @@ logical_axis_rules: [ ['decode_batch', ['fsdp']], ['embed', ['fsdp']], ['embed_no_exp', ['fsdp']], + ['embed_moe', ['fsdp']], + ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['exp_with_fsdp', 'fsdp'], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 52e8255420..98c56bb61a 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -30,7 +30,9 @@ weight_dtype: bfloat16 mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert'] logical_axis_rules: [ ['activation_batch', ['expert']], + ['activation_batch_moe', ['expert']], ['activation_batch_no_exp', []], + ['activation_batch_no_exp_moe', []], ['activation_embed_and_logits_batch', ['expert']], ['activation_embed_and_logits_batch_sequence', ['expert']], ['activation_heads', ['model']], @@ -38,10 +40,13 @@ logical_axis_rules: [ ['activation_attn_length', ['expert']], ['activation_attn_length_no_exp', []], ['activation_length', ['data', 'expert']], + ['activation_length_moe', ['data', 'expert']], ['activation_length_no_exp', 'data'], + ['activation_length_no_exp_moe', 'data'], ['activation_q_length', ['expert', 'attn_dp_expert']], ['activation_attn_embed', 'model'], ['activation_embed', ['model', 'attn_dp']], + ['activation_embed_moe', ['model', 'attn_dp']], ['activation_mlp', ['model', 'attn_dp']], ['activation_kv', ['model']], ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']], @@ -50,6 +55,7 @@ logical_axis_rules: [ ['activation_kv_head_dim', ['model']], ['activation_vocab', ['model', 'attn_dp']], ['activation_norm_length', []], + ['activation_norm_length_moe', []], ['activation_exp', ['expert', 'attn_dp_expert']], ['decode_batch', ['expert', 'attn_dp_expert']], ['decode_length', []], @@ -63,8 +69,10 @@ logical_axis_rules: [ ['kv_head_dim', []], ['kv', []], ['embed', ['expert', 'attn_dp_expert']], + ['embed_moe', ['expert', 'attn_dp_expert']], ['embed_tensor_transpose', ['attn_dp', 'model']], ['embed_no_exp', []], + ['embed_no_exp_moe', []], ['q_lora', ['expert', 'attn_dp_expert']], ['kv_lora', ['expert', 'attn_dp_expert']], ['norm', []], diff --git a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml index c137d94c98..e55a173f1f 100644 --- a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml +++ b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml @@ -60,14 +60,18 @@ mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context'] data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_norm_length', ['context']], + ['activation_norm_length_moe', ['context']], ['activation_heads', []], ['activation_stage', 'stage'], ['embed', ['fsdp']], + ['embed_moe', ['fsdp']], ['embed_no_exp', ['fsdp']], + ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['layers', 'stage'], diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index cff98af9f1..3ddee5232a 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. contract_ind = tuple(range(0, len(norm_axis))) output_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None)) + create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None)) if self.shard_mode == ShardMode.EXPLICIT else None ) @@ -351,16 +351,16 @@ def __init__( if self.config.shard_exp_on_fsdp: # special sharding for dsv3 - self.wi_kernel_axes = ("embed_no_exp", None, "mlp") - self.wo_kernel_axes = ("embed_no_exp", "mlp", None) + self.wi_kernel_axes = ("embed_no_exp_moe", None, "mlp") + self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None) elif self.config.use_2d_fsdp_sharding: - self.wi_kernel_axes = ("embed_no_exp", "mlp", None) - self.wo_kernel_axes = ("embed_no_exp", "mlp", None) + self.wi_kernel_axes = ("embed_no_exp_moe", "mlp", None) + self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None) elif self.config.use_batch_split_schedule: self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes() else: - self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") - self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") + self.wi_kernel_axes = ("exp", "embed_no_exp_moe", "mlp") + self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp_moe") if self.config.attention == "vllm_rpa": # vLLM uses 'model' as the tensor parallelism axis name @@ -437,7 +437,7 @@ def __init__( if self.config.mlp_bias: wi_bias_axes = ("exp", "activation_mlp") - wo_bias_axes = ("exp", "activation_embed") + wo_bias_axes = ("exp", "activation_embed_moe") wi_bias_shape = (self.num_experts, self.intermediate_dim) wo_bias_shape = (self.num_experts, self.config.emb_dim) self.wi_0_bias = nnx.Param( @@ -1018,7 +1018,7 @@ def gmm( self._expert_parallelism_name in tuple( filter( - lambda tup: tup[0] == "activation_batch", + lambda tup: tup[0] == "activation_batch_moe", self.config.logical_axis_rules, ) )[ @@ -1028,26 +1028,26 @@ def gmm( except: # pylint: disable=bare-except is_batch_sharded_by_expert = False if is_batch_sharded_by_expert and inputs.shape[0] > 1: - batch_logical_axis = "activation_batch" + batch_logical_axis = "activation_batch_moe" else: - batch_logical_axis = "activation_batch_no_exp" + batch_logical_axis = "activation_batch_no_exp_moe" if self.get_tensor_transpose_parallelism_size() > 1: input_partition_pspec = self._logical_to_mesh_axes( - (batch_logical_axis, "activation_norm_length", "activation_embed") + (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") ) w0_bias_pspec = self._logical_to_mesh_axes(("exp", None)) w1_bias_pspec = self._logical_to_mesh_axes(("exp", None)) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) else: - input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) - gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) else: # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits_pspec = None @@ -1099,7 +1099,7 @@ def gmm( P(), # Replicate the input key ), out_specs=( - self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")), + self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")), P(), # Handle None or replicate the output P(), # Handle None or replicate the output ), @@ -1411,13 +1411,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose")) if self.get_tensor_transpose_parallelism_size() > 1: - input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed") + input_axes = (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") else: - input_axes = (batch_logical_axis, "activation_norm_length", None) + input_axes = (batch_logical_axis, "activation_norm_length_moe", None) - gate_logits_axes = (batch_logical_axis, "activation_norm_length", None) + gate_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) + pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) else: pre_bias_logits_axes = None @@ -1436,13 +1436,13 @@ def reshape_and_update_weights(self, weights, indices): update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) index_update = ( self._maybe_shard_with_logical( - jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp", None, None) + jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None) ), - self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp", None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp_moe", None)), indices, ) weight_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None)) + create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None)) if self.config.shard_mode == ShardMode.EXPLICIT else None ) @@ -1497,7 +1497,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): expert_mask, (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch_moe", None, None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( expert_token_count_fused, @@ -1505,7 +1505,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch", "activation_norm_length", None, None, None), + ("activation_batch_moe", "activation_norm_length_moe", None, None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) @@ -1585,7 +1585,7 @@ def generate_masks(self, top_k_indices, softmax_probs): expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch_moe", None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1) expert_token_count = jnp.reshape( expert_token_count_fused, @@ -1593,7 +1593,7 @@ def generate_masks(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch", "activation_norm_length", None, None), + ("activation_batch_moe", "activation_norm_length_moe", None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) @@ -1691,11 +1691,13 @@ def dense_matmul( ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert - gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None)) + gate_logits = self._maybe_shard_with_logical( + gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) + ) if self.config.model_name.startswith("deepseek3"): # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits = self._maybe_shard_with_logical( - pre_bias_logits, ("activation_batch", "activation_norm_length", None) + pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) ) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 @@ -1735,16 +1737,16 @@ def dense_matmul( dispatch_mask, combine_mask = self.generate_masks( top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment ) - mask_axes = ("activation_batch", "activation_norm_length", None, None) + mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp", + "activation_batch_no_exp_moe", None, - "activation_embed", + "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp", + "activation_batch_no_exp_moe", None, "activation_mlp", ) @@ -1759,56 +1761,56 @@ def dense_matmul( dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs) if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1: mask_axes = ( - "activation_norm_length", - "activation_batch", + "activation_norm_length_moe", + "activation_batch_moe", None, None, None, ) input_axis = ( - "activation_norm_length", - "activation_batch", + "activation_norm_length_moe", + "activation_batch_moe", None, - "activation_embed", + "activation_embed_moe", ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp", + "activation_batch_no_exp_moe", None, None, - "activation_embed", + "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp", + "activation_batch_no_exp_moe", None, None, "activation_mlp", ) else: mask_axes = ( - "activation_batch", - "activation_norm_length", + "activation_batch_moe", + "activation_norm_length_moe", None, None, None, ) input_axis = ( - "activation_batch", - "activation_norm_length", + "activation_batch_moe", + "activation_norm_length_moe", None, - "activation_embed", + "activation_embed_moe", ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp", + "activation_batch_no_exp_moe", None, None, - "activation_embed", + "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp", + "activation_batch_no_exp_moe", None, None, "activation_mlp", @@ -1834,10 +1836,10 @@ def dense_matmul( dispatch, ( None, - "activation_batch_no_exp", - "activation_norm_length", + "activation_batch_no_exp_moe", + "activation_norm_length_moe", None, - "activation_embed", + "activation_embed_moe", ), ) dispatch = self._maybe_shard_with_logical( @@ -1897,9 +1899,9 @@ def dense_matmul( intermediate_layer, ( "activation_exp", - "activation_batch_no_exp", + "activation_batch_no_exp_moe", None, - "activation_embed", + "activation_embed_moe", ), ) intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo") @@ -1922,7 +1924,9 @@ def dense_matmul( ) return output, lb_loss, bias_updates else: - inputs = self._maybe_shard_with_logical(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = self._maybe_shard_with_logical( + inputs, ("activation_batch_moe", "activation_norm_length_moe", "activation_embed_moe") + ) with jax.named_scope("wi_0"): layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision @@ -2082,7 +2086,7 @@ def __init__( num_experts_per_tok=self.config.num_experts_per_tok, mesh=self.mesh, kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), - kernel_axes=("embed", None), + kernel_axes=("embed_moe", None), intermediate_dim=self.config.moe_mlp_dim, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json index 2ca5429163..cbee49b201 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[192,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[192,2048,64]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json index 0d224005c5..5e33fc22b8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json @@ -143,7 +143,7 @@ }, ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -157,7 +157,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -171,7 +171,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -186,7 +186,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -469,7 +469,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -483,7 +483,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -497,7 +497,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -512,7 +512,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -791,7 +791,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -805,7 +805,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -819,7 +819,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -834,7 +834,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json index c3bec496eb..a12030dbd9 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[768,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[768,2048,64]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json index 0d224005c5..5e33fc22b8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json @@ -143,7 +143,7 @@ }, ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -157,7 +157,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -171,7 +171,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -186,7 +186,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -469,7 +469,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -483,7 +483,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -497,7 +497,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -512,7 +512,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -791,7 +791,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -805,7 +805,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -819,7 +819,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -834,7 +834,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json index 6bdd341c12..4172fc960f 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[96,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[96,2048,64]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json index 0d224005c5..5e33fc22b8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json @@ -143,7 +143,7 @@ }, ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -157,7 +157,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -171,7 +171,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -186,7 +186,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -469,7 +469,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -483,7 +483,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -497,7 +497,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -512,7 +512,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -791,7 +791,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -805,7 +805,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -819,7 +819,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -834,7 +834,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json index a9b0c8c577..2789aa367e 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[384,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[384,2048,64]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json index 0d224005c5..5e33fc22b8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json @@ -143,7 +143,7 @@ }, ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -157,7 +157,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -171,7 +171,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -186,7 +186,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -469,7 +469,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -483,7 +483,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -497,7 +497,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -512,7 +512,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -791,7 +791,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -805,7 +805,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -819,7 +819,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -834,7 +834,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json index 2ca5429163..cbee49b201 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[192,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[192,2048,64]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json index 0d224005c5..5e33fc22b8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json @@ -143,7 +143,7 @@ }, ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -157,7 +157,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -171,7 +171,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -186,7 +186,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -469,7 +469,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -483,7 +483,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -497,7 +497,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -512,7 +512,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -791,7 +791,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -805,7 +805,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -819,7 +819,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -834,7 +834,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json index c3bec496eb..a12030dbd9 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[768,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[768,2048,64]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json index 0d224005c5..5e33fc22b8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json @@ -143,7 +143,7 @@ }, ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -157,7 +157,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -171,7 +171,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -186,7 +186,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -469,7 +469,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -483,7 +483,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -497,7 +497,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -512,7 +512,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, @@ -791,7 +791,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { "partition_spec": [ - "embed", + "embed_moe", "moe_layers", null ], @@ -805,7 +805,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -819,7 +819,7 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -834,7 +834,7 @@ "exp", "moe_layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json index 8409398a06..1f050c09b8 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[192,2048,2880]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[192,2048,32]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json index 1b90463c89..35b79ae83c 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json @@ -149,7 +149,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -175,7 +175,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -202,7 +202,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -381,7 +381,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -407,7 +407,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -434,7 +434,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -645,7 +645,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -671,7 +671,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -698,7 +698,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -877,7 +877,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -903,7 +903,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -930,7 +930,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1137,7 +1137,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1163,7 +1163,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1190,7 +1190,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1369,7 +1369,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1395,7 +1395,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1422,7 +1422,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json index 37aeba83cc..96fab6247a 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[768,2048,2880]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[768,2048,32]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json index 1b90463c89..35b79ae83c 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json @@ -149,7 +149,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -175,7 +175,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -202,7 +202,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -381,7 +381,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -407,7 +407,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -434,7 +434,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -645,7 +645,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -671,7 +671,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -698,7 +698,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -877,7 +877,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -903,7 +903,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -930,7 +930,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1137,7 +1137,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1163,7 +1163,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1190,7 +1190,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1369,7 +1369,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1395,7 +1395,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1422,7 +1422,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json index 5c25c0f2a5..ab45563642 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[96,2048,2880]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[96,2048,32]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json index 1b90463c89..35b79ae83c 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json @@ -149,7 +149,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -175,7 +175,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -202,7 +202,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -381,7 +381,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -407,7 +407,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -434,7 +434,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -645,7 +645,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -671,7 +671,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -698,7 +698,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -877,7 +877,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -903,7 +903,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -930,7 +930,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1137,7 +1137,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1163,7 +1163,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1190,7 +1190,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1369,7 +1369,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1395,7 +1395,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1422,7 +1422,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json index 26f3df46f2..0a86cb5c83 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[384,2048,2880]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[384,2048,32]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json index 1b90463c89..35b79ae83c 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json @@ -149,7 +149,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -175,7 +175,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -202,7 +202,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -381,7 +381,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -407,7 +407,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -434,7 +434,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -645,7 +645,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -671,7 +671,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -698,7 +698,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -877,7 +877,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -903,7 +903,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -930,7 +930,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1137,7 +1137,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1163,7 +1163,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1190,7 +1190,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1369,7 +1369,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1395,7 +1395,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1422,7 +1422,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json index 8409398a06..1f050c09b8 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[192,2048,2880]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[192,2048,32]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json index 1b90463c89..35b79ae83c 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json @@ -149,7 +149,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -175,7 +175,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -202,7 +202,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -381,7 +381,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -407,7 +407,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -434,7 +434,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -645,7 +645,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -671,7 +671,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -698,7 +698,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -877,7 +877,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -903,7 +903,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -930,7 +930,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1137,7 +1137,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1163,7 +1163,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1190,7 +1190,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1369,7 +1369,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1395,7 +1395,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1422,7 +1422,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json index 37aeba83cc..96fab6247a 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[768,2048,2880]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[768,2048,32]": { - "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json index 1b90463c89..35b79ae83c 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json @@ -149,7 +149,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -175,7 +175,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -202,7 +202,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -381,7 +381,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -407,7 +407,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -434,7 +434,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -645,7 +645,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -671,7 +671,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -698,7 +698,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -877,7 +877,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -903,7 +903,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -930,7 +930,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1137,7 +1137,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1163,7 +1163,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1190,7 +1190,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32, @@ -1369,7 +1369,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1395,7 +1395,7 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp", + "embed_no_exp_moe", "mlp" ], "shape": [ @@ -1422,7 +1422,7 @@ "exp", "layers", "mlp", - "embed_no_exp" + "embed_no_exp_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed" + "activation_embed_moe" ], "shape": [ 32,