Skip to content

Fix incorrect LSE gradients in cached FlashAttention for Eagle training#536

Open
uygnef wants to merge 4 commits intosgl-project:mainfrom
uygnef:fix/fa
Open

Fix incorrect LSE gradients in cached FlashAttention for Eagle training#536
uygnef wants to merge 4 commits intosgl-project:mainfrom
uygnef:fix/fa

Conversation

@uygnef
Copy link
Copy Markdown
Collaborator

@uygnef uygnef commented Apr 16, 2026

Motivation

This PR mainly fixes the cached LlamaFlashAttention path used in Eagle training.

The issue shows up in the cached / iterative decoding path, where the FA implementation needs to merge the main causal attention block with later cached branches. In the previous implementation,
this merge path could become numerically inconsistent with the reference attention behavior, especially with padding and incremental decoding. In practice, this could cause forward / backward
mismatch and make end-to-end training behavior harder to reason about.

Modifications

  • Fix the cached LlamaFlashAttention merge path so it stays better aligned with the reference attention behavior in both forward and backward.
  • Extend the FlashAttention test with end-to-end metric comparison.
  • Add train/pre_clip_grad_norm logging as an additional validation signal for training comparison.
  • Add a switch to control grad norm logging, since enabling the metric adds about 5% overhead in our runs; the logging is therefore disabled by default and only enabled when needed for validation.

Related Issues

N/A

Accuracy Test

End-to-end training comparison:

  • outputs1/fa_old: FA before the fix
  • outputs1/fa_fix: FA after the fix
  • outputs1/flex: flex attention baseline

From TensorBoard comparison, the fixed FA run behaves closer to the flex baseline than the old FA run, and train/pre_clip_grad_norm provides an additional signal for comparing training dynamics
besides loss / accuracy.
image
image

Additional observation:

The current training loss is averaged over all tokens after masked positions are zeroed out, instead of being normalized by the number of valid unmasked tokens. This may dilute the gradient
contribution from affected positions and partially explain why the previous FA implementation did not show a larger end-to-end regression immediately.

However, replaying an extreme spike case still shows a large discrepancy for the old FA path:

Case Grad Norm Reported Loss
dump 1422.276855 14.712839
fixed FA replay 22.8646 14.633358
flex attention replay 22.7233 14.633938

In this case, the fixed FA replay is much closer to the flex attention baseline in both loss and gradient scale.

Benchmark & Profiling

This PR is mainly for correctness / numerical behavior. It is not intended as a performance optimization.

One practical note is that logging pre_clip_grad_norm introduces roughly 5% overhead in training, so it is guarded by a dedicated flag and is disabled by default.

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces global gradient norm logging for distributed training and refactors the Llama3 Eagle Flash Attention implementation to use a custom autograd function for cached merging. Feedback recommends utilizing the new global norm calculation for manual gradient clipping to ensure stability in distributed environments, updating the Flash Attention availability checks to include the specific functions used, and optimizing the handling of negative infinity constants in the forward pass.

Comment thread specforge/optimizer.py
Comment on lines +67 to +70
if self.log_grad_norm:
self.last_grad_norm = self._compute_global_grad_norm()
else:
self.last_grad_norm = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In a distributed training environment (like FSDP used in this repository), torch.nn.utils.clip_grad_norm_ (called on line 76) will only compute the norm of local shards, leading to incorrect gradient clipping. You should always compute the global norm using _compute_global_grad_norm() and use it to perform manual clipping of the gradients. This ensures that the clipping factor is consistent across all ranks and training remains stable.

        global_grad_norm = self._compute_global_grad_norm()
        self.last_grad_norm = global_grad_norm if self.log_grad_norm else None

Comment thread specforge/modeling/draft/llama3_eagle.py Outdated
out0_expanded = out0.view(
bsz, q_len, num_kv_heads, num_groups, head_dim
).float()
neg_inf = torch.tensor(float("-inf"), device=q.device, dtype=torch.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a new tensor for negative infinity on every forward pass is slightly inefficient. You can use the scalar float("-inf") directly in torch.where calls, or define a constant outside the function.

Suggested change
neg_inf = torch.tensor(float("-inf"), device=q.device, dtype=torch.float32)
neg_inf = float("-inf")

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant