Skip to content

[WIP] [LinalgExt] Fix attention NaN for fully-masked rows#24447

Draft
keshavvinayak01 wants to merge 20 commits into
iree-org:mainfrom
keshavvinayak01:users/keshavvinayak01/attention-bug-24175
Draft

[WIP] [LinalgExt] Fix attention NaN for fully-masked rows#24447
keshavvinayak01 wants to merge 20 commits into
iree-org:mainfrom
keshavvinayak01:users/keshavvinayak01/attention-bug-24175

Conversation

@keshavvinayak01
Copy link
Copy Markdown
Contributor

Attention softmax normalization produced NaN for fully-masked rows.

The failing case is:

masked scores are all -inf
softmax numerator becomes 0
softmax denominator becomes 0
final normalization computes 0 / 0
PyTorch SDPA uses _safe_softmax, which explicitly zeroes fully-masked rows, so IREE should produce 0 here instead of NaN.

This PR handles that in both attention lowering paths:

Standalone iree_linalg_ext.attention decomposition clamps the row softmax denominator with max(sum, 1) before P / sum.
Online attention finalization keeps the existing unmasked (1 / sum) * x IR unchanged.
Masked online attention guards the existing finalization loop so sum == 0 yields 0 instead of NaN, avoiding an extra row-level pass.
For non-fully-masked rows, the softmax denominator is unchanged: after max subtraction, at least one term is exp(0) = 1, so sum >= 1.

Fixes #24175

Reopened this to check for regression.

keshavvinayak01 and others added 10 commits April 24, 2026 08:27
Softmax normalization after attention computes (1/sum) * x (OnlineAttention
finalization in TileAttention.cpp) and P/sum (standalone AttentionOp in
AggregatedOpInterfaceImpl.cpp). A fully-masked row has sum == 0 and x == 0,
so the unguarded divide produces 0/0 == NaN instead of the expected zero.
Match PyTorch's scaled_dot_product_attention semantics by adding the
smallest-normal float (FLT_MIN for f32) to `sum` before the divide:
0 / FLT_MIN = 0 rescues the fully-masked row, and for any non-fully-masked
row the softmax invariant guarantees sum >= 1 (at least one element has
exp(S - max) = exp(0) = 1), so `sum + FLT_MIN` rounds back to `sum`
exactly in f32 and the result is unperturbed.

The guard is factored into createSafeDivide in LinalgExt/Utils and emitted
only when the input iree_linalg_ext.attention has a mask operand — both
the standalone AttentionOp decomposition and the OnlineAttention
finalization gate on mask presence, so unmasked attention continues to
use plain divf.

Tests:
- tests/e2e/linalg_ext_ops/attention.mlir: f32 fully-masked and per-row
  partial-mask cases.
- tests/e2e/linalg_ext_ops/attention_f16_mask.mlir: f16 and bf16
  fully-masked cases (bf16 is the dtype the original bug report was
  filed against; vmvx excluded — doesn't legalize f16/bf16
  arith.constant).
- Lit tests updated in convert_to_online_attention.mlir and
  decompose_aggregate_op.mlir (new @attention_f16_masked pins the
  guarded path; @attention_f16 pins the unguarded plain divf).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Keep NaN protection for fully-masked rows while avoiding a separate row-level reciprocal generic; remove unused reciprocal helper and update conversion checks.

Co-authored-by: GPT-5.5 <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Keep the existing reciprocal-and-multiply finalization for unmasked online attention while retaining the masked sum==0 guard. Update the conversion check to assert the unmasked mulf remains present.

Co-authored-by: GPT-5 <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/attention-bug-24175 branch from be056b1 to 0993695 Compare May 12, 2026 09:36
keshavvinayak01 and others added 2 commits May 12, 2026 18:30
Detect fully masked attention rows explicitly, normalize integer masks to i1 for row predicates, and zero those rows after the softmax division instead of clamping the denominator.

Co-authored-by: GPT-5 Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/attention-bug-24175 branch from d785f03 to 1e8d55e Compare May 12, 2026 16:20
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
This reverts commit a02e77b.

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
This reverts commit d894313.

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/attention-bug-24175 branch 7 times, most recently from 936478a to d0f20c8 Compare May 13, 2026 12:09
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/attention-bug-24175 branch 2 times, most recently from 4857bc0 to 0c758f4 Compare May 13, 2026 12:55
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/attention-bug-24175 branch from 0c758f4 to 418faa2 Compare May 13, 2026 12:55
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/attention-bug-24175 branch from 418faa2 to 30428cc Compare May 13, 2026 13:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[SDPA] iree_linalg_ext.attention produces NaN for fully masked inputs, instead of zeros

1 participant