[CUDA] Enable XQA decode for GroupQueryAttention with attention sink#29162
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR extends the CUDA contrib com.microsoft::GroupQueryAttention implementation to enable the XQA decode kernel when a per-head attention sink (head_sink) is present, including an init-time PrePack path that caches a constant head_sink as FP32 to avoid per-step conversion. It also threads the attention-sink pointer through the XQA loader entry points and adds parity tests, profiling helpers, and operator documentation.
Changes:
- Enable XQA for non-quantized decode when
head_sinkis provided (default-on for that path, while honoring explicitORT_ENABLE_XQA=0). - Add
PrePackcaching of constanthead_sink(FP16/BF16 → FP32) and a per-launch FP32 scratch conversion path for dynamic sinks. - Plumb
attention_sinksthrough XQA loaders and add Python parity tests + new GQA documentation and profiling scripts.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | XQA dispatch updates, head_sink handling, PrePack caching, local-window split-planning clamp, debug info wiring |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.h | Add PrePack override and XQA head-sink cache members |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu | Add FP16/BF16 head_sink → FP32 conversion kernel and wire into XQA launch |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h | Declare LaunchConvertHeadSinkToFloat |
| onnxruntime/contrib_ops/cuda/bert/attention_data.h | Add xqa_head_sink pointer + conversion flag to GroupQueryAttentionData |
| onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h | Add use_xqa to debug-info struct |
| onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc | Print SdpaKernel=XQA when applicable |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h | Add attention_sinks parameter to XQA launch API |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu | Thread attention_sinks into fp16 XQA dispatch |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh | Propagate sinks to generated kernels; reject sinks for int8/fp8 KV-cache paths |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh | Update int8 launches to new signature (pass nullptr sinks) |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh | Update fp8 launches to new signature (pass nullptr sinks) |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu | Update fp16 head-size specialization signature |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu | Update fp16 head-size specialization signature |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu | Update fp16 head-size specialization signature |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu | Thread attention_sinks into bf16 XQA dispatch; reject sinks for int8 KV-cache path |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh | Propagate sinks to generated kernels; reject sinks for int8/fp8 KV-cache paths |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh | Update int8 launches to new signature (pass nullptr sinks) |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh | Update fp8 launches to new signature (pass nullptr sinks) |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu | Update bf16 head-size specialization signature |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu | Update bf16 head-size specialization signature |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu | Update bf16 head-size specialization signature |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh | Pass attention_sinks to generated kernel entrypoints |
| onnxruntime/test/python/transformers/test_gqa.py | Add head_sink initializer plumbing, XQA head-sink parity tests, and local-window split-planning regression |
| onnxruntime/test/python/transformers/gqa_test_helper.py | Add head_sink plumbing support for model creation and random feeds |
| onnxruntime/test/python/transformers/profile_gqa.py | New profiling helper (modes, local-window, optional head_sink, optional NVTX) |
| onnxruntime/test/python/transformers/profile_gqa.sh | New nsys wrapper for profiling runs and parsing |
| docs/contrib_ops/gqa.md | New/updated operator documentation including head_sink/XQA defaults and PrePack behavior |
kunal-vaishnavi
approved these changes
Jun 19, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR enables the XQA decode kernel for the CUDA
GroupQueryAttention(GQA) operator when anattention-sink input (
head_sink) is present, the common pattern in GPT-OSS style decode models.The sink is treated as a smooth-softmax term, and a
PrePackstep converts a constanthead_sinkinitializer to a cached FP32 buffer once at session init to avoid a per-step conversion. XQA now
turns on by default for the
head_sinkdecode path while preserving the existingORT_ENABLE_XQAopt-in/opt-out semantics for all other non-quantized cases.Summary of Changes
Kernel: XQA dispatch and head_sink handling
onnxruntime/contrib_ops/cuda/bert/group_query_attention.ccPrePackthat caches a constanthead_sinkinitializer as FP32 (xqa_head_sink_); allow XQA whenhead_sinkis present (smooth-softmax via attention sink); default XQA on for thehead_sinkdecode path; addxqa_force_disabled_so an explicitORT_ENABLE_XQA=0always wins; reserve per-launch FP32 scratch whenhead_sinkis dynamic (not prepacked).onnxruntime/contrib_ops/cuda/bert/group_query_attention.hPrePackdeclaration and members:xqa_head_sink_,xqa_head_sink_count_,xqa_force_disabled_.onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu/.hLaunchConvertHeadSinkToFloatto convert FP16/BF16head_sinkto FP32 for XQA.onnxruntime/contrib_ops/cuda/bert/attention_data.hxqa_head_sink(FP32 sink pointer) andxqa_head_sink_needs_conversiontoGroupQueryAttentionData.onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.{cc,h}use_xqato the debug-info print soSdpaKernel=XQAis reported.XQA loaders: attention-sink plumbing
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader*.{cu,cuh,h}andxqa_impl_gen.cuh— thread theFP32 attention-sink pointer through the FP16/BF16 (and int8/fp8 KV) XQA loader entry points.
Tests and docs
onnxruntime/test/python/transformers/test_gqa.pyTestXQAHeadSinkParitywith runtime and PrePack (head_sinkas initializer) parity cases; addhas_xqa()skip guard;setUpclearsORT_ENABLE_XQAto exercise the real default-on behavior.onnxruntime/test/python/transformers/gqa_test_helper.pyhead_sinkplumbing for the new tests.onnxruntime/test/python/transformers/profile_gqa.pyhead_sinkprofiling support.docs/contrib_ops/gqa.mdhead_sinkon; otherwise opt-in viaORT_ENABLE_XQA;ORT_ENABLE_XQA=0disables).Testing
ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1:head_sinkpresent, env unset →SdpaKernel=XQA.head_sinkpresent,ORT_ENABLE_XQA=0→ falls back toSdpaKernel=FLASH_ATTENTION(parity still passes).head_sinkdecode is unchanged (still Flash / cuDNN).head_sink, so XQA is aperformance default, not a correctness requirement;
ORT_ENABLE_XQA=1/0semantics are preserved.Motivation and Context
GPT-OSS style decode models use a per-head attention sink. Routing these decode steps through XQA
improves decode latency, and prepacking the constant sink to FP32 removes a per-step conversion.
This PR targets
mainbut depends on PR #29161 (FlashDecode split planning for local-window GQA),which should merge first. Until then this PR's diff also includes the #29161 commit; once #29161
lands on
main, GitHub's merge-base will drop it and this PR's diff will contain only the XQA change.Checklist
docs/contrib_ops/gqa.md)ORT_ENABLE_XQAsemantics preserved)