2020
2121
2222@contextmanager
23- def initialize_distributed (cfg ):
23+ def initialize_distributed (
24+ dp_outer : int = 1 ,
25+ dp_shard : int = 1 ,
26+ cp : int = 1 ,
27+ tp : int = 1 ,
28+ ):
2429 """
2530 Setup the DeviceMesh for distributed training.
2631
2732 Args:
28- cfg: Hydra config.
33+ dp_outer: The size of the data parallelism outer dimension.
34+ dp_shard: The size of the data parallelism shard dimension.
35+ cp: The size of the context parallelism dimension.
36+ tp: The size of the tensor parallelism dimension.
2937
3038 Yields:
3139 device_mesh: The DeviceMesh.
@@ -45,30 +53,30 @@ def initialize_distributed(cfg):
4553 # TODO(@cspades): Will add TE-backed context parallelism (CP) in the future, just need to
4654 # modify the ViT model to shard the sequence dimension after tokenization. For now, we
4755 # setup the CP dimension for demonstrating how to use DeviceMesh and CP with Megatron-FSDP.
48- if cfg . distributed . dp_inter * cfg . distributed . dp_shard * cfg . distributed . cp != torch .distributed .get_world_size ():
56+ if dp_outer * dp_shard * cp != torch .distributed .get_world_size ():
4957 raise ValueError (
50- f"Invalid parallelism sizes: dp_inter( { cfg . distributed . dp_inter } ) * dp_shard({ cfg . distributed . dp_shard } ) * cp({ cfg . distributed . cp } ) * tp(1 ) != world_size({ torch .distributed .get_world_size ()} )"
58+ f"Invalid parallelism sizes: dp_outer( { dp_outer } ) * dp_shard({ dp_shard } ) * cp({ cp } ) * tp({ tp } ) != world_size({ torch .distributed .get_world_size ()} )"
5159 )
5260 device_mesh = torch .distributed .device_mesh .init_device_mesh (
5361 "cuda" ,
5462 mesh_shape = (
55- cfg . distributed . dp_inter ,
56- cfg . distributed . dp_shard ,
57- cfg . distributed . cp ,
58- 1 , # Needed to use TransformerEngine layers with Megatron-FSDP. "TP is always 1."
63+ dp_outer ,
64+ dp_shard ,
65+ cp ,
66+ tp , # Needed to use TransformerEngine layers with Megatron-FSDP.
5967 ),
60- mesh_dim_names = ("dp_inter " , "dp_shard" , "cp" , "tp" ),
68+ mesh_dim_names = ("dp_outer " , "dp_shard" , "cp" , "tp" ),
6169 )
6270
6371 # Sub-meshes (possibly) required for Megatron-FSDP.
6472 # WARNING: These have a tendency to be deleted by Torch. Save references
6573 # or pass them to all classes or functions that use them.
6674 # DP: Only relevant when using HSDP, where we need the flattened DP group for data parallelism. (Otherwise, just pass dp_shard.)
67- device_mesh [("dp_inter " , "dp_shard" )]._flatten ("dp" )
75+ device_mesh [("dp_outer " , "dp_shard" )]._flatten ("dp" )
6876 # DP-Shard-CP: Only required if using CP. Otherwise, just pass dp_shard to FSDP.
6977 device_mesh [("dp_shard" , "cp" )]._flatten ("dp_cp_shard" )
7078 # HSDP (DP-CP): Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group to Megatron-FSDP.
71- device_mesh [("dp_inter " , "dp_shard" , "cp" )]._flatten ("hsdp" )
79+ device_mesh [("dp_outer " , "dp_shard" , "cp" )]._flatten ("hsdp" )
7280
7381 # Yield DeviceMesh.
7482 yield device_mesh
0 commit comments