Skip to content

Commit 11623d9

Browse files
committed
introduce cp-as-ep rule for long context training or strong scaling
1 parent 412902a commit 11623d9

6 files changed

Lines changed: 3789 additions & 0 deletions

File tree

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This rule uses data, stage, FSDP, and expert. Expert axis acts as context parallelism in
16+
# components except core dMoE part (between EP all2all).
17+
mesh_axes: ['data', 'stage', 'fsdp', 'context', 'expert']
18+
data_sharding: [['data', 'stage', 'fsdp', 'context', 'expert']]
19+
context_sharding: 'context'
20+
logical_axis_rules: [
21+
# ==========================================
22+
# Vocabulary Embedding
23+
# ==========================================
24+
# Vocab Activations
25+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
26+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
27+
# Vocab Weights
28+
['vocab', []],
29+
['embed_vocab', ['fsdp', 'context', 'expert']],
30+
# ==========================================
31+
# Attention
32+
# ==========================================
33+
# Attention Activations
34+
['activation_batch_attn', ['data', 'fsdp', 'expert']],
35+
['activation_heads', []],
36+
['activation_kv_heads', []],
37+
['activation_length_attn', ['context']],
38+
['activation_q_length', ['context']],
39+
['activation_kv_length', []],
40+
['activation_embed_attn', []],
41+
['activation_kv', []],
42+
['activation_kv_batch', ['data', 'fsdp', 'expert']],
43+
['activation_kv_head_dim', []],
44+
# Attention Weights
45+
['heads', []],
46+
['q_heads', []],
47+
['kv_heads', []],
48+
['qkv', []],
49+
['kv', []],
50+
['kv_head_dim', []],
51+
['q_lora', ['fsdp', 'context', 'expert']],
52+
["q_lora_up_proj", []],
53+
['kv_lora', ['fsdp', 'context', 'expert']],
54+
["kv_lora_up_proj", []],
55+
# ==========================================
56+
# Mixture of Experts (MoE)
57+
# ==========================================
58+
# MoE Activations
59+
['activation_batch_moe', ['data', 'fsdp']],
60+
['activation_exp', ['context', 'expert']],
61+
# MoE Weights
62+
['exp', ['context', 'expert']],
63+
['embed_moe', ['fsdp']],
64+
# ==========================================
65+
# Standard MLP / Dense Layers / Model Structure
66+
# ==========================================
67+
# Dense Activations
68+
['activation_mlp', []],
69+
['activation_batch', ['data', 'fsdp', 'expert']],
70+
['activation_length', ['context']],
71+
['activation_norm_length', ['context']],
72+
['activation_embed', []],
73+
['activation_stage', 'stage'],
74+
# General Weights
75+
['mlp', []],
76+
['layers', 'stage'],
77+
['embed', ['fsdp', 'context', 'expert']],
78+
]

src/maxtext/layers/moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,9 @@ def __init__(
376376

377377
if self.config.attention == "vllm_rpa" and self.config.enable_dp_attention:
378378
self._expert_parallelism_name = "attn_dp_expert"
379+
elif self.config.custom_mesh_and_rule == "cp-as-ep":
380+
# when custom mesh and rule is cp-as-ep, context axis is same with expert in MoE component
381+
self._expert_parallelism_name = ("context", "expert")
379382
else:
380383
self._expert_parallelism_name = "expert"
381384

tests/utils/sharding_dump.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@
5555
"ep-as-cp",
5656
("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2"),
5757
),
58+
(
59+
"deepseek2-16b",
60+
"tpu7x-8",
61+
1,
62+
"cp-as-ep",
63+
("ici_fsdp_parallelism=-1", "ici_context_parallelism=2", "ici_expert_parallelism=2"),
64+
),
5865
("qwen3-0.6b", "tpu7x-16", 1, "", ()),
5966
("gpt-oss-20b", "tpu7x-16", 1, "", ()),
6067
("gpt-oss-20b", "tpu7x-16", 1, "", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2")),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
{
2+
"Activation Sharding Dump": [
3+
{
4+
"deepseek/inputs: bfloat16[96,2048,2048]": {
5+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
6+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
7+
}
8+
},
9+
{
10+
"deepseek/pre_attention_norm: bfloat16[96,2048,2048]": {
11+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
12+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
13+
}
14+
},
15+
{
16+
"attention_mla/inputs_q: bfloat16[96,2048,2048]": {
17+
"logic_axes": "('activation_batch_attn', 'activation_length', 'activation_embed')",
18+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
19+
}
20+
},
21+
{
22+
"attention_mla/inputs_kv: bfloat16[96,2048,2048]": {
23+
"logic_axes": "('activation_batch_attn', 'activation_length', 'activation_embed')",
24+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
25+
}
26+
},
27+
{
28+
"attention_mla/q_nope: bfloat16[96,2048,16,128]": {
29+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
30+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)"
31+
}
32+
},
33+
{
34+
"attention_mla/q_pe: bfloat16[96,2048,16,64]": {
35+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
36+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)"
37+
}
38+
},
39+
{
40+
"attention_mla/query: bfloat16[96,2048,16,192]": {
41+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
42+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)"
43+
}
44+
},
45+
{
46+
"attention_mla/key_nope: bfloat16[96,2048,16,128]": {
47+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
48+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)"
49+
}
50+
},
51+
{
52+
"attention_mla/key_rope: bfloat16[96,2048,16,64]": {
53+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
54+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)"
55+
}
56+
},
57+
{
58+
"attention_mla/key: bfloat16[96,2048,16,192]": {
59+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
60+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)"
61+
}
62+
},
63+
{
64+
"attention_mla/value: bfloat16[96,2048,16,128]": {
65+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
66+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)"
67+
}
68+
},
69+
{
70+
"attention_op/arr: int8[1,4,4]": {
71+
"logic_axes": "Unknown",
72+
"PartitionSpec": "P(None, 'context')"
73+
}
74+
},
75+
{
76+
"attention_op/arr: int32[2048]": {
77+
"logic_axes": "Unknown",
78+
"PartitionSpec": "P('context',)"
79+
}
80+
},
81+
{
82+
"attention_op/query: bfloat16[96,16,2048,192]": {
83+
"logic_axes": "Unknown",
84+
"PartitionSpec": "P(('fsdp', 'expert'), None, 'context', None)"
85+
}
86+
},
87+
{
88+
"attention_op/key: bfloat16[96,16,2048,192]": {
89+
"logic_axes": "Unknown",
90+
"PartitionSpec": "P(('fsdp', 'expert'), None, None, None)"
91+
}
92+
},
93+
{
94+
"attention_op/value: bfloat16[96,16,2048,128]": {
95+
"logic_axes": "Unknown",
96+
"PartitionSpec": "P(('fsdp', 'expert'), None, None, None)"
97+
}
98+
},
99+
{
100+
"attention_mla/out: bfloat16[96,2048,16,128]": {
101+
"logic_axes": "('activation_batch_attn', 'activation_length', 'activation_heads', 'activation_kv')",
102+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)"
103+
}
104+
},
105+
{
106+
"deepseek/attention_result: bfloat16[96,2048,2048]": {
107+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
108+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
109+
}
110+
},
111+
{
112+
"deepseek/post_attention_norm: bfloat16[96,2048,2048]": {
113+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
114+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
115+
}
116+
},
117+
{
118+
"linears/x: bfloat16[96,2048,10944]": {
119+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
120+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
121+
}
122+
},
123+
{
124+
"deepseek/mlp: bfloat16[96,2048,2048]": {
125+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
126+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
127+
}
128+
},
129+
{
130+
"deepseek/x: bfloat16[96,2048,2048]": {
131+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
132+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
133+
}
134+
},
135+
{
136+
"moe/inputs: bfloat16[96,2048,2048]": {
137+
"logic_axes": "('activation_batch', 'activation_norm_length', None)",
138+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
139+
}
140+
},
141+
{
142+
"moe/gate_logits: bfloat16[96,2048,64]": {
143+
"logic_axes": "('activation_batch', 'activation_norm_length', None)",
144+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
145+
}
146+
},
147+
{
148+
"moe/w0_kernel: bfloat16[64,2048,1408]": {
149+
"logic_axes": "Unknown",
150+
"PartitionSpec": "P(('context', 'expert'), None, None)"
151+
}
152+
},
153+
{
154+
"moe/w1_kernel: bfloat16[64,2048,1408]": {
155+
"logic_axes": "Unknown",
156+
"PartitionSpec": "P(('context', 'expert'), None, None)"
157+
}
158+
},
159+
{
160+
"moe/wo_kernel: bfloat16[64,1408,2048]": {
161+
"logic_axes": "Unknown",
162+
"PartitionSpec": "P(('context', 'expert'), None, None)"
163+
}
164+
},
165+
{
166+
"linears/x: bfloat16[96,2048,2816]": {
167+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
168+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
169+
}
170+
},
171+
{
172+
"deepseek/mlp_lnx: bfloat16[96,2048,2048]": {
173+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
174+
"PartitionSpec": "P(('fsdp', 'expert'), 'context', None)"
175+
}
176+
}
177+
]
178+
}

0 commit comments

Comments
 (0)