Update sparse attention op#1195
Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes the Triton attention kernels by simplifying index calculations, introducing a compile-time constant HD for offset computations, and updating tensor shape comments. Feedback focuses on restoring removed contiguity assertions for input tensors to prevent silent correctness bugs, ensuring unsqueezed tensors are made contiguous, and maintaining the tl.constexpr type annotation for LQ in the backward preprocessing kernel to allow Triton compiler optimizations.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| assert k_block_id.is_contiguous() and lut.is_contiguous() | ||
| assert BLOCK_M == 64 or BLOCK_M == 128 | ||
| assert BLOCK_N == 64 or BLOCK_N == 128 | ||
| # q, k, v: (B, L, H, D) |
There was a problem hiding this comment.
The contiguity assertions for q, k, v, k_block_id, and lut were removed. Since the Triton kernel's pointer arithmetic (e.g., q_offset = idx_b * LQ * HD + idx_h * D) strictly assumes a contiguous layout, passing non-contiguous tensors will lead to silent correctness issues or out-of-bounds memory accesses. Please restore these assertions to ensure safety.
| # q, k, v: (B, L, H, D) | |
| # q, k, v: (B, L, H, D) | |
| assert q.is_contiguous() and k.is_contiguous() and v.is_contiguous() | |
| assert k_block_id.is_contiguous() and lut.is_contiguous() |
| q = q.unsqueeze(0) | ||
| k = k.unsqueeze(0) | ||
| v = v.unsqueeze(0) |
There was a problem hiding this comment.
In apply_triton_ar, q, k, and v are unsqueezed but not made contiguous. If the input tensors are non-contiguous (e.g., due to slicing or transposing prior to this call), it will cause silent correctness bugs in the Triton kernel. Please call .contiguous() on the unsqueezed tensors to guarantee contiguous memory layout.
| q = q.unsqueeze(0) | |
| k = k.unsqueeze(0) | |
| v = v.unsqueeze(0) | |
| q = q.unsqueeze(0).contiguous() | |
| k = k.unsqueeze(0).contiguous() | |
| v = v.unsqueeze(0).contiguous() |
| DELTAS, | ||
| LQ: tl.constexpr, | ||
| H: tl.constexpr, | ||
| LQ, |
There was a problem hiding this comment.
The type annotation : tl.constexpr for LQ was omitted in _attn_bwd_preprocess, unlike in all other kernels (_attn_fwd, _attn_bwd_dq, _attn_bwd_dkdv) where LQ is a constexpr. This inconsistency prevents Triton from optimizing the loop boundaries and masking based on LQ being a compile-time constant. Please restore : tl.constexpr to LQ.
| LQ, | |
| LQ: tl.constexpr, |
No description provided.