Skip to content

Attempt to fix #2220: Enable Flash Attention in KV-cache path#2248

Open
woaiwang wants to merge 1 commit into
Lightning-AI:mainfrom
woaiwang:fix-issue-2220
Open

Attempt to fix #2220: Enable Flash Attention in KV-cache path#2248
woaiwang wants to merge 1 commit into
Lightning-AI:mainfrom
woaiwang:fix-issue-2220

Conversation

@woaiwang
Copy link
Copy Markdown

@woaiwang woaiwang commented May 9, 2026

Fixes #2220

What I did:
I noticed in #2220 that passing an explicit attn_mask disables the PyTorch SDPA Flash Attention fast path. I tried to drop the mask during the decoding phase (q.size(2) == 1) in CausalSelfAttention.scaled_dot_product_attention to re-enable it.

The Issue:
Running pytest tests/test_model.py results in AssertionError: Tensor-likes are not close! for a few tests (e.g., test_against_gpt_neox_model).
I realize that simply dropping the mask causes the query to attend to the uninitialized/padded parts of the k and v tensors in the KVCache if input_pos_maxp1 is not aggressively slicing them.

Question for reviewers:
Could you guide me on the safest way to slice the KV-cache or formulate the is_causal flag here so that we can safely drop the explicit mask without reading uninitialized memory? I would love to finish this PR with your guidance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

mask_cache in kv-cache path seems to force attn_mask, preventing flash attention

1 participant