Skip to content

Commit 37c0562

Browse files
committed
Pin TP to 1, remove from config.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent ba27aa8 commit 37c0562

3 files changed

Lines changed: 3 additions & 8 deletions

File tree

recipes/vit/config/defaults.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ distributed:
4444
dp_inter: 1
4545
dp_shard: 1
4646
cp: 1
47-
tp: 1
4847

4948
fsdp:
5049
init_model_with_meta_device: true

recipes/vit/config/vit_base_patch16_224.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ distributed:
4242
dp_inter: 1
4343
dp_shard: 1
4444
cp: 1
45-
tp: 1
4645

4746
fsdp:
4847
init_model_with_meta_device: true

recipes/vit/train.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,20 +172,17 @@ def setup_device_mesh(cfg):
172172
# TODO(@cspades): Will add TE-backed context parallelism (CP) in the future, just need to
173173
# modify the ViT model to shard the sequence dimension after tokenization. For now, we
174174
# setup the CP dimension for demonstrating how to use DeviceMesh and CP with Megatron-FSDP.
175-
if (
176-
cfg.distributed.dp_inter * cfg.distributed.dp_shard * cfg.distributed.cp * cfg.distributed.tp
177-
!= torch.distributed.get_world_size()
178-
):
175+
if cfg.distributed.dp_inter * cfg.distributed.dp_shard * cfg.distributed.cp != torch.distributed.get_world_size():
179176
raise ValueError(
180-
f"Invalid parallelism sizes: dp_inter({cfg.distributed.dp_inter}) * dp_shard({cfg.distributed.dp_shard}) * cp({cfg.distributed.cp}) * tp({cfg.distributed.tp}) != world_size({torch.distributed.get_world_size()})"
177+
f"Invalid parallelism sizes: dp_inter({cfg.distributed.dp_inter}) * dp_shard({cfg.distributed.dp_shard}) * cp({cfg.distributed.cp}) * tp(1) != world_size({torch.distributed.get_world_size()})"
181178
)
182179
device_mesh = torch.distributed.device_mesh.init_device_mesh(
183180
"cuda",
184181
mesh_shape=(
185182
cfg.distributed.dp_inter,
186183
cfg.distributed.dp_shard,
187184
cfg.distributed.cp,
188-
cfg.distributed.tp,
185+
1, # Needed to use TransformerEngine layers with Megatron-FSDP. "TP is always 1."
189186
),
190187
mesh_dim_names=("dp_inter", "dp_shard", "cp", "tp"),
191188
)

0 commit comments

Comments
 (0)