Fix incorrect LSE gradients in cached FlashAttention for Eagle training#536
Fix incorrect LSE gradients in cached FlashAttention for Eagle training#536uygnef wants to merge 4 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces global gradient norm logging for distributed training and refactors the Llama3 Eagle Flash Attention implementation to use a custom autograd function for cached merging. Feedback recommends utilizing the new global norm calculation for manual gradient clipping to ensure stability in distributed environments, updating the Flash Attention availability checks to include the specific functions used, and optimizing the handling of negative infinity constants in the forward pass.
| if self.log_grad_norm: | ||
| self.last_grad_norm = self._compute_global_grad_norm() | ||
| else: | ||
| self.last_grad_norm = None |
There was a problem hiding this comment.
In a distributed training environment (like FSDP used in this repository), torch.nn.utils.clip_grad_norm_ (called on line 76) will only compute the norm of local shards, leading to incorrect gradient clipping. You should always compute the global norm using _compute_global_grad_norm() and use it to perform manual clipping of the gradients. This ensures that the clipping factor is consistent across all ranks and training remains stable.
global_grad_norm = self._compute_global_grad_norm()
self.last_grad_norm = global_grad_norm if self.log_grad_norm else None| out0_expanded = out0.view( | ||
| bsz, q_len, num_kv_heads, num_groups, head_dim | ||
| ).float() | ||
| neg_inf = torch.tensor(float("-inf"), device=q.device, dtype=torch.float32) |
There was a problem hiding this comment.
Creating a new tensor for negative infinity on every forward pass is slightly inefficient. You can use the scalar float("-inf") directly in torch.where calls, or define a constant outside the function.
| neg_inf = torch.tensor(float("-inf"), device=q.device, dtype=torch.float32) | |
| neg_inf = float("-inf") |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Motivation
This PR mainly fixes the cached LlamaFlashAttention path used in Eagle training.
The issue shows up in the cached / iterative decoding path, where the FA implementation needs to merge the main causal attention block with later cached branches. In the previous implementation,
this merge path could become numerically inconsistent with the reference attention behavior, especially with padding and incremental decoding. In practice, this could cause forward / backward
mismatch and make end-to-end training behavior harder to reason about.
Modifications
Related Issues
N/A
Accuracy Test
End-to-end training comparison:
From TensorBoard comparison, the fixed FA run behaves closer to the flex baseline than the old FA run, and train/pre_clip_grad_norm provides an additional signal for comparing training dynamics


besides loss / accuracy.
Additional observation:
The current training loss is averaged over all tokens after masked positions are zeroed out, instead of being normalized by the number of valid unmasked tokens. This may dilute the gradient
contribution from affected positions and partially explain why the previous FA implementation did not show a larger end-to-end regression immediately.
However, replaying an extreme spike case still shows a large discrepancy for the old FA path:
In this case, the fixed FA replay is much closer to the flex attention baseline in both loss and gradient scale.
Benchmark & Profiling
This PR is mainly for correctness / numerical behavior. It is not intended as a performance optimization.
One practical note is that logging pre_clip_grad_norm introduces roughly 5% overhead in training, so it is guarded by a dedicated flag and is disabled by default.
Checklist
(https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci).
and Accuracy Results (https://docs.sglang.ai/references/accuracy_evaluation.html).