Skip to content

Commit f5a5c88

Browse files
committed
fix mfsdp reduction, update error to be the difference in multiple
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 96db758 commit f5a5c88

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

models/esm2/tests/test_distributed_strategies.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,9 @@ def test_ddp_vs_fsdp_single_gpu(strategy, backend):
6666

6767

6868
@requires_multi_gpu
69-
@pytest.mark.parametrize("strategy", ["fsdp2", "mfsdp"])
69+
@pytest.mark.parametrize("strategy", ["fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2726"))])
7070
@pytest.mark.parametrize("backend", ["te", "eager"])
7171
def test_ddp_vs_fsdp_multi_gpu(strategy, backend):
72-
if strategy == "mfsdp":
73-
pytest.skip(
74-
"MFSDP multi-gpu tests are currently failing because tensors are not always evenly sharded, leaving p.grad "
75-
"to be None on some ranks (BIONEMO-2726)"
76-
)
77-
7872
cmd = [
7973
"torchrun",
8074
"--nproc_per_node=2",
@@ -213,7 +207,18 @@ def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dis
213207
grads = {name: p.grad for name, p in model.module.named_parameters() if p.grad is not None}
214208

215209
elif strategy is Strategy.MFSDP:
216-
grads = {name: p.grad.full_tensor() for name, p in model.module.named_parameters() if p.grad is not None}
210+
# Because of uneven sharding, we need to manually gather the gradients.
211+
sharded_grads = [(name, p.grad) for name, p in model.module.named_parameters()]
212+
grads = {}
213+
for name, grad in sharded_grads:
214+
grad_shards = [None] * device_mesh["dp"].size()
215+
# For FSDP, we are not strided sharding, so gathering across dp_shard_cp is sufficient.
216+
# For HSDP, we need to first gather across dp_shard_cp, then gather across dp_inter,
217+
# not the other way around or you'll get wrong zig-zags.
218+
torch.distributed.all_gather_object(grad_shards, grad, group=device_mesh["dp"].get_group())
219+
all_valid_shards = [shard for shard in grad_shards if shard is not None]
220+
# Megatron-FSDP is always sharded across dim=0.
221+
grads[name] = torch.cat([s.to_local().to(device) for s in all_valid_shards], dim=0)
217222

218223
del model
219224
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)