From 54d2e452ecff8300c1a66045ef83d166468707a4 Mon Sep 17 00:00:00 2001 From: Zakir Jiwani <108548454+JiwaniZakir@users.noreply.github.com> Date: Sun, 5 Apr 2026 15:34:53 +0000 Subject: [PATCH] Fix grad norm hang when LoRA frozen ranks have no gradients Ranks with no gradients (e.g. frozen non-LoRA params) previously returned 0.0 immediately, skipping the all_reduce. Ranks that do have gradients then hang waiting for the collective to complete. Move device init before the empty-grads check and make zero-grad ranks still participate in all_reduce with a zero-valued tensor. Co-Authored-By: Claude Sonnet 4.6 --- areal/engine/fsdp_utils/grad.py | 17 ++++++++++++++--- tests/test_fsdp_grad.py | 11 ++++++++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/areal/engine/fsdp_utils/grad.py b/areal/engine/fsdp_utils/grad.py index 48fd5c6e7c..290db79812 100644 --- a/areal/engine/fsdp_utils/grad.py +++ b/areal/engine/fsdp_utils/grad.py @@ -99,11 +99,22 @@ def get_grad_norm_fp32( norm_type = float(norm_type) total_norm = 0.0 - if not grads_for_norm: - return 0.0 - device = current_platform.current_device() + if not grads_for_norm: + # Still participate in all_reduce with zero contribution so that + # ranks with grads don't hang waiting for this rank (e.g. LoRA frozen ranks). + total_norm_cuda = torch.tensor(0.0, dtype=torch.float, device=device) + reduce_op = dist.ReduceOp.MAX if norm_type == torch.inf else dist.ReduceOp.SUM + if data_parallel_group: + dist.all_reduce(total_norm_cuda, op=reduce_op, group=data_parallel_group) + if model_parallel_group is not None: + dist.all_reduce(total_norm_cuda, op=reduce_op, group=model_parallel_group) + total_norm = float(total_norm_cuda.item()) + if norm_type != torch.inf and total_norm > 0: + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + if norm_type == torch.inf: norms = [grad.abs().max() for grad in grads_for_norm] total_norm = torch.max(torch.stack(norms)) if norms else 0.0 diff --git a/tests/test_fsdp_grad.py b/tests/test_fsdp_grad.py index bed14750dd..519bc9088d 100644 --- a/tests/test_fsdp_grad.py +++ b/tests/test_fsdp_grad.py @@ -112,8 +112,17 @@ def mock_process_groups(self): return dp_group, mp_group def test_empty_grads_returns_zero(self, mock_process_groups): + # Empty grads must still participate in all_reduce (e.g. LoRA frozen ranks) + # so that ranks with real grads don't hang. dp_group, mp_group = mock_process_groups - result = get_grad_norm_fp32([], dp_group, mp_group) + with patch("torch.distributed.all_reduce") as mock_allreduce: + result = get_grad_norm_fp32([], dp_group, mp_group) + assert result == 0.0 + assert mock_allreduce.call_count == 2 # called for dp_group and mp_group + + def test_empty_grads_participates_in_allreduce_no_groups(self): + # With no process groups, empty grads should still return 0.0 without hanging. + result = get_grad_norm_fp32([], None, None) assert result == 0.0 @pytest.mark.parametrize("norm_type", [1.0, 2.0, 3.0, float("inf")])