@@ -66,15 +66,9 @@ def test_ddp_vs_fsdp_single_gpu(strategy, backend):
6666
6767
6868@requires_multi_gpu
69- @pytest .mark .parametrize ("strategy" , ["fsdp2" , "mfsdp" ])
69+ @pytest .mark .parametrize ("strategy" , ["fsdp2" , pytest . param ( "mfsdp" , marks = pytest . mark . xfail ( reason = "BIONEMO-2726" )) ])
7070@pytest .mark .parametrize ("backend" , ["te" , "eager" ])
7171def test_ddp_vs_fsdp_multi_gpu (strategy , backend ):
72- if strategy == "mfsdp" :
73- pytest .skip (
74- "MFSDP multi-gpu tests are currently failing because tensors are not always evenly sharded, leaving p.grad "
75- "to be None on some ranks (BIONEMO-2726)"
76- )
77-
7872 cmd = [
7973 "torchrun" ,
8074 "--nproc_per_node=2" ,
@@ -213,7 +207,18 @@ def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dis
213207 grads = {name : p .grad for name , p in model .module .named_parameters () if p .grad is not None }
214208
215209 elif strategy is Strategy .MFSDP :
216- grads = {name : p .grad .full_tensor () for name , p in model .module .named_parameters () if p .grad is not None }
210+ # Because of uneven sharding, we need to manually gather the gradients.
211+ sharded_grads = [(name , p .grad ) for name , p in model .module .named_parameters ()]
212+ grads = {}
213+ for name , grad in sharded_grads :
214+ grad_shards = [None ] * device_mesh ["dp" ].size ()
215+ # For FSDP, we are not strided sharding, so gathering across dp_shard_cp is sufficient.
216+ # For HSDP, we need to first gather across dp_shard_cp, then gather across dp_inter,
217+ # not the other way around or you'll get wrong zig-zags.
218+ torch .distributed .all_gather_object (grad_shards , grad , group = device_mesh ["dp" ].get_group ())
219+ all_valid_shards = [shard for shard in grad_shards if shard is not None ]
220+ # Megatron-FSDP is always sharded across dim=0.
221+ grads [name ] = torch .cat ([s .to_local ().to (device ) for s in all_valid_shards ], dim = 0 )
217222
218223 del model
219224 torch .cuda .empty_cache ()
0 commit comments