Skip to content

[CUDA] Enable XQA decode for GroupQueryAttention with attention sink#29162

Merged
tianleiwu merged 7 commits into
mainfrom
tlwu/20260618/xqa_head_sink
Jun 20, 2026
Merged

[CUDA] Enable XQA decode for GroupQueryAttention with attention sink#29162
tianleiwu merged 7 commits into
mainfrom
tlwu/20260618/xqa_head_sink

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

Description

This PR enables the XQA decode kernel for the CUDA GroupQueryAttention (GQA) operator when an
attention-sink input (head_sink) is present, the common pattern in GPT-OSS style decode models.
The sink is treated as a smooth-softmax term, and a PrePack step converts a constant head_sink
initializer to a cached FP32 buffer once at session init to avoid a per-step conversion. XQA now
turns on by default for the head_sink decode path while preserving the existing
ORT_ENABLE_XQA opt-in/opt-out semantics for all other non-quantized cases.

Summary of Changes

Kernel: XQA dispatch and head_sink handling

File Change
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Add PrePack that caches a constant head_sink initializer as FP32 (xqa_head_sink_); allow XQA when head_sink is present (smooth-softmax via attention sink); default XQA on for the head_sink decode path; add xqa_force_disabled_ so an explicit ORT_ENABLE_XQA=0 always wins; reserve per-launch FP32 scratch when head_sink is dynamic (not prepacked).
onnxruntime/contrib_ops/cuda/bert/group_query_attention.h Add PrePack declaration and members: xqa_head_sink_, xqa_head_sink_count_, xqa_force_disabled_.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu / .h Add LaunchConvertHeadSinkToFloat to convert FP16/BF16 head_sink to FP32 for XQA.
onnxruntime/contrib_ops/cuda/bert/attention_data.h Add xqa_head_sink (FP32 sink pointer) and xqa_head_sink_needs_conversion to GroupQueryAttentionData.
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.{cc,h} Add use_xqa to the debug-info print so SdpaKernel=XQA is reported.

XQA loaders: attention-sink plumbing

  • onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader*.{cu,cuh,h} and xqa_impl_gen.cuh — thread the
    FP32 attention-sink pointer through the FP16/BF16 (and int8/fp8 KV) XQA loader entry points.

Tests and docs

File Change
onnxruntime/test/python/transformers/test_gqa.py Add TestXQAHeadSinkParity with runtime and PrePack (head_sink as initializer) parity cases; add has_xqa() skip guard; setUp clears ORT_ENABLE_XQA to exercise the real default-on behavior.
onnxruntime/test/python/transformers/gqa_test_helper.py Support head_sink plumbing for the new tests.
onnxruntime/test/python/transformers/profile_gqa.py Minor head_sink profiling support.
docs/contrib_ops/gqa.md New document describing the GQA operator, inputs/attributes, and XQA selection defaults (quantized on; non-quantized head_sink on; otherwise opt-in via ORT_ENABLE_XQA; ORT_ENABLE_XQA=0 disables).

Testing

  • Run the XQA parity suites on an Ampere+ GPU:
    cd onnxruntime/test/python/transformers
    python -m pytest test_gqa.py -k "TestXQAHeadSinkParity or TestXQAQuantizedParity" -q
    
    All 224 cases pass on H200 (SM90). Tests skip automatically on devices without XQA support.
  • Kernel selection was verified via ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1:
    • head_sink present, env unset → SdpaKernel=XQA.
    • head_sink present, ORT_ENABLE_XQA=0 → falls back to SdpaKernel=FLASH_ATTENTION (parity still passes).
    • Non-head_sink decode is unchanged (still Flash / cuDNN).
  • Backward compatibility: Flash/fallback paths keep the original FP16/BF16 head_sink, so XQA is a
    performance default, not a correctness requirement; ORT_ENABLE_XQA=1/0 semantics are preserved.

Motivation and Context

GPT-OSS style decode models use a per-head attention sink. Routing these decode steps through XQA
improves decode latency, and prepacking the constant sink to FP32 removes a per-step conversion.
This PR targets main but depends on PR #29161 (FlashDecode split planning for local-window GQA),
which should merge first. Until then this PR's diff also includes the #29161 commit; once #29161
lands on main, GitHub's merge-base will drop it and this PR's diff will contain only the XQA change.

Checklist

  • Tests added/updated
  • Documentation updated (docs/contrib_ops/gqa.md)
  • No breaking changes (existing ORT_ENABLE_XQA semantics preserved)
  • CI passes

@tianleiwu tianleiwu changed the title Enable XQA decode for GroupQueryAttention with attention sink [CUDA] Enable XQA decode for GroupQueryAttention with attention sink Jun 19, 2026
@tianleiwu tianleiwu changed the base branch from tlwu/20260618/flash_decode_split_local_window to main June 19, 2026 07:29
@tianleiwu tianleiwu requested a review from Copilot June 19, 2026 08:17

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the CUDA contrib com.microsoft::GroupQueryAttention implementation to enable the XQA decode kernel when a per-head attention sink (head_sink) is present, including an init-time PrePack path that caches a constant head_sink as FP32 to avoid per-step conversion. It also threads the attention-sink pointer through the XQA loader entry points and adds parity tests, profiling helpers, and operator documentation.

Changes:

  • Enable XQA for non-quantized decode when head_sink is provided (default-on for that path, while honoring explicit ORT_ENABLE_XQA=0).
  • Add PrePack caching of constant head_sink (FP16/BF16 → FP32) and a per-launch FP32 scratch conversion path for dynamic sinks.
  • Plumb attention_sinks through XQA loaders and add Python parity tests + new GQA documentation and profiling scripts.

Reviewed changes

Copilot reviewed 28 out of 28 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc XQA dispatch updates, head_sink handling, PrePack caching, local-window split-planning clamp, debug info wiring
onnxruntime/contrib_ops/cuda/bert/group_query_attention.h Add PrePack override and XQA head-sink cache members
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Add FP16/BF16 head_sink → FP32 conversion kernel and wire into XQA launch
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h Declare LaunchConvertHeadSinkToFloat
onnxruntime/contrib_ops/cuda/bert/attention_data.h Add xqa_head_sink pointer + conversion flag to GroupQueryAttentionData
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h Add use_xqa to debug-info struct
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc Print SdpaKernel=XQA when applicable
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h Add attention_sinks parameter to XQA launch API
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu Thread attention_sinks into fp16 XQA dispatch
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh Propagate sinks to generated kernels; reject sinks for int8/fp8 KV-cache paths
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh Update int8 launches to new signature (pass nullptr sinks)
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh Update fp8 launches to new signature (pass nullptr sinks)
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu Update fp16 head-size specialization signature
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu Update fp16 head-size specialization signature
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu Update fp16 head-size specialization signature
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu Thread attention_sinks into bf16 XQA dispatch; reject sinks for int8 KV-cache path
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh Propagate sinks to generated kernels; reject sinks for int8/fp8 KV-cache paths
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh Update int8 launches to new signature (pass nullptr sinks)
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh Update fp8 launches to new signature (pass nullptr sinks)
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu Update bf16 head-size specialization signature
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu Update bf16 head-size specialization signature
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu Update bf16 head-size specialization signature
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh Pass attention_sinks to generated kernel entrypoints
onnxruntime/test/python/transformers/test_gqa.py Add head_sink initializer plumbing, XQA head-sink parity tests, and local-window split-planning regression
onnxruntime/test/python/transformers/gqa_test_helper.py Add head_sink plumbing support for model creation and random feeds
onnxruntime/test/python/transformers/profile_gqa.py New profiling helper (modes, local-window, optional head_sink, optional NVTX)
onnxruntime/test/python/transformers/profile_gqa.sh New nsys wrapper for profiling runs and parsing
docs/contrib_ops/gqa.md New/updated operator documentation including head_sink/XQA defaults and PrePack behavior

Comment thread onnxruntime/test/python/transformers/test_gqa.py
Comment thread onnxruntime/test/python/transformers/test_gqa.py Outdated
Comment thread docs/contrib_ops/gqa.md Outdated
Comment thread onnxruntime/test/python/transformers/profile_gqa.py
Comment thread docs/contrib_ops/gqa.md
@tianleiwu tianleiwu merged commit 3718f7d into main Jun 20, 2026
86 checks passed
@tianleiwu tianleiwu deleted the tlwu/20260618/xqa_head_sink branch June 20, 2026 00:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants