Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions areal/engine/fsdp_utils/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,22 @@ def get_grad_norm_fp32(
norm_type = float(norm_type)
total_norm = 0.0

if not grads_for_norm:
return 0.0

device = current_platform.current_device()

if not grads_for_norm:
# Still participate in all_reduce with zero contribution so that
# ranks with grads don't hang waiting for this rank (e.g. LoRA frozen ranks).
total_norm_cuda = torch.tensor(0.0, dtype=torch.float, device=device)
reduce_op = dist.ReduceOp.MAX if norm_type == torch.inf else dist.ReduceOp.SUM
if data_parallel_group:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=data_parallel_group)
if model_parallel_group is not None:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=model_parallel_group)
Comment on lines +109 to +112
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an inconsistency in how the existence of process groups is checked. data_parallel_group is checked using truthiness, while model_parallel_group is checked using an explicit is not None comparison. Per PEP 8, it is generally preferred to use is not None when checking if an optional argument was provided, to avoid potential issues with objects that might define a custom __bool__ method. Using a consistent check also improves readability.

Suggested change
if data_parallel_group:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=data_parallel_group)
if model_parallel_group is not None:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=model_parallel_group)
if data_parallel_group is not None:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=data_parallel_group)
if model_parallel_group is not None:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=model_parallel_group)
References
  1. PEP 8 recommends using 'is not None' for comparisons to singletons like None, rather than relying on truthiness. (link)

total_norm = float(total_norm_cuda.item())
if norm_type != torch.inf and total_norm > 0:
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
Comment on lines +104 to +116
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for participating in the all_reduce collective and calculating the final norm for empty gradient lists is largely duplicated from the rest of the function (specifically the logic found in lines 118-174). While this correctly addresses the distributed deadlock, it introduces a maintainability risk. Any future changes to the distributed synchronization logic (e.g., adding support for new process groups like expert parallel groups) would now need to be updated in three separate locations within this function.

Consider refactoring the function to unify the synchronization and post-processing steps. For instance, you could initialize total_norm_cuda to zero when grads_for_norm is empty and then use conditional blocks to skip only the local gradient accumulation parts, allowing the existing all_reduce and root-calculation logic to handle the final result.


if norm_type == torch.inf:
norms = [grad.abs().max() for grad in grads_for_norm]
total_norm = torch.max(torch.stack(norms)) if norms else 0.0
Expand Down
11 changes: 10 additions & 1 deletion tests/test_fsdp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,17 @@ def mock_process_groups(self):
return dp_group, mp_group

def test_empty_grads_returns_zero(self, mock_process_groups):
# Empty grads must still participate in all_reduce (e.g. LoRA frozen ranks)
# so that ranks with real grads don't hang.
dp_group, mp_group = mock_process_groups
result = get_grad_norm_fp32([], dp_group, mp_group)
with patch("torch.distributed.all_reduce") as mock_allreduce:
result = get_grad_norm_fp32([], dp_group, mp_group)
assert result == 0.0
assert mock_allreduce.call_count == 2 # called for dp_group and mp_group

def test_empty_grads_participates_in_allreduce_no_groups(self):
# With no process groups, empty grads should still return 0.0 without hanging.
result = get_grad_norm_fp32([], None, None)
assert result == 0.0

@pytest.mark.parametrize("norm_type", [1.0, 2.0, 3.0, float("inf")])
Expand Down
Loading