Summary
When running a pretraining job with Torch FSDP2 enabled, training fails during gradient finalization with:
TypeError: _BaseDataParallel.finish_grad_sync() got an unexpected keyword argument 'force_all_reduce'
This appears to be an interface mismatch introduced after force_all_reduce was added to the gradient finalization path.
Analysis
finalize_model_grads() now calls:
model_chunk.finish_grad_sync(force_all_reduce=force_all_reduce)
This works for DistributedDataParallel, whose method signature is:
def finish_grad_sync(self, force_all_reduce: Optional[bool] = False):
It also works for Megatron FSDP, which accepts the same compatibility argument.
However, TorchFullyShardedDataParallel does not define its own finish_grad_sync() method, so it inherits the base implementation from _BaseDataParallel:
def finish_grad_sync(self):
pass
As a result, any Torch FSDP2 run that reaches finalize_model_grads() receives the unexpected keyword argument and fails before the optimizer step.
Expected behavior
Torch FSDP2 should be able to ignore force_all_reduce if it does not need it, instead of failing due to the base class method signature.
Suggested fix
Update the base data-parallel interface in:
megatron/core/distributed/data_parallel_base.py
to accept the optional compatibility argument:
diff --git a/megatron/core/distributed/data_parallel_base.py b/megatron/core/distributed/data_parallel_base.py
index ... .. ...
--- a/megatron/core/distributed/data_parallel_base.py
+++ b/megatron/core/distributed/data_parallel_base.py
@@ -1,6 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from contextlib import contextmanager
+from typing import Optional
import torch
@@ -46,7 +47,7 @@ class _BaseDataParallel(MegatronModule):
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale all gradients inside the buffers by `scaling_factor`."""
pass
- def finish_grad_sync(self):
+ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
@@ -55,6 +56,10 @@ class _BaseDataParallel(MegatronModule):
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
+
+ Args:
+ force_all_reduce: Optional compatibility argument used by implementations
+ that may need to force all-reduce instead of reduce-scatter.
"""
pass
This keeps the base class interface consistent with the current call site and with existing DDP/Megatron-FSDP implementations. It also allows Torch FSDP2 to continue using the base no-op
behavior without crashing.
Summary
When running a pretraining job with Torch FSDP2 enabled, training fails during gradient finalization with:
TypeError: _BaseDataParallel.finish_grad_sync() got an unexpected keyword argument 'force_all_reduce'This appears to be an interface mismatch introduced after force_all_reduce was added to the gradient finalization path.
Analysis
finalize_model_grads() now calls:
model_chunk.finish_grad_sync(force_all_reduce=force_all_reduce)
This works for DistributedDataParallel, whose method signature is:
def finish_grad_sync(self, force_all_reduce: Optional[bool] = False):
It also works for Megatron FSDP, which accepts the same compatibility argument.
However, TorchFullyShardedDataParallel does not define its own finish_grad_sync() method, so it inherits the base implementation from _BaseDataParallel:
As a result, any Torch FSDP2 run that reaches finalize_model_grads() receives the unexpected keyword argument and fails before the optimizer step.
Expected behavior
Torch FSDP2 should be able to ignore force_all_reduce if it does not need it, instead of failing due to the base class method signature.
Suggested fix
Update the base data-parallel interface in:
megatron/core/distributed/data_parallel_base.py
to accept the optional compatibility argument:
This keeps the base class interface consistent with the current call site and with existing DDP/Megatron-FSDP implementations. It also allows Torch FSDP2 to continue using the base no-op
behavior without crashing.