Skip to content

Commit 4d486c4

Browse files
committed
add ici_attn_dp_expert_parallleism config
1 parent eb81131 commit 4d486c4

4 files changed

Lines changed: 60 additions & 2 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,7 @@ ici_tensor_sequence_parallelism: 1
592592
ici_autoregressive_parallelism: 1
593593
ici_pipeline_parallelism: 1
594594
ici_expert_parallelism: 1
595+
ici_attn_dp_expert_parallelism: 1
595596

596597
# Enable ZeRO-1 optimizer sharding over data axis
597598
shard_optimizer_over_data: False
@@ -985,7 +986,7 @@ xprof_e2e_enable_fw_power_level_event: False
985986
xprof_e2e_enable_fw_thermal_event: False
986987
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
987988

988-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
989+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
989990
debug_sharding: False # Prints model weights sharding info
990991

991992
# Checkpoint Structured logging
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
2+
logical_axis_rules: [
3+
['activation_batch', []],
4+
['activation_batch_moe', ['data']],
5+
['activation_embed_and_logits_batch', ['data']],
6+
['activation_embed_and_logits_batch_sequence', ['data']],
7+
['activation_heads', ['model', 'expert']],
8+
['activation_kv_heads', ['model', 'expert']],
9+
['activation_attn_length', []],
10+
['activation_length', []],
11+
['activation_length_moe', []],
12+
['activation_q_length', ['expert']],
13+
['activation_attn_embed', 'model'],
14+
# Expert is missing explicitly from activation_embed despite using TP.
15+
# We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
16+
# due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
17+
['activation_embed', ['model', 'attn_dp', 'attn_dp_expert']],
18+
['activation_embed_moe', ['model', 'attn_dp', 'attn_dp_expert']],
19+
['activation_mlp', ['model']],
20+
['activation_mlp_moe', ['model']],
21+
['activation_kv', ['model']],
22+
['activation_prefill_kv_batch', ['expert']],
23+
['activation_kv_batch', ['data', 'attn_dp_expert']],
24+
['activation_kv_head_dim', ['model']],
25+
['activation_vocab', ['model', 'attn_dp']],
26+
['activation_norm_length', []],
27+
['activation_norm_length_moe', []],
28+
['activation_exp', ['expert', 'attn_dp_expert']],
29+
['decode_batch', ['data', 'attn_dp_expert']],
30+
['decode_batch_moe', ['data', 'attn_dp_expert']],
31+
['decode_length', []],
32+
['mlp', ['model', 'attn_dp']],
33+
['mlp_moe', ['model', 'attn_dp']],
34+
['mlp_no_fsdp', ['model', 'attn_dp']],
35+
['vocab', ['model', 'attn_dp']],
36+
# Expert is intended to act like TP for attention.
37+
# We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
38+
['heads', ['model', 'expert']],
39+
['q_heads', ['model', 'expert']],
40+
['kv_heads', ['model', 'expert']],
41+
['kv_head_dim', []],
42+
['kv', []],
43+
['embed', []],
44+
['embed_moe', []],
45+
['embed_tensor_transpose', ['attn_dp', 'model']],
46+
['q_lora', ['expert']],
47+
['kv_lora', ['expert']],
48+
['norm', []],
49+
['cache_heads', ['model']],
50+
['exp', ['expert', 'attn_dp_expert']],
51+
['paged_kv_heads', ['model']],
52+
]
53+
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]

src/maxtext/configs/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,7 @@ class IciParallelism(BaseModel):
886886
ici_autoregressive_parallelism: int = Field(1, description="ICI axis for autoregressive parallelism.")
887887
ici_pipeline_parallelism: int = Field(1, description="ICI axis for pipeline parallelism.")
888888
ici_expert_parallelism: int = Field(1, description="ICI axis for expert parallelism.")
889+
ici_attn_dp_expert_parallelism: int = Field(1, description="ICI axis for attn dp expert parallelism.")
889890

890891

891892
class PipelineParallelism(BaseModel):
@@ -2746,7 +2747,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
27462747
"expert": self.ici_expert_parallelism,
27472748
"autoregressive": self.ici_autoregressive_parallelism,
27482749
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2749-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2750+
"attn_dp_expert": self.ici_attn_dp_expert_parallelism,
27502751
}
27512752
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
27522753

src/maxtext/inference/vllm_decode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def decode_with_vllm(config: Config) -> None:
100100
enable_expert_parallel = config.ici_expert_parallelism > 1
101101
if enable_expert_parallel:
102102
vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = config.ici_expert_parallelism
103+
vllm_args["additional_config"]["sharding"]["sharding_strategy"][
104+
"attention_data_expert_parallelism"
105+
] = config.ici_attn_dp_expert_parallelism
103106
vllm_args["enable_expert_parallel"] = enable_expert_parallel
104107

105108
max_logging.log(

0 commit comments

Comments
 (0)