[WIP] [LinalgExt] Fix attention NaN for fully-masked rows#24447
Draft
keshavvinayak01 wants to merge 20 commits into
Draft
[WIP] [LinalgExt] Fix attention NaN for fully-masked rows#24447keshavvinayak01 wants to merge 20 commits into
keshavvinayak01 wants to merge 20 commits into
Conversation
Softmax normalization after attention computes (1/sum) * x (OnlineAttention finalization in TileAttention.cpp) and P/sum (standalone AttentionOp in AggregatedOpInterfaceImpl.cpp). A fully-masked row has sum == 0 and x == 0, so the unguarded divide produces 0/0 == NaN instead of the expected zero. Match PyTorch's scaled_dot_product_attention semantics by adding the smallest-normal float (FLT_MIN for f32) to `sum` before the divide: 0 / FLT_MIN = 0 rescues the fully-masked row, and for any non-fully-masked row the softmax invariant guarantees sum >= 1 (at least one element has exp(S - max) = exp(0) = 1), so `sum + FLT_MIN` rounds back to `sum` exactly in f32 and the result is unperturbed. The guard is factored into createSafeDivide in LinalgExt/Utils and emitted only when the input iree_linalg_ext.attention has a mask operand — both the standalone AttentionOp decomposition and the OnlineAttention finalization gate on mask presence, so unmasked attention continues to use plain divf. Tests: - tests/e2e/linalg_ext_ops/attention.mlir: f32 fully-masked and per-row partial-mask cases. - tests/e2e/linalg_ext_ops/attention_f16_mask.mlir: f16 and bf16 fully-masked cases (bf16 is the dtype the original bug report was filed against; vmvx excluded — doesn't legalize f16/bf16 arith.constant). - Lit tests updated in convert_to_online_attention.mlir and decompose_aggregate_op.mlir (new @attention_f16_masked pins the guarded path; @attention_f16 pins the unguarded plain divf). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Keep NaN protection for fully-masked rows while avoiding a separate row-level reciprocal generic; remove unused reciprocal helper and update conversion checks. Co-authored-by: GPT-5.5 <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Keep the existing reciprocal-and-multiply finalization for unmasked online attention while retaining the masked sum==0 guard. Update the conversion check to assert the unmasked mulf remains present. Co-authored-by: GPT-5 <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
…vinayak01/attention-bug-24175
be056b1 to
0993695
Compare
Detect fully masked attention rows explicitly, normalize integer masks to i1 for row predicates, and zero those rows after the softmax division instead of clamping the denominator. Co-authored-by: GPT-5 Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
d785f03 to
1e8d55e
Compare
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
This reverts commit a02e77b. Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
This reverts commit d894313. Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
936478a to
d0f20c8
Compare
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
4857bc0 to
0c758f4
Compare
0c758f4 to
418faa2
Compare
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
418faa2 to
30428cc
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Attention softmax normalization produced NaN for fully-masked rows.
The failing case is:
masked scores are all -inf
softmax numerator becomes 0
softmax denominator becomes 0
final normalization computes 0 / 0
PyTorch SDPA uses _safe_softmax, which explicitly zeroes fully-masked rows, so IREE should produce 0 here instead of NaN.
This PR handles that in both attention lowering paths:
Standalone iree_linalg_ext.attention decomposition clamps the row softmax denominator with max(sum, 1) before P / sum.
Online attention finalization keeps the existing unmasked (1 / sum) * x IR unchanged.
Masked online attention guards the existing finalization loop so sum == 0 yields 0 instead of NaN, avoiding an extra row-level pass.
For non-fully-masked rows, the softmax denominator is unchanged: after max subtraction, at least one term is exp(0) = 1, so sum >= 1.
Fixes #24175
Reopened this to check for regression.