diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 85db2bcce4..366cd3f688 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -476,14 +476,13 @@ logical_axis_rules: [ ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'context', 'expert']], + ['embed', ['expert']], # Instead of using pure FSDP (both FSDP and EP act like FSDP during attn), we replace the FSDP by DP to shard less, otherwise we would be sharding too much (e.g. 512+ ways, and have small shard shape). We use EP since EP is 2D in the target config ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'context']], + # For full solution should rename embed_no_exp to embed_moe + # May need remove_fsdp functionality tho ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], @@ -522,6 +521,7 @@ logical_axis_rules: [ # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length'] +embed_shard: "expert_only" #Choose to shard embed (embed_attn) on both fsdp and expert, expert only, or fsdp only ("both", "expert_only", "fsdp_only") # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. sharding_tolerance: 0.02 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 9ee1a7bb59..77e0614fe6 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -817,6 +817,7 @@ class LayoutAndSharding(BaseModel): shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.") internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.") internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.") + embed_shard: str = Field("expert_only", description="Which axes to shard embed (embed_attention) on") class DcnParallelism(BaseModel): @@ -2278,6 +2279,21 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if self.expert_shard_attention_option == "context": cp_size *= self.ici_expert_parallelism * self.dcn_expert_parallelism self.context_parallel_size = cp_size + + # Modify embed - this is a VERY hacky (non-mergeable) implementation, to be replaced with some cool new way to share logical axis rules soon + for rule in self.logical_axis_rules: + if rule and rule[0] == "embed": + if self.embed_shard == "expert_only": + rule[1] = ["expert"] + elif self.embed_shard == "fsdp_only": + rule[1] = ["fsdp"] + elif self.embed_shard == "both": + rule[1] = ["fsdp", "expert"] + else: + # throw value error + raise ValueError(f"Invalid embed_shard: {self.embed_shard}. Must be 'expert_only', 'fsdp_only', or 'both'.") + break + if self.pipeline_parallel_layers == -1: if self.decoder_block == DecoderBlockType.DEEPSEEK: moe_layers = self.num_decoder_layers - self.first_num_dense_layers diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index cff98af9f1..c4f9880d6a 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1073,6 +1073,7 @@ def gmm( w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None)) wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None)) else: + # embed_tensor_transpose here is crazy but doesn't have FSDP so we AG over FSDP.... w0_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) w1_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) wo_pspec = self._logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))