Skip to content

Commit fb0fdce

Browse files
Merge pull request #3894 from AI-Hypercomputer:chengnuojin-2dfsdp
PiperOrigin-RevId: 915542547
2 parents b7afa2d + f28f38c commit fb0fdce

11 files changed

Lines changed: 3973 additions & 109 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,4 @@ class CustomRule(enum.Enum):
146146
CP_AS_EP = "cp-as-ep" # Support CP and EP together
147147
EP_AS_CP = "ep-as-cp" # Support EP only
148148
PIPELINE_LARGE_MOE = "pipeline-large-moe"
149+
FSDP_2D = "2d-fsdp"

src/maxtext/configs/base.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,6 @@ moe_fsdp_use_two_stage_all_gather: false
251251
# Shard the expert dimension of the MLP weights on the FSDP axis.
252252
# This configuration is recommended only when num_experts is a multiple of fsdp_parallelism
253253
shard_exp_on_fsdp: False
254-
# use fsdp and fsdp_transpose axes for sharding the moe weights
255-
use_2d_fsdp_sharding: False
256254

257255
# deepseek moe
258256
first_num_dense_layers: 0 # number of initial dense layers in the model
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# When scaling to a large number of devices with limited model dimensions,
16+
# introducing an additional FSDP axis prevents sharding limits and improves
17+
# GMM efficiency. This rule demonstrates using both `fsdp` and `fsdp_transpose`
18+
# to enable efficient training across O(1000) chips.
19+
20+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']
21+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']]
22+
context_sharding: 'context'
23+
logical_axis_rules: [
24+
# ==========================================
25+
# Vocabulary Embedding
26+
# ==========================================
27+
# Vocab Activations
28+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
29+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']],
30+
# Vocab Weights
31+
['vocab', []],
32+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
33+
# ==========================================
34+
# Attention
35+
# ==========================================
36+
# Attention Activations
37+
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
38+
['activation_length_attn', ['context']],
39+
['activation_q_length', ['context']],
40+
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
41+
# Attention Weights
42+
['q_lora', ['fsdp']],
43+
["q_lora_up_proj", ['fsdp_transpose', 'expert']],
44+
['kv_lora', ['fsdp']],
45+
["kv_lora_up_proj", ['fsdp_transpose', 'expert']],
46+
# ==========================================
47+
# Mixture of Experts (MoE)
48+
# ==========================================
49+
# MoE Activations
50+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
51+
['activation_length_moe', ['context']],
52+
['activation_norm_length_moe', ['context']],
53+
['activation_mlp_moe', []],
54+
['activation_exp', ['expert']],
55+
# MoE Weights
56+
['exp', 'expert'],
57+
['mlp_moe', ['fsdp_transpose']],
58+
['embed_moe', ['fsdp', 'context']],
59+
# ==========================================
60+
# Standard MLP / Dense Layers / Model Structure
61+
# ==========================================
62+
# Dense Activations
63+
['activation_mlp', []],
64+
# Note activation batch and length also get used in vocab
65+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
66+
['activation_length', ['context']],
67+
['activation_norm_length', ['context']],
68+
['activation_embed', []],
69+
['activation_stage', 'stage'],
70+
# General Weights
71+
['mlp', ['fsdp_transpose']],
72+
['embed', ['fsdp', 'context', 'expert']],
73+
['norm', []],
74+
['layers', 'stage'],
75+
]

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 0 additions & 86 deletions
This file was deleted.

src/maxtext/configs/types.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ class ProfilerType(str, Enum):
221221
"deepseek2-16b",
222222
"deepseek2-236b",
223223
"deepseek3-671b",
224-
"deepseek3-671b-2dfsdp",
225224
"deepseek3-671b-batchsplit",
226225
"deepseek3-test",
227226
"deepseek3-tiny",
@@ -718,10 +717,6 @@ class MoEGeneral(BaseModel):
718717
description="Shard the expert dimension of the MLP weights on the FSDP axis, "
719718
"and recommended only when num_experts is a multiple of fsdp_parallelism",
720719
)
721-
use_2d_fsdp_sharding: bool = Field(
722-
False,
723-
description="Use `fsdp` and `fsdp_transpose` axes for 2D FSDP sharding.",
724-
)
725720
norm_topk_prob: bool = Field(
726721
False,
727722
description="Enable top-k probability normalization for router weights (Qwen3-specific).",
@@ -3050,13 +3045,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
30503045
"tensor": self.ici_tensor_parallelism,
30513046
"tensor_transpose": self.ici_tensor_transpose_parallelism,
30523047
"tensor_sequence": self.ici_tensor_sequence_parallelism,
3053-
"model": self.ici_tensor_parallelism,
30543048
"expert": self.ici_expert_parallelism,
30553049
"autoregressive": self.ici_autoregressive_parallelism,
3056-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
3057-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
30583050
}
3059-
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
30603051

30613052
dcn_map = {
30623053
"diloco": self.dcn_diloco_parallelism,
@@ -3070,12 +3061,37 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
30703061
"tensor": self.dcn_tensor_parallelism,
30713062
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
30723063
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
3073-
"model": self.dcn_tensor_parallelism,
30743064
"expert": self.dcn_expert_parallelism,
30753065
"autoregressive": self.dcn_autoregressive_parallelism,
3076-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
3077-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
30783066
}
3067+
3068+
# Conditionally include vLLM RPA specific axes
3069+
if self.attention == "vllm_rpa":
3070+
ici_map.update(
3071+
{
3072+
"model": self.ici_tensor_parallelism,
3073+
"attn_dp": 1,
3074+
"attn_dp_expert": 1,
3075+
}
3076+
)
3077+
dcn_map.update(
3078+
{
3079+
"model": self.dcn_tensor_parallelism,
3080+
"attn_dp": 1,
3081+
"attn_dp_expert": 1,
3082+
}
3083+
)
3084+
3085+
# Validate that any axis with configured parallelism > 1 is present in mesh_axes
3086+
for axis, ici_size in ici_map.items():
3087+
if axis not in self.mesh_axes:
3088+
if ici_size > 1 or dcn_map[axis] > 1:
3089+
raise ValueError(
3090+
f"Mesh axis '{axis}' has configured parallelism > 1 "
3091+
f"(ici: {ici_size}, dcn: {dcn_map[axis]}) "
3092+
f"but is not included in self.mesh_axes: {self.mesh_axes}"
3093+
)
3094+
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
30793095
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
30803096

30813097
# Diloco params

src/maxtext/layers/moe.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,6 @@ def __init__(
358358
# special sharding for dsv3
359359
self.wi_kernel_axes = ("embed_moe", None, "mlp_moe")
360360
self.wo_kernel_axes = ("embed_moe", "mlp_moe", None)
361-
elif self.config.use_2d_fsdp_sharding:
362-
self.wi_kernel_axes = ("embed_moe", "mlp_moe", None)
363-
self.wo_kernel_axes = ("embed_moe", "mlp_moe", None)
364361
elif self.config.use_batch_split_schedule:
365362
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
366363
else:
@@ -1217,10 +1214,6 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert):
12171214
w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
12181215
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
12191216
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
1220-
elif self.config.use_2d_fsdp_sharding:
1221-
w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
1222-
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
1223-
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
12241217
else:
12251218
# These are the main shardings used by default - they use funky rules to AG over FSDP.
12261219
w0_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))

tests/unit/configs_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def test_gpt_configs(config_file):
200200
os.path.join(CONFIGS_DIR, "models", "deepseek2-236b.yml"),
201201
os.path.join(CONFIGS_DIR, "models", "deepseek3-test.yml"),
202202
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b.yml"),
203-
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b-2dfsdp.yml"),
204203
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b-batchsplit.yml"),
205204
]
206205

tests/utils/sharding_dump.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@
6262
"cp-as-ep",
6363
("ici_fsdp_parallelism=-1", "ici_context_parallelism=2", "ici_expert_parallelism=2"),
6464
),
65+
(
66+
"deepseek2-16b",
67+
"tpu7x-8",
68+
1,
69+
"2d-fsdp",
70+
("ici_fsdp_parallelism=-1", "ici_fsdp_transpose_parallelism=2"),
71+
),
6572
("qwen3-0.6b", "tpu7x-16", 1, "", ()),
6673
("gpt-oss-20b", "tpu7x-16", 1, "", ()),
6774
("gpt-oss-20b", "tpu7x-16", 1, "", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2")),
@@ -168,7 +175,14 @@ def main(argv: Sequence[str]) -> None:
168175
validate_config(config)
169176
print(f"Sharding debug: {config.debug_sharding}")
170177

171-
rule_name = f"rule_{config.custom_mesh_and_rule}" if config.custom_mesh_and_rule else "rule_default"
178+
# Extract custom_mesh_and_rule directly from argv test case string
179+
custom_mesh_and_rule = ""
180+
for arg in argv:
181+
if arg.startswith("custom_mesh_and_rule="):
182+
custom_mesh_and_rule = arg.split("=", 1)[1]
183+
break
184+
185+
rule_name = f"rule_{custom_mesh_and_rule}" if custom_mesh_and_rule else "rule_default"
172186
# Find overrides from argv to append to rule_name
173187
overrides = []
174188
for arg in argv:

0 commit comments

Comments
 (0)