Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions docs/contrib_ops/gqa.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# GroupQueryAttention — Operator Documentation

This document describes the `com.microsoft::GroupQueryAttention` (GQA) contrib operator: its schema,
Comment thread
tianleiwu marked this conversation as resolved.
the CUDA kernel backends and how one is selected, and the attention-sink (`head_sink`) decode path
that is accelerated by the XQA kernel.

For CPU-specific implementation details (including the quantized KV-cache flash path), see
[cpu/gqa.md](cpu/gqa.md).

---

## Table of Contents

1. [Overview](#1-overview)
2. [Operator Schema](#2-operator-schema)
3. [KV Cache and Quantization](#3-kv-cache-and-quantization)
4. [Attention Sink (`head_sink`) and Smooth Softmax](#4-attention-sink-head_sink-and-smooth-softmax)
5. [CUDA Kernel Backends and Dispatch](#5-cuda-kernel-backends-and-dispatch)
6. [XQA Decode Path](#6-xqa-decode-path)
7. [XQA `head_sink` PrePack](#7-xqa-head_sink-prepack)
8. [Environment Variables](#8-environment-variables)
9. [Testing](#9-testing)

---

## 1. Overview

GroupQueryAttention implements causal grouped-query attention with KV-cache (past/present) support.
Grouped-query attention uses fewer key/value heads than query heads: each KV head is shared by a
group of `num_heads / kv_num_heads` query heads. The operator also supports:

- Rotary positional embeddings (RoPE)
- Past/present KV cache with optional in-place (shared) buffer
- Quantized KV cache (int4 / int8 / float8e4m3fn) to reduce memory footprint
- Optional attention bias and local (sliding) window attention
- Smooth softmax, including a per-head attention sink (`head_sink`)

The operator schema is defined in
[onnxruntime/core/graph/contrib_ops/bert_defs.cc](../../onnxruntime/core/graph/contrib_ops/bert_defs.cc).
The CUDA kernel is implemented in
[onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc)
and [group_query_attention_impl.cu](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu).

## 2. Operator Schema

Selected attributes:

| Attribute | Description |
|-----------|-------------|
| `num_heads` | Number of query heads. |
| `kv_num_heads` | Number of key/value heads. `num_heads % kv_num_heads == 0`. |
| `scale` | Softmax scale. Defaults to `1/sqrt(head_size)`. |
| `softcap` | Optional logit soft-capping value. `0` disables it. |
| `local_window_size` | Left window size for local attention. `-1` means global attention. |
| `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. |
| `smooth_softmax` | Add a smooth factor to the softmax denominator. |
| `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. |
| `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). |

Selected inputs (see the schema for the full list and shapes):

| Index | Name | Notes |
|-------|------|-------|
| 0 | `query` | `(batch, seq, hidden)`, or packed QKV. |
| 1, 2 | `key`, `value` | Optional when QKV is packed into `query`. |
| 3, 4 | `past_key`, `past_value` | BNSH cache. Shares the buffer with `present_*` when in-place. |
| 5 | `seqlens_k` | `total_sequence_lengths - 1` per batch entry. |
| 6 | `total_sequence_length` | Scalar used to distinguish prompt vs. decode. |
| 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. |
| 11 | `head_sink` | `(num_heads,)` per-head attention sink (see §4). |
| 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. |

Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`.

## 3. KV Cache and Quantization

The past/present KV cache uses BNSH layout `(batch_size, kv_num_heads, cache_sequence_length, head_size)`.
When `past_present_share_buffer` holds (the past and present tensors alias the same memory), the cache
length is the maximum sequence length and new keys/values are appended in place. This shared-buffer mode
is required by the XQA decode path.

When quantization is enabled, `k_quant_type` and `v_quant_type` select `PER_TENSOR` or `PER_CHANNEL`
scaling, and `kv_cache_bit_width` selects 8-bit or 4-bit storage. The `k_scale` / `v_scale` inputs are
always FP32.

## 4. Attention Sink (`head_sink`) and Smooth Softmax

An attention sink adds a learned per-head bias term to the softmax denominator. With sink value `s_h`
for head `h`, the attention weights over `T` cached positions become:

$$
\text{softmax}_i = \frac{e^{x_i - m}}{e^{s_h - m} + \sum_{j} e^{x_j - m}}, \quad m = \max\left(s_h, \max_j x_j\right)
$$

This is equivalent to appending a single extra logit `s_h` (whose value contributes nothing to the
output, only to normalization). GPT-OSS style models use this to let a head attend to "nothing".

In the kernel, providing the `head_sink` input is treated as smooth softmax:
`parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr`. The `head_sink` tensor is
1D of shape `(num_heads,)` and matches the operator's floating-point type (`float16` or `bfloat16` on
the XQA path).

## 5. CUDA Kernel Backends and Dispatch

The CUDA EP can route a GQA node to several backends. At runtime it selects the first eligible one:

| Backend | Typical use |
|---------|-------------|
| **XQA** | Single-token global decode (`seq_len == 1`), shared KV buffer. Fastest decode path. |
| **Flash Attention / Flash Decoding** | General prompt and decode, including local window and softcap. |
| **cuDNN SDPA** | Preferred on SM≥90 for non-quantized FP16/BF16 causal attention. |
| **Memory Efficient Attention** | Fallback for FP16/FP32 (and BF16 on SM80+). |
| **Unfused** | Last-resort fallback (e.g. `head_size > 256` with past KV). |

The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is enabled.

## 6. XQA Decode Path

XQA (a highly optimized cross/decode attention kernel) is used only when **all** of the following hold:

1. Compute capability SM 8.0+ (Ampere or newer).
2. Decoding phase (not the first prompt) with `sequence_length == 1`.
3. `kv_sequence_length > 0` (there is a new K/V to append).
4. Past and present KV cache share the same buffer.
5. No softcap.
6. Standard softmax, **or** smooth softmax expressed via a `head_sink` tensor (non-quantized KV cache).
7. No local (sliding) window attention — global attention only.
8. Supported `head_size` (64, 128, or 256) and group size.

`head_sink` (attention sink) is supported on the non-quantized XQA path only. Quantized KV cache
(int8 / fp8) paths explicitly reject a non-null attention sink, so a GQA node with both `head_sink`
and a quantized cache falls back to Flash/Flash-Decoding.

XQA selection defaults are:

- **Quantized KV cache (int8 / fp8):** on by default.
- **Non-quantized with a `head_sink` input:** on by default (GPT-OSS style decode).
- **Non-quantized without `head_sink`:** opt-in via `ORT_ENABLE_XQA=1`.

Setting `ORT_ENABLE_XQA=0` disables XQA for the non-quantized path regardless of `head_sink`.

## 7. XQA `head_sink` PrePack

XQA consumes the attention sink as an FP32 buffer, while the model stores `head_sink` as FP16/BF16. To
avoid converting on every decode step, `GroupQueryAttention::PrePack` converts a **constant-initializer**
`head_sink` once into a cached FP32 device buffer (`xqa_head_sink_`):

- The cached buffer is reused for every launch when XQA is eligible.
- A dynamic / non-initializer `head_sink` is **not** prepacked; the kernel instead reserves a small FP32
scratch buffer and converts the sink per launch (`xqa_head_sink_needs_conversion = true`).
- `PrePack` keeps `is_packed = false` so the original FP16/BF16 `head_sink` is still delivered to the
Flash/fallback paths when XQA is disabled or ineligible.

## 8. Environment Variables

| Variable | Effect |
|----------|--------|
| `ORT_ENABLE_XQA` | `1` enables the XQA decode path for the non-quantized KV cache (default off; default on for quantized). |
| `ORT_DISABLE_FLASH_DECODE` | `1` disables the Flash Decoding split-KV optimization. |

These are read once when the kernel is constructed.

## 9. Testing

CUDA parity tests live in
[onnxruntime/test/python/transformers/test_gqa.py](../../onnxruntime/test/python/transformers/test_gqa.py):

- `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity.
- `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input.

`TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity`
instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input
is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`).
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ struct GroupQueryAttentionData {
// XQA buffer
void* xqa_buffer = nullptr;
size_t xqa_buffer_bytes = 0;
// FP32 per-head attention sink consumed by the XQA kernel (nullptr when no head_sink input).
// Either points to a PrePack-cached buffer or to scratch that is filled at launch time.
float* xqa_head_sink = nullptr;
// When true, head_sink was not prepacked (e.g. dynamic/non-initializer input) and the FP16/BF16
// head_sink must be converted to xqa_head_sink (FP32 scratch) before launching XQA.
bool xqa_head_sink_needs_conversion = false;

// Unfused fallback buffers (see LaunchUnfusedAttention in unfused_attention.h):
// unfused_q_bnsh : [B, N_q, S_q, H] (Q transposed from BSNH to BNSH)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ void AttentionKernelDebugInfo::Print(const char* operator_name,
}

sstream << " SdpaKernel=";
if (use_flash_attention.has_value() && use_flash_attention.value()) {
if (use_xqa.has_value() && use_xqa.value()) {
sstream << "XQA";
} else if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << "FLASH_ATTENTION";
#if USE_LEAN_ATTENTION
} else if (use_lean_attention.has_value() && use_lean_attention.value()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

namespace onnxruntime {
struct AttentionKernelDebugInfo {
std::optional<bool> use_xqa = std::nullopt;
std::optional<bool> use_flash_attention = std::nullopt;
std::optional<bool> use_lean_attention = std::nullopt;
std::optional<bool> use_efficient_attention = std::nullopt;
Expand Down
104 changes: 95 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ REGISTER_KERNEL_TYPED(BFloat16, uint8_t)
#endif

constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE";
constexpr int kHeadSinkInputIndex = 11;

// Group Query Attention (GQA) Operator
//
Expand Down Expand Up @@ -110,8 +111,15 @@ GroupQueryAttention<T, U>::GroupQueryAttention(const OpKernelInfo& info)
kv_cache_bit_width_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("kv_cache_bit_width", 0));

bool is_quantized = (k_quant_type_ != KVQuantizationType::NONE || v_quant_type_ != KVQuantizationType::NONE);
int default_enable_xqa = is_quantized ? 1 : 0;
enable_xqa_ = (std::is_same_v<T, MLFloat16> || std::is_same_v<T, BFloat16>) && ParseEnvironmentVariableWithDefault<int>("ORT_ENABLE_XQA", default_enable_xqa) != 0;
// XQA enablement:
// - An explicit ORT_ENABLE_XQA overrides everything (1 = on, 0 = off, including the head_sink default-on path).
// - When unset, XQA defaults on for the quantized KV cache path and off for the non-quantized path
// (the non-quantized head_sink decode path is additionally enabled per-Run in ComputeInternal).
constexpr bool kIsFp16OrBf16 = std::is_same_v<T, MLFloat16> || std::is_same_v<T, BFloat16>;
const int xqa_env = ParseEnvironmentVariableWithDefault<int>("ORT_ENABLE_XQA", -1); // -1 means unset
xqa_force_disabled_ = (xqa_env == 0);
const int effective_enable_xqa = (xqa_env == -1) ? (is_quantized ? 1 : 0) : xqa_env;
enable_xqa_ = kIsFp16OrBf16 && (effective_enable_xqa != 0);

kernel_options_ = this->GetAttentionKernelOptions();

Expand All @@ -121,7 +129,6 @@ GroupQueryAttention<T, U>::GroupQueryAttention(const OpKernelInfo& info)
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();

// cuDNN SDPA (cudnn_frontend) supports FP16 and BF16 and is auto-preferred on SM>=90.
constexpr bool kIsFp16OrBf16 = std::is_same<T, MLFloat16>::value || std::is_same<T, BFloat16>::value;
enable_cudnn_flash_attention_ = kIsFp16OrBf16 && kernel_options_->UseCudnnFlashAttention();
auto_enable_cudnn_flash_attention_ = kIsFp16OrBf16 && kernel_options_->AllowCudnnFlashAttentionAuto();

Expand All @@ -133,6 +140,55 @@ GroupQueryAttention<T, U>::GroupQueryAttention(const OpKernelInfo& info)
disable_flash_decode_ = ParseEnvironmentVariableWithDefault<bool>(kDisableFlashDecode, false);
}

template <typename T, typename U>
Status GroupQueryAttention<T, U>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
// Keep is_packed=false so the original fp16/bf16 head_sink remains available to the Flash/fallback
// paths (which are used when XQA is disabled or ineligible). We only cache an extra FP32 copy for XQA.
is_packed = false;

if (input_idx != kHeadSinkInputIndex) {
return Status::OK();
}

// XQA consumes the attention sink as FP32. When head_sink is a constant initializer, convert it once
// here into a cached device buffer (xqa_head_sink_) to avoid a per-launch conversion. Dynamic /
// non-initializer head_sink inputs are not prepacked and fall back to the per-launch scratch path.
if constexpr (std::is_same_v<T, MLFloat16> || std::is_same_v<T, BFloat16>) {
const auto& shape = tensor.Shape();
ORT_RETURN_IF_NOT(shape.NumDimensions() == 1,
"head_sink must be a 1D tensor, got ", shape.NumDimensions(), " dimensions");
ORT_RETURN_IF_NOT(shape[0] == num_heads_,
"head_sink dimension 0 must be equal to the num heads, got ", shape[0]);
ORT_RETURN_IF_NOT(tensor.IsDataType<T>(), "head_sink type must match GroupQueryAttention input type");

// Derive the element count from the tensor itself (one sink per head) rather than num_heads_.
const int head_sink_count = static_cast<int>(shape.Size());
const size_t head_sink_bytes = tensor.SizeInBytes();
const void* head_sink_data = tensor.DataRaw();
IAllocatorUniquePtr<void> head_sink_gpu;
cudaStream_t stream = cudaStreamLegacy;

if (tensor.Location().device.Type() == OrtDevice::CPU) {
head_sink_gpu = IAllocator::MakeUniquePtr<void>(alloc, head_sink_bytes, true);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(head_sink_gpu.get(), head_sink_data, head_sink_bytes,
cudaMemcpyHostToDevice, stream));
head_sink_data = head_sink_gpu.get();
}

xqa_head_sink_ = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(head_sink_count), true);
using CudaT = typename onnxruntime::cuda::OrtToCudaType<T>::type;
ORT_RETURN_IF_ERROR(LaunchConvertHeadSinkToFloat<CudaT>(
reinterpret_cast<const CudaT*>(head_sink_data), xqa_head_sink_.get(), head_sink_count, stream,
GetDeviceProp().maxThreadsPerBlock));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
xqa_head_sink_count_ = head_sink_count;
}

return Status::OK();
}

// ComputeInternal executes the GQA kernel.
//
// Inputs:
Expand Down Expand Up @@ -338,16 +394,25 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
// 3. Sequence length is 1.
// 4. Past and Present KV cache share the same buffer (required for XQA specific memory access).
// 5. No Softcap (XQA doesn't support softcap).
// 6. Standard Softmax (no smooth softmax).
// 6. Standard Softmax, or smooth softmax represented by a head_sink tensor.
// 7. No local window attention (global attention only).
if (enable_xqa_ &&
const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized;
const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks;
// XQA is opt-in for the non-quantized path (ORT_ENABLE_XQA), but a head_sink (attention sink) input
// signals a GPT-OSS style decode model that benefits from XQA, so enable it by default in that case.
// An explicit ORT_ENABLE_XQA=0 (xqa_force_disabled_) still wins and turns XQA off entirely.
// The dtype guard mirrors enable_xqa_ (XQA only supports fp16/bf16); ineligible cases fall back below.
constexpr bool kIsFp16OrBf16 = std::is_same_v<T, MLFloat16> || std::is_same_v<T, BFloat16>;
const bool xqa_enabled_for_run =
!xqa_force_disabled_ && (enable_xqa_ || (kIsFp16OrBf16 && use_xqa_attention_sinks));
if (xqa_enabled_for_run &&
(device_prop.major >= 8) &&
!parameters.is_first_prompt &&
parameters.sequence_length == 1 &&
parameters.kv_sequence_length > 0 && // Shared KV (kv_seq=0) has no new K/V to append
parameters.past_present_share_buffer &&
parameters.softcap == 0.0f &&
!parameters.use_smooth_softmax &&
is_xqa_smooth_softmax_supported &&
parameters.local_window_size == -1) {
int group_size = parameters.num_heads / parameters.kv_num_heads;

Expand Down Expand Up @@ -389,23 +454,43 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
assert(xqa_internal_bytes > 0);
// Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding
size_t xqa_total_bytes = xqa_internal_bytes;
size_t q_bytes = 0;
size_t k_bytes = 0;
if (parameters.do_rotary) {
// 1. Q_rotated buffer: B * N * H * sizeof(T) (if rotary)
// 2. K_rotated buffer: B * Nk * H * sizeof(T) (if rotary)
size_t element_size = sizeof(CudaT);
size_t q_bytes = parameters.batch_size * parameters.num_heads * parameters.head_size * element_size;
size_t k_bytes = parameters.batch_size * parameters.kv_num_heads * parameters.head_size * element_size;
q_bytes = parameters.batch_size * parameters.num_heads * parameters.head_size * element_size;
k_bytes = parameters.batch_size * parameters.kv_num_heads * parameters.head_size * element_size;
q_bytes = (q_bytes + 255) / 256 * 256;
k_bytes = (k_bytes + 255) / 256 * 256;
xqa_total_bytes += q_bytes + k_bytes;
}
const bool use_prepacked_xqa_head_sink =
use_xqa_attention_sinks && xqa_head_sink_ != nullptr && xqa_head_sink_count_ == parameters.num_heads;
const bool convert_xqa_head_sink = use_xqa_attention_sinks && !use_prepacked_xqa_head_sink;
size_t xqa_head_sink_bytes = 0;
if (convert_xqa_head_sink) {
// No prepacked FP32 head_sink (dynamic input): reserve scratch for the per-launch conversion.
xqa_head_sink_bytes = parameters.num_heads * sizeof(float);
xqa_head_sink_bytes = (xqa_head_sink_bytes + 255) / 256 * 256;
xqa_total_bytes += xqa_head_sink_bytes;
}

xqa_scratch_buffer = this->GetScratchBuffer<void>(xqa_total_bytes, GetComputeStream(context));
data.xqa_buffer = xqa_scratch_buffer.get();
data.xqa_buffer_bytes = xqa_internal_bytes;

char* xqa_extra_buffer = reinterpret_cast<char*>(data.xqa_buffer) + xqa_internal_bytes;
if (parameters.do_rotary) {
data.qkv_buffer = reinterpret_cast<CudaT*>(reinterpret_cast<char*>(data.xqa_buffer) + xqa_internal_bytes);
data.qkv_buffer = reinterpret_cast<CudaT*>(xqa_extra_buffer);
xqa_extra_buffer += q_bytes + k_bytes;
}
if (use_prepacked_xqa_head_sink) {
data.xqa_head_sink = xqa_head_sink_.get();
} else if (convert_xqa_head_sink) {
data.xqa_head_sink = reinterpret_cast<float*>(xqa_extra_buffer);
data.xqa_head_sink_needs_conversion = true;
}
}
}
Expand Down Expand Up @@ -606,6 +691,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons

if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_xqa = data.use_xqa;
debug_info.use_flash_attention = data.use_flash_attention;
debug_info.use_efficient_attention = data.use_memory_efficient_attention;
debug_info.use_cudnn_flash_attention = data.use_cudnn_sdpa;
Expand Down
Loading
Loading