Skip to content

Commit a80c31f

Browse files
committed
add 2d fsdp custom mesh
1 parent 41a5d98 commit a80c31f

8 files changed

Lines changed: 3974 additions & 9 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,4 @@ class CustomRule(enum.Enum):
146146
CP_AS_EP = "cp-as-ep" # Support CP and EP together
147147
EP_AS_CP = "ep-as-cp" # Support EP only
148148
PIPELINE_LARGE_MOE = "pipeline-large-moe"
149+
FSDP_2D = "2d-fsdp"
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# When scaling to a large number of devices with limited model dimensions,
16+
# introducing an additional FSDP axis prevents sharding limits and improves
17+
# GMM efficiency. This rule demonstrates using both `fsdp` and `fsdp_transpose`
18+
# to enable efficient training across O(1000) chips.
19+
20+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']
21+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']]
22+
context_sharding: 'context'
23+
logical_axis_rules: [
24+
# ==========================================
25+
# Vocabulary Embedding
26+
# ==========================================
27+
# Vocab Activations
28+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
29+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']],
30+
# Vocab Weights
31+
['vocab', []],
32+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
33+
# ==========================================
34+
# Attention
35+
# ==========================================
36+
# Attention Activations
37+
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
38+
['activation_length_attn', ['context']],
39+
['activation_q_length', ['context']],
40+
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
41+
# Attention Weights
42+
['q_lora', ['fsdp']],
43+
["q_lora_up_proj", ['fsdp_transpose', 'expert']],
44+
['kv_lora', ['fsdp']],
45+
["kv_lora_up_proj", ['fsdp_transpose', 'expert']],
46+
# ==========================================
47+
# Mixture of Experts (MoE)
48+
# ==========================================
49+
# MoE Activations
50+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
51+
['activation_length_moe', ['context']],
52+
['activation_norm_length_moe', ['context']],
53+
['activation_mlp_moe', []],
54+
['activation_exp', ['expert']],
55+
# MoE Weights
56+
['exp', 'expert'],
57+
['mlp_moe', ['fsdp_transpose']],
58+
['embed_moe', ['fsdp', 'context']],
59+
# ==========================================
60+
# Standard MLP / Dense Layers / Model Structure
61+
# ==========================================
62+
# Dense Activations
63+
['activation_mlp', []],
64+
# Note activation batch and length also get used in vocab
65+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
66+
['activation_length', ['context']],
67+
['activation_norm_length', ['context']],
68+
['activation_embed', []],
69+
['activation_stage', 'stage'],
70+
# General Weights
71+
['mlp', ['fsdp_transpose']],
72+
['embed', ['fsdp', 'context', 'expert']],
73+
['norm', []],
74+
['layers', 'stage'],
75+
]

src/maxtext/configs/types.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,13 +2994,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
29942994
"tensor": self.ici_tensor_parallelism,
29952995
"tensor_transpose": self.ici_tensor_transpose_parallelism,
29962996
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2997-
"model": self.ici_tensor_parallelism,
29982997
"expert": self.ici_expert_parallelism,
29992998
"autoregressive": self.ici_autoregressive_parallelism,
3000-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
3001-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
30022999
}
3003-
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
30043000

30053001
dcn_map = {
30063002
"diloco": self.dcn_diloco_parallelism,
@@ -3014,12 +3010,37 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
30143010
"tensor": self.dcn_tensor_parallelism,
30153011
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
30163012
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
3017-
"model": self.dcn_tensor_parallelism,
30183013
"expert": self.dcn_expert_parallelism,
30193014
"autoregressive": self.dcn_autoregressive_parallelism,
3020-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
3021-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
30223015
}
3016+
3017+
# Conditionally include vLLM RPA specific axes
3018+
if self.attention == "vllm_rpa":
3019+
ici_map.update(
3020+
{
3021+
"model": self.ici_tensor_parallelism,
3022+
"attn_dp": 1,
3023+
"attn_dp_expert": 1,
3024+
}
3025+
)
3026+
dcn_map.update(
3027+
{
3028+
"model": self.dcn_tensor_parallelism,
3029+
"attn_dp": 1,
3030+
"attn_dp_expert": 1,
3031+
}
3032+
)
3033+
3034+
# Validate that any axis with configured parallelism > 1 is present in mesh_axes
3035+
for axis, ici_size in ici_map.items():
3036+
if axis not in self.mesh_axes:
3037+
if ici_size > 1 or dcn_map[axis] > 1:
3038+
raise ValueError(
3039+
f"Mesh axis '{axis}' has configured parallelism > 1 "
3040+
f"(ici: {ici_size}, dcn: {dcn_map[axis]}) "
3041+
f"but is not included in self.mesh_axes: {self.mesh_axes}"
3042+
)
3043+
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
30233044
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
30243045

30253046
# Diloco params

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,4 +323,4 @@ def load_weights(self, rng_key: jax.Array) -> None:
323323
model = model_creation_utils.from_pretrained(
324324
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
325325
)
326-
self.model = nnx.data(model)
326+
self.model = nnx.data(model)

tests/utils/sharding_dump.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@
6262
"cp-as-ep",
6363
("ici_fsdp_parallelism=-1", "ici_context_parallelism=2", "ici_expert_parallelism=2"),
6464
),
65+
(
66+
"deepseek2-16b",
67+
"tpu7x-8",
68+
1,
69+
"2d-fsdp",
70+
("ici_fsdp_parallelism=-1", "ici_fsdp_transpose_parallelism=2"),
71+
),
6572
("qwen3-0.6b", "tpu7x-16", 1, "", ()),
6673
("gpt-oss-20b", "tpu7x-16", 1, "", ()),
6774
("gpt-oss-20b", "tpu7x-16", 1, "", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2")),
@@ -168,7 +175,14 @@ def main(argv: Sequence[str]) -> None:
168175
validate_config(config)
169176
print(f"Sharding debug: {config.debug_sharding}")
170177

171-
rule_name = f"rule_{config.custom_mesh_and_rule}" if config.custom_mesh_and_rule else "rule_default"
178+
# Extract custom_mesh_and_rule directly from argv test case string
179+
custom_mesh_and_rule = ""
180+
for arg in argv:
181+
if arg.startswith("custom_mesh_and_rule="):
182+
custom_mesh_and_rule = arg.split("=", 1)[1]
183+
break
184+
185+
rule_name = f"rule_{custom_mesh_and_rule}" if custom_mesh_and_rule else "rule_default"
172186
# Find overrides from argv to append to rule_name
173187
overrides = []
174188
for arg in argv:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
{
2+
"Activation Sharding Dump": [
3+
{
4+
"deepseek/inputs: bfloat16[96,2048,2048]": {
5+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
6+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
7+
}
8+
},
9+
{
10+
"deepseek/pre_attention_norm: bfloat16[96,2048,2048]": {
11+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
12+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
13+
}
14+
},
15+
{
16+
"attention_mla/inputs_q: bfloat16[96,2048,2048]": {
17+
"logic_axes": "('activation_batch_attn', 'activation_length', 'activation_embed')",
18+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
19+
}
20+
},
21+
{
22+
"attention_mla/inputs_kv: bfloat16[96,2048,2048]": {
23+
"logic_axes": "('activation_batch_attn', 'activation_length', 'activation_embed')",
24+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
25+
}
26+
},
27+
{
28+
"attention_mla/q_nope: bfloat16[96,2048,16,128]": {
29+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
30+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
31+
}
32+
},
33+
{
34+
"attention_mla/q_pe: bfloat16[96,2048,16,64]": {
35+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
36+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
37+
}
38+
},
39+
{
40+
"attention_mla/query: bfloat16[96,2048,16,192]": {
41+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
42+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
43+
}
44+
},
45+
{
46+
"attention_mla/key_nope: bfloat16[96,2048,16,128]": {
47+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
48+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
49+
}
50+
},
51+
{
52+
"attention_mla/key_rope: bfloat16[96,2048,16,64]": {
53+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
54+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
55+
}
56+
},
57+
{
58+
"attention_mla/key: bfloat16[96,2048,16,192]": {
59+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
60+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
61+
}
62+
},
63+
{
64+
"attention_mla/value: bfloat16[96,2048,16,128]": {
65+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
66+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
67+
}
68+
},
69+
{
70+
"attention_op/arr: int8[1,4,4]": {
71+
"logic_axes": "Unknown",
72+
"PartitionSpec": "P(None, None)"
73+
}
74+
},
75+
{
76+
"attention_op/arr: int32[2048]": {
77+
"logic_axes": "Unknown",
78+
"PartitionSpec": "P(None,)"
79+
}
80+
},
81+
{
82+
"attention_op/query: bfloat16[96,16,2048,192]": {
83+
"logic_axes": "Unknown",
84+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
85+
}
86+
},
87+
{
88+
"attention_op/key: bfloat16[96,16,2048,192]": {
89+
"logic_axes": "Unknown",
90+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
91+
}
92+
},
93+
{
94+
"attention_op/value: bfloat16[96,16,2048,128]": {
95+
"logic_axes": "Unknown",
96+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
97+
}
98+
},
99+
{
100+
"attention_mla/out: bfloat16[96,2048,16,128]": {
101+
"logic_axes": "('activation_batch_attn', 'activation_length', 'activation_heads', 'activation_kv')",
102+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)"
103+
}
104+
},
105+
{
106+
"deepseek/attention_result: bfloat16[96,2048,2048]": {
107+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
108+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
109+
}
110+
},
111+
{
112+
"deepseek/post_attention_norm: bfloat16[96,2048,2048]": {
113+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
114+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
115+
}
116+
},
117+
{
118+
"linears/x: bfloat16[96,2048,10944]": {
119+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
120+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
121+
}
122+
},
123+
{
124+
"deepseek/mlp: bfloat16[96,2048,2048]": {
125+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
126+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
127+
}
128+
},
129+
{
130+
"deepseek/x: bfloat16[96,2048,2048]": {
131+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
132+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
133+
}
134+
},
135+
{
136+
"moe/inputs: bfloat16[96,2048,2048]": {
137+
"logic_axes": "('activation_batch', 'activation_norm_length', None)",
138+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
139+
}
140+
},
141+
{
142+
"moe/gate_logits: bfloat16[96,2048,64]": {
143+
"logic_axes": "('activation_batch', 'activation_norm_length', None)",
144+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
145+
}
146+
},
147+
{
148+
"moe/w0_kernel: bfloat16[64,2048,1408]": {
149+
"logic_axes": "Unknown",
150+
"PartitionSpec": "P(None, None, None)"
151+
}
152+
},
153+
{
154+
"moe/w1_kernel: bfloat16[64,2048,1408]": {
155+
"logic_axes": "Unknown",
156+
"PartitionSpec": "P(None, None, None)"
157+
}
158+
},
159+
{
160+
"moe/wo_kernel: bfloat16[64,1408,2048]": {
161+
"logic_axes": "Unknown",
162+
"PartitionSpec": "P(None, None, None)"
163+
}
164+
},
165+
{
166+
"linears/x: bfloat16[96,2048,2816]": {
167+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
168+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
169+
}
170+
},
171+
{
172+
"deepseek/mlp_lnx: bfloat16[96,2048,2048]": {
173+
"logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
174+
"PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)"
175+
}
176+
}
177+
]
178+
}

0 commit comments

Comments
 (0)