@@ -358,9 +358,6 @@ def __init__(
358358 # special sharding for dsv3
359359 self .wi_kernel_axes = ("embed_moe" , None , "mlp_moe" )
360360 self .wo_kernel_axes = ("embed_moe" , "mlp_moe" , None )
361- elif self .config .use_2d_fsdp_sharding :
362- self .wi_kernel_axes = ("embed_moe" , "mlp_moe" , None )
363- self .wo_kernel_axes = ("embed_moe" , "mlp_moe" , None )
364361 elif self .config .use_batch_split_schedule :
365362 self .wi_kernel_axes , self .wo_kernel_axes = get_batchsplit_init_kernel_axes ()
366363 else :
@@ -1217,10 +1214,6 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert):
12171214 w0_pspec = self ._logical_to_mesh_axes (("embed_tensor_transpose" , None , "mlp_no_fsdp" ))
12181215 w1_pspec = self ._logical_to_mesh_axes (("embed_tensor_transpose" , None , "mlp_no_fsdp" ))
12191216 wo_pspec = self ._logical_to_mesh_axes (("embed_tensor_transpose" , "mlp_no_fsdp" , None ))
1220- elif self .config .use_2d_fsdp_sharding :
1221- w0_pspec = self ._logical_to_mesh_axes (("embed_tensor_transpose" , "mlp_no_fsdp" , None ))
1222- w1_pspec = self ._logical_to_mesh_axes (("embed_tensor_transpose" , "mlp_no_fsdp" , None ))
1223- wo_pspec = self ._logical_to_mesh_axes (("embed_tensor_transpose" , "mlp_no_fsdp" , None ))
12241217 else :
12251218 # These are the main shardings used by default - they use funky rules to AG over FSDP.
12261219 w0_pspec = self ._logical_to_mesh_axes (("exp" , "embed_tensor_transpose" , "mlp_no_fsdp" ))
0 commit comments