Skip to content

Torch FSDP2 crashes in finalize_model_grads because _BaseDataParallel.finish_grad_sync() does not accept force_all_reduce #4669

@Jayoprell

Description

@Jayoprell

Summary

When running a pretraining job with Torch FSDP2 enabled, training fails during gradient finalization with:

TypeError: _BaseDataParallel.finish_grad_sync() got an unexpected keyword argument 'force_all_reduce'

This appears to be an interface mismatch introduced after force_all_reduce was added to the gradient finalization path.

Analysis

finalize_model_grads() now calls:

model_chunk.finish_grad_sync(force_all_reduce=force_all_reduce)

This works for DistributedDataParallel, whose method signature is:

def finish_grad_sync(self, force_all_reduce: Optional[bool] = False):

It also works for Megatron FSDP, which accepts the same compatibility argument.

However, TorchFullyShardedDataParallel does not define its own finish_grad_sync() method, so it inherits the base implementation from _BaseDataParallel:

  def finish_grad_sync(self):
      pass

As a result, any Torch FSDP2 run that reaches finalize_model_grads() receives the unexpected keyword argument and fails before the optimizer step.

Expected behavior

Torch FSDP2 should be able to ignore force_all_reduce if it does not need it, instead of failing due to the base class method signature.

Suggested fix

Update the base data-parallel interface in:

megatron/core/distributed/data_parallel_base.py

to accept the optional compatibility argument:

  diff --git a/megatron/core/distributed/data_parallel_base.py b/megatron/core/distributed/data_parallel_base.py
  index ... .. ...
  --- a/megatron/core/distributed/data_parallel_base.py
  +++ b/megatron/core/distributed/data_parallel_base.py
  @@ -1,6 +1,7 @@
   # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

   from contextlib import contextmanager
  +from typing import Optional

   import torch

  @@ -46,7 +47,7 @@ class _BaseDataParallel(MegatronModule):
       def scale_gradients(self, scaling_factor: float) -> None:
           """Scale all gradients inside the buffers by `scaling_factor`."""
           pass

  -    def finish_grad_sync(self):
  +    def finish_grad_sync(self, force_all_reduce: Optional[bool] = False):
           """
           Finishes grad sync (all-reduce or reduce-scatter) communication operations
           for all model gradients.
  @@ -55,6 +56,10 @@ class _BaseDataParallel(MegatronModule):
           calls to complete. When overlap_grad_reduce is set to False, calls synchronous
           communication ops.
  +
  +        Args:
  +            force_all_reduce: Optional compatibility argument used by implementations
  +                that may need to force all-reduce instead of reduce-scatter.
           """
           pass

This keeps the base class interface consistent with the current call site and with existing DDP/Megatron-FSDP implementations. It also allows Torch FSDP2 to continue using the base no-op
behavior without crashing.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions