Skip to content

Commit f28f38c

Browse files
committed
deprecate old 2dfsdp functions
1 parent a80c31f commit f28f38c

5 files changed

Lines changed: 0 additions & 101 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,6 @@ moe_fsdp_use_two_stage_all_gather: false
250250
# Shard the expert dimension of the MLP weights on the FSDP axis.
251251
# This configuration is recommended only when num_experts is a multiple of fsdp_parallelism
252252
shard_exp_on_fsdp: False
253-
# use fsdp and fsdp_transpose axes for sharding the moe weights
254-
use_2d_fsdp_sharding: False
255253

256254
# deepseek moe
257255
first_num_dense_layers: 0 # number of initial dense layers in the model

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 0 additions & 86 deletions
This file was deleted.

src/maxtext/configs/types.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ class ProfilerType(str, Enum):
221221
"deepseek2-16b",
222222
"deepseek2-236b",
223223
"deepseek3-671b",
224-
"deepseek3-671b-2dfsdp",
225224
"deepseek3-671b-batchsplit",
226225
"deepseek3-test",
227226
"deepseek3-tiny",
@@ -705,10 +704,6 @@ class MoEGeneral(BaseModel):
705704
description="Shard the expert dimension of the MLP weights on the FSDP axis, "
706705
"and recommended only when num_experts is a multiple of fsdp_parallelism",
707706
)
708-
use_2d_fsdp_sharding: bool = Field(
709-
False,
710-
description="Use `fsdp` and `fsdp_transpose` axes for 2D FSDP sharding.",
711-
)
712707
norm_topk_prob: bool = Field(
713708
False,
714709
description="Enable top-k probability normalization for router weights (Qwen3-specific).",

src/maxtext/layers/moe.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,6 @@ def __init__(
358358
# special sharding for dsv3
359359
self.wi_kernel_axes = ("embed_moe", None, "mlp_moe")
360360
self.wo_kernel_axes = ("embed_moe", "mlp_moe", None)
361-
elif self.config.use_2d_fsdp_sharding:
362-
self.wi_kernel_axes = ("embed_moe", "mlp_moe", None)
363-
self.wo_kernel_axes = ("embed_moe", "mlp_moe", None)
364361
elif self.config.use_batch_split_schedule:
365362
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
366363
else:
@@ -1217,10 +1214,6 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert):
12171214
w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
12181215
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
12191216
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
1220-
elif self.config.use_2d_fsdp_sharding:
1221-
w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
1222-
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
1223-
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
12241217
else:
12251218
# These are the main shardings used by default - they use funky rules to AG over FSDP.
12261219
w0_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))

tests/unit/configs_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def test_gpt_configs(config_file):
200200
os.path.join(CONFIGS_DIR, "models", "deepseek2-236b.yml"),
201201
os.path.join(CONFIGS_DIR, "models", "deepseek3-test.yml"),
202202
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b.yml"),
203-
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b-2dfsdp.yml"),
204203
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b-batchsplit.yml"),
205204
]
206205

0 commit comments

Comments
 (0)