Skip to content

Update sparse attention op#1195

Merged
helloyongyang merged 2 commits into
mainfrom
sparse
Jun 29, 2026
Merged

Update sparse attention op#1195
helloyongyang merged 2 commits into
mainfrom
sparse

Conversation

@chengtao-lv

Copy link
Copy Markdown
Contributor

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the Triton attention kernels by simplifying index calculations, introducing a compile-time constant HD for offset computations, and updating tensor shape comments. Feedback focuses on restoring removed contiguity assertions for input tensors to prevent silent correctness bugs, ensuring unsqueezed tensors are made contiguous, and maintaining the tl.constexpr type annotation for LQ in the backward preprocessing kernel to allow Triton compiler optimizations.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

assert k_block_id.is_contiguous() and lut.is_contiguous()
assert BLOCK_M == 64 or BLOCK_M == 128
assert BLOCK_N == 64 or BLOCK_N == 128
# q, k, v: (B, L, H, D)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The contiguity assertions for q, k, v, k_block_id, and lut were removed. Since the Triton kernel's pointer arithmetic (e.g., q_offset = idx_b * LQ * HD + idx_h * D) strictly assumes a contiguous layout, passing non-contiguous tensors will lead to silent correctness issues or out-of-bounds memory accesses. Please restore these assertions to ensure safety.

Suggested change
# q, k, v: (B, L, H, D)
# q, k, v: (B, L, H, D)
assert q.is_contiguous() and k.is_contiguous() and v.is_contiguous()
assert k_block_id.is_contiguous() and lut.is_contiguous()

Comment on lines 116 to 118
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In apply_triton_ar, q, k, and v are unsqueezed but not made contiguous. If the input tensors are non-contiguous (e.g., due to slicing or transposing prior to this call), it will cause silent correctness bugs in the Triton kernel. Please call .contiguous() on the unsqueezed tensors to guarantee contiguous memory layout.

Suggested change
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
q = q.unsqueeze(0).contiguous()
k = k.unsqueeze(0).contiguous()
v = v.unsqueeze(0).contiguous()

DELTAS,
LQ: tl.constexpr,
H: tl.constexpr,
LQ,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type annotation : tl.constexpr for LQ was omitted in _attn_bwd_preprocess, unlike in all other kernels (_attn_fwd, _attn_bwd_dq, _attn_bwd_dkdv) where LQ is a constexpr. This inconsistency prevents Triton from optimizing the loop boundaries and masking based on LQ being a compile-time constant. Please restore : tl.constexpr to LQ.

Suggested change
LQ,
LQ: tl.constexpr,

@helloyongyang helloyongyang changed the title Sparse Update sparse attention op Jun 29, 2026
@helloyongyang helloyongyang merged commit bf75223 into main Jun 29, 2026
2 checks passed
@helloyongyang helloyongyang deleted the sparse branch June 29, 2026 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants