diff --git a/docs/contrib_ops/gqa.md b/docs/contrib_ops/gqa.md new file mode 100644 index 0000000000000..08596ff4b5dd9 --- /dev/null +++ b/docs/contrib_ops/gqa.md @@ -0,0 +1,173 @@ +# GroupQueryAttention — Operator Documentation + +This document describes the `com.microsoft::GroupQueryAttention` (GQA) contrib operator: its schema, +the CUDA kernel backends and how one is selected, and the attention-sink (`head_sink`) decode path +that is accelerated by the XQA kernel. + +For CPU-specific implementation details (including the quantized KV-cache flash path), see +[cpu/gqa.md](cpu/gqa.md). + +--- + +## Table of Contents + +1. [Overview](#1-overview) +2. [Operator Schema](#2-operator-schema) +3. [KV Cache and Quantization](#3-kv-cache-and-quantization) +4. [Attention Sink (`head_sink`) and Smooth Softmax](#4-attention-sink-head_sink-and-smooth-softmax) +5. [CUDA Kernel Backends and Dispatch](#5-cuda-kernel-backends-and-dispatch) +6. [XQA Decode Path](#6-xqa-decode-path) +7. [XQA `head_sink` PrePack](#7-xqa-head_sink-prepack) +8. [Environment Variables](#8-environment-variables) +9. [Testing](#9-testing) + +--- + +## 1. Overview + +GroupQueryAttention implements causal grouped-query attention with KV-cache (past/present) support. +Grouped-query attention uses fewer key/value heads than query heads: each KV head is shared by a +group of `num_heads / kv_num_heads` query heads. The operator also supports: + +- Rotary positional embeddings (RoPE) +- Past/present KV cache with optional in-place (shared) buffer +- Quantized KV cache (int4 / int8 / float8e4m3fn) to reduce memory footprint +- Optional attention bias and local (sliding) window attention +- Smooth softmax, including a per-head attention sink (`head_sink`) + +The operator schema is defined in +[onnxruntime/core/graph/contrib_ops/bert_defs.cc](../../onnxruntime/core/graph/contrib_ops/bert_defs.cc). +The CUDA kernel is implemented in +[onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc) +and [group_query_attention_impl.cu](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu). + +## 2. Operator Schema + +Selected attributes: + +| Attribute | Description | +|-----------|-------------| +| `num_heads` | Number of query heads. | +| `kv_num_heads` | Number of key/value heads. `num_heads % kv_num_heads == 0`. | +| `scale` | Softmax scale. Defaults to `1/sqrt(head_size)`. | +| `softcap` | Optional logit soft-capping value. `0` disables it. | +| `local_window_size` | Left window size for local attention. `-1` means global attention. | +| `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. | +| `smooth_softmax` | Add a smooth factor to the softmax denominator. | +| `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. | +| `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). | + +Selected inputs (see the schema for the full list and shapes): + +| Index | Name | Notes | +|-------|------|-------| +| 0 | `query` | `(batch, seq, hidden)`, or packed QKV. | +| 1, 2 | `key`, `value` | Optional when QKV is packed into `query`. | +| 3, 4 | `past_key`, `past_value` | BNSH cache. Shares the buffer with `present_*` when in-place. | +| 5 | `seqlens_k` | `total_sequence_lengths - 1` per batch entry. | +| 6 | `total_sequence_length` | Scalar used to distinguish prompt vs. decode. | +| 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. | +| 11 | `head_sink` | `(num_heads,)` per-head attention sink (see §4). | +| 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. | + +Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`. + +## 3. KV Cache and Quantization + +The past/present KV cache uses BNSH layout `(batch_size, kv_num_heads, cache_sequence_length, head_size)`. +When `past_present_share_buffer` holds (the past and present tensors alias the same memory), the cache +length is the maximum sequence length and new keys/values are appended in place. This shared-buffer mode +is required by the XQA decode path. + +When quantization is enabled, `k_quant_type` and `v_quant_type` select `PER_TENSOR` or `PER_CHANNEL` +scaling, and `kv_cache_bit_width` selects 8-bit or 4-bit storage. The `k_scale` / `v_scale` inputs are +always FP32. + +## 4. Attention Sink (`head_sink`) and Smooth Softmax + +An attention sink adds a learned per-head bias term to the softmax denominator. With sink value `s_h` +for head `h`, the attention weights over `T` cached positions become: + +$$ +\text{softmax}_i = \frac{e^{x_i - m}}{e^{s_h - m} + \sum_{j} e^{x_j - m}}, \quad m = \max\left(s_h, \max_j x_j\right) +$$ + +This is equivalent to appending a single extra logit `s_h` (whose value contributes nothing to the +output, only to normalization). GPT-OSS style models use this to let a head attend to "nothing". + +In the kernel, providing the `head_sink` input is treated as smooth softmax: +`parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr`. The `head_sink` tensor is +1D of shape `(num_heads,)` and matches the operator's floating-point type (`float16` or `bfloat16` on +the XQA path). + +## 5. CUDA Kernel Backends and Dispatch + +The CUDA EP can route a GQA node to several backends. At runtime it selects the first eligible one: + +| Backend | Typical use | +|---------|-------------| +| **XQA** | Single-token global decode (`seq_len == 1`), shared KV buffer. Fastest decode path. | +| **Flash Attention / Flash Decoding** | General prompt and decode, including local window and softcap. | +| **cuDNN SDPA** | Preferred on SM≥90 for non-quantized FP16/BF16 causal attention. | +| **Memory Efficient Attention** | Fallback for FP16/FP32 (and BF16 on SM80+). | +| **Unfused** | Last-resort fallback (e.g. `head_size > 256` with past KV). | + +The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is enabled. + +## 6. XQA Decode Path + +XQA (a highly optimized cross/decode attention kernel) is used only when **all** of the following hold: + +1. Compute capability SM 8.0+ (Ampere or newer). +2. Decoding phase (not the first prompt) with `sequence_length == 1`. +3. `kv_sequence_length > 0` (there is a new K/V to append). +4. Past and present KV cache share the same buffer. +5. No softcap. +6. Standard softmax, **or** smooth softmax expressed via a `head_sink` tensor (non-quantized KV cache). +7. No local (sliding) window attention — global attention only. +8. Supported `head_size` (64, 128, or 256) and group size. + +`head_sink` (attention sink) is supported on the non-quantized XQA path only. Quantized KV cache +(int8 / fp8) paths explicitly reject a non-null attention sink, so a GQA node with both `head_sink` +and a quantized cache falls back to Flash/Flash-Decoding. + +XQA selection defaults are: + +- **Quantized KV cache (int8 / fp8):** on by default. +- **Non-quantized with a `head_sink` input:** on by default (GPT-OSS style decode). +- **Non-quantized without `head_sink`:** opt-in via `ORT_ENABLE_XQA=1`. + +Setting `ORT_ENABLE_XQA=0` disables XQA for the non-quantized path regardless of `head_sink`. + +## 7. XQA `head_sink` PrePack + +XQA consumes the attention sink as an FP32 buffer, while the model stores `head_sink` as FP16/BF16. To +avoid converting on every decode step, `GroupQueryAttention::PrePack` converts a **constant-initializer** +`head_sink` once into a cached FP32 device buffer (`xqa_head_sink_`): + +- The cached buffer is reused for every launch when XQA is eligible. +- A dynamic / non-initializer `head_sink` is **not** prepacked; the kernel instead reserves a small FP32 + scratch buffer and converts the sink per launch (`xqa_head_sink_needs_conversion = true`). +- `PrePack` keeps `is_packed = false` so the original FP16/BF16 `head_sink` is still delivered to the + Flash/fallback paths when XQA is disabled or ineligible. + +## 8. Environment Variables + +| Variable | Effect | +|----------|--------| +| `ORT_ENABLE_XQA` | `1` enables the XQA decode path for the non-quantized KV cache (default off; default on for quantized). | +| `ORT_DISABLE_FLASH_DECODE` | `1` disables the Flash Decoding split-KV optimization. | + +These are read once when the kernel is constructed. + +## 9. Testing + +CUDA parity tests live in +[onnxruntime/test/python/transformers/test_gqa.py](../../onnxruntime/test/python/transformers/test_gqa.py): + +- `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity. +- `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input. + +`TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity` +instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input +is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`). diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 7d5c9bc1e221e..74aeaf1285e8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -206,6 +206,12 @@ struct GroupQueryAttentionData { // XQA buffer void* xqa_buffer = nullptr; size_t xqa_buffer_bytes = 0; + // FP32 per-head attention sink consumed by the XQA kernel (nullptr when no head_sink input). + // Either points to a PrePack-cached buffer or to scratch that is filled at launch time. + float* xqa_head_sink = nullptr; + // When true, head_sink was not prepacked (e.g. dynamic/non-initializer input) and the FP16/BF16 + // head_sink must be converted to xqa_head_sink (FP32 scratch) before launching XQA. + bool xqa_head_sink_needs_conversion = false; // Unfused fallback buffers (see LaunchUnfusedAttention in unfused_attention.h): // unfused_q_bnsh : [B, N_q, S_q, H] (Q transposed from BSNH to BNSH) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc index 3aa6351e457e6..cc820b046f966 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -151,7 +151,9 @@ void AttentionKernelDebugInfo::Print(const char* operator_name, } sstream << " SdpaKernel="; - if (use_flash_attention.has_value() && use_flash_attention.value()) { + if (use_xqa.has_value() && use_xqa.value()) { + sstream << "XQA"; + } else if (use_flash_attention.has_value() && use_flash_attention.value()) { sstream << "FLASH_ATTENTION"; #if USE_LEAN_ATTENTION } else if (use_lean_attention.has_value() && use_lean_attention.value()) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index 99b7437238d53..1f4737207cb69 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -8,6 +8,7 @@ namespace onnxruntime { struct AttentionKernelDebugInfo { + std::optional use_xqa = std::nullopt; std::optional use_flash_attention = std::nullopt; std::optional use_lean_attention = std::nullopt; std::optional use_efficient_attention = std::nullopt; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index c2ec19147fc9f..60235024b9118 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -76,6 +76,7 @@ REGISTER_KERNEL_TYPED(BFloat16, uint8_t) #endif constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE"; +constexpr int kHeadSinkInputIndex = 11; // Group Query Attention (GQA) Operator // @@ -110,8 +111,15 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) kv_cache_bit_width_ = static_cast(info.GetAttrOrDefault("kv_cache_bit_width", 0)); bool is_quantized = (k_quant_type_ != KVQuantizationType::NONE || v_quant_type_ != KVQuantizationType::NONE); - int default_enable_xqa = is_quantized ? 1 : 0; - enable_xqa_ = (std::is_same_v || std::is_same_v) && ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", default_enable_xqa) != 0; + // XQA enablement: + // - An explicit ORT_ENABLE_XQA overrides everything (1 = on, 0 = off, including the head_sink default-on path). + // - When unset, XQA defaults on for the quantized KV cache path and off for the non-quantized path + // (the non-quantized head_sink decode path is additionally enabled per-Run in ComputeInternal). + constexpr bool kIsFp16OrBf16 = std::is_same_v || std::is_same_v; + const int xqa_env = ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", -1); // -1 means unset + xqa_force_disabled_ = (xqa_env == 0); + const int effective_enable_xqa = (xqa_env == -1) ? (is_quantized ? 1 : 0) : xqa_env; + enable_xqa_ = kIsFp16OrBf16 && (effective_enable_xqa != 0); kernel_options_ = this->GetAttentionKernelOptions(); @@ -121,7 +129,6 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); // cuDNN SDPA (cudnn_frontend) supports FP16 and BF16 and is auto-preferred on SM>=90. - constexpr bool kIsFp16OrBf16 = std::is_same::value || std::is_same::value; enable_cudnn_flash_attention_ = kIsFp16OrBf16 && kernel_options_->UseCudnnFlashAttention(); auto_enable_cudnn_flash_attention_ = kIsFp16OrBf16 && kernel_options_->AllowCudnnFlashAttentionAuto(); @@ -133,6 +140,55 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) disable_flash_decode_ = ParseEnvironmentVariableWithDefault(kDisableFlashDecode, false); } +template +Status GroupQueryAttention::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + // Keep is_packed=false so the original fp16/bf16 head_sink remains available to the Flash/fallback + // paths (which are used when XQA is disabled or ineligible). We only cache an extra FP32 copy for XQA. + is_packed = false; + + if (input_idx != kHeadSinkInputIndex) { + return Status::OK(); + } + + // XQA consumes the attention sink as FP32. When head_sink is a constant initializer, convert it once + // here into a cached device buffer (xqa_head_sink_) to avoid a per-launch conversion. Dynamic / + // non-initializer head_sink inputs are not prepacked and fall back to the per-launch scratch path. + if constexpr (std::is_same_v || std::is_same_v) { + const auto& shape = tensor.Shape(); + ORT_RETURN_IF_NOT(shape.NumDimensions() == 1, + "head_sink must be a 1D tensor, got ", shape.NumDimensions(), " dimensions"); + ORT_RETURN_IF_NOT(shape[0] == num_heads_, + "head_sink dimension 0 must be equal to the num heads, got ", shape[0]); + ORT_RETURN_IF_NOT(tensor.IsDataType(), "head_sink type must match GroupQueryAttention input type"); + + // Derive the element count from the tensor itself (one sink per head) rather than num_heads_. + const int head_sink_count = static_cast(shape.Size()); + const size_t head_sink_bytes = tensor.SizeInBytes(); + const void* head_sink_data = tensor.DataRaw(); + IAllocatorUniquePtr head_sink_gpu; + cudaStream_t stream = cudaStreamLegacy; + + if (tensor.Location().device.Type() == OrtDevice::CPU) { + head_sink_gpu = IAllocator::MakeUniquePtr(alloc, head_sink_bytes, true); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(head_sink_gpu.get(), head_sink_data, head_sink_bytes, + cudaMemcpyHostToDevice, stream)); + head_sink_data = head_sink_gpu.get(); + } + + xqa_head_sink_ = IAllocator::MakeUniquePtr(alloc, static_cast(head_sink_count), true); + using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; + ORT_RETURN_IF_ERROR(LaunchConvertHeadSinkToFloat( + reinterpret_cast(head_sink_data), xqa_head_sink_.get(), head_sink_count, stream, + GetDeviceProp().maxThreadsPerBlock)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + xqa_head_sink_count_ = head_sink_count; + } + + return Status::OK(); +} + // ComputeInternal executes the GQA kernel. // // Inputs: @@ -338,16 +394,25 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // 3. Sequence length is 1. // 4. Past and Present KV cache share the same buffer (required for XQA specific memory access). // 5. No Softcap (XQA doesn't support softcap). - // 6. Standard Softmax (no smooth softmax). + // 6. Standard Softmax, or smooth softmax represented by a head_sink tensor. // 7. No local window attention (global attention only). - if (enable_xqa_ && + const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized; + const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks; + // XQA is opt-in for the non-quantized path (ORT_ENABLE_XQA), but a head_sink (attention sink) input + // signals a GPT-OSS style decode model that benefits from XQA, so enable it by default in that case. + // An explicit ORT_ENABLE_XQA=0 (xqa_force_disabled_) still wins and turns XQA off entirely. + // The dtype guard mirrors enable_xqa_ (XQA only supports fp16/bf16); ineligible cases fall back below. + constexpr bool kIsFp16OrBf16 = std::is_same_v || std::is_same_v; + const bool xqa_enabled_for_run = + !xqa_force_disabled_ && (enable_xqa_ || (kIsFp16OrBf16 && use_xqa_attention_sinks)); + if (xqa_enabled_for_run && (device_prop.major >= 8) && !parameters.is_first_prompt && parameters.sequence_length == 1 && parameters.kv_sequence_length > 0 && // Shared KV (kv_seq=0) has no new K/V to append parameters.past_present_share_buffer && parameters.softcap == 0.0f && - !parameters.use_smooth_softmax && + is_xqa_smooth_softmax_supported && parameters.local_window_size == -1) { int group_size = parameters.num_heads / parameters.kv_num_heads; @@ -389,23 +454,43 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons assert(xqa_internal_bytes > 0); // Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding size_t xqa_total_bytes = xqa_internal_bytes; + size_t q_bytes = 0; + size_t k_bytes = 0; if (parameters.do_rotary) { // 1. Q_rotated buffer: B * N * H * sizeof(T) (if rotary) // 2. K_rotated buffer: B * Nk * H * sizeof(T) (if rotary) size_t element_size = sizeof(CudaT); - size_t q_bytes = parameters.batch_size * parameters.num_heads * parameters.head_size * element_size; - size_t k_bytes = parameters.batch_size * parameters.kv_num_heads * parameters.head_size * element_size; + q_bytes = parameters.batch_size * parameters.num_heads * parameters.head_size * element_size; + k_bytes = parameters.batch_size * parameters.kv_num_heads * parameters.head_size * element_size; q_bytes = (q_bytes + 255) / 256 * 256; k_bytes = (k_bytes + 255) / 256 * 256; xqa_total_bytes += q_bytes + k_bytes; } + const bool use_prepacked_xqa_head_sink = + use_xqa_attention_sinks && xqa_head_sink_ != nullptr && xqa_head_sink_count_ == parameters.num_heads; + const bool convert_xqa_head_sink = use_xqa_attention_sinks && !use_prepacked_xqa_head_sink; + size_t xqa_head_sink_bytes = 0; + if (convert_xqa_head_sink) { + // No prepacked FP32 head_sink (dynamic input): reserve scratch for the per-launch conversion. + xqa_head_sink_bytes = parameters.num_heads * sizeof(float); + xqa_head_sink_bytes = (xqa_head_sink_bytes + 255) / 256 * 256; + xqa_total_bytes += xqa_head_sink_bytes; + } xqa_scratch_buffer = this->GetScratchBuffer(xqa_total_bytes, GetComputeStream(context)); data.xqa_buffer = xqa_scratch_buffer.get(); data.xqa_buffer_bytes = xqa_internal_bytes; + char* xqa_extra_buffer = reinterpret_cast(data.xqa_buffer) + xqa_internal_bytes; if (parameters.do_rotary) { - data.qkv_buffer = reinterpret_cast(reinterpret_cast(data.xqa_buffer) + xqa_internal_bytes); + data.qkv_buffer = reinterpret_cast(xqa_extra_buffer); + xqa_extra_buffer += q_bytes + k_bytes; + } + if (use_prepacked_xqa_head_sink) { + data.xqa_head_sink = xqa_head_sink_.get(); + } else if (convert_xqa_head_sink) { + data.xqa_head_sink = reinterpret_cast(xqa_extra_buffer); + data.xqa_head_sink_needs_conversion = true; } } } @@ -606,6 +691,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons if (kernel_options_->AllowDebugInfo()) { AttentionKernelDebugInfo debug_info; + debug_info.use_xqa = data.use_xqa; debug_info.use_flash_attention = data.use_flash_attention; debug_info.use_efficient_attention = data.use_memory_efficient_attention; debug_info.use_cudnn_flash_attention = data.use_cudnn_sdpa; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 34847983ad7de..d5b980bdca290 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -21,6 +21,9 @@ class GroupQueryAttention final : public CudaKernel { GroupQueryAttention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + protected: int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention @@ -36,6 +39,7 @@ class GroupQueryAttention final : public CudaKernel { bool disable_memory_efficient_attention_; bool disable_flash_decode_; bool enable_xqa_; + bool xqa_force_disabled_; // True when ORT_ENABLE_XQA=0 is explicitly set (overrides default-on paths). bool enable_cudnn_flash_attention_; // cuDNN SDPA explicitly enabled (env / sdpa_kernel) bool auto_enable_cudnn_flash_attention_; // auto-prefer cuDNN SDPA on SM>=90 when no explicit kernel pinned @@ -45,6 +49,9 @@ class GroupQueryAttention final : public CudaKernel { static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; + // FP32 head_sink cached in PrePack for the XQA path (empty when head_sink is not a constant initializer). + IAllocatorUniquePtr xqa_head_sink_; + int xqa_head_sink_count_ = 0; // Number of elements in xqa_head_sink_ (0 when not prepacked). const AttentionKernelOptions* kernel_options_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 1c39de01fef66..6a55f18bd939a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -68,6 +68,26 @@ namespace cuda { // QKV Preprocessing Helpers // ============================================================================ +template +__global__ void ConvertHeadSinkToFloatKernel(const T* input, float* output, int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) { + output[i] = static_cast(input[i]); + } +} + +template +Status LaunchConvertHeadSinkToFloat( + const T* input, + float* output, + int count, + cudaStream_t stream, + int max_threads_per_block) { + int blocks = (count + max_threads_per_block - 1) / max_threads_per_block; + ConvertHeadSinkToFloatKernel<<>>(input, output, count); + return CUDA_CALL(cudaGetLastError()); +} + // Internal helper to get Q, K, V pointers, handling packed input // // This function orchestrates the preparation of Q, K, and V tensors for attention kernels. @@ -655,6 +675,13 @@ Status ExtremeDecoding( void* xqa_workspace = data.xqa_buffer; size_t xqa_workspace_size = data.xqa_buffer_bytes; + if (data.xqa_head_sink_needs_conversion) { + ORT_ENFORCE(data.xqa_head_sink != nullptr, "XQA head_sink conversion buffer was not allocated."); + ORT_ENFORCE(data.head_sink != nullptr, "XQA head_sink input was not available for conversion."); + ORT_RETURN_IF_ERROR(LaunchConvertHeadSinkToFloat( + data.head_sink, data.xqa_head_sink, num_heads, stream, device_prop.maxThreadsPerBlock)); + } + constexpr bool is_fp8 = std::is_same::value; using onnxruntime::contrib::cuda::XqaQuantType; // 5. Launch XQA @@ -673,6 +700,7 @@ Status ExtremeDecoding( scale, past_bsnh, data.past_seq_lens, + data.xqa_head_sink, data.k_scale, // kv_cache_scale // Map cache type to XqaQuantType: NONE->kNone, Float8E4M3FN->kFp8, int8->kInt8 (parameters.k_quant_type == KVQuantizationType::NONE) ? XqaQuantType::kNone : (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8), @@ -1316,6 +1344,20 @@ template struct GroupQueryAttentionData; template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>; template struct GroupQueryAttentionData; +template Status LaunchConvertHeadSinkToFloat( + const half* input, + float* output, + int count, + cudaStream_t stream, + int max_threads_per_block); + +template Status LaunchConvertHeadSinkToFloat<__nv_bfloat16>( + const __nv_bfloat16* input, + float* output, + int count, + cudaStream_t stream, + int max_threads_per_block); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 89945b20fcfb3..348dc0832d3ba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -28,6 +28,14 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); +template +Status LaunchConvertHeadSinkToFloat( + const T* input, + float* output, + int count, + cudaStream_t stream, + int max_threads_per_block); + // ============================================================================ // GQABufferRequirements: Centralized buffer size calculation // ============================================================================ diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh index d132fba85988c..cd4088cf757ba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh @@ -56,6 +56,7 @@ inline Status Launch( [[maybe_unused]] const float scale, [[maybe_unused]] const bool is_bsnh, [[maybe_unused]] const int* past_seq_lens, + [[maybe_unused]] const float* attention_sinks, [[maybe_unused]] const float* kv_cache_scale, [[maybe_unused]] void* workspace, [[maybe_unused]] size_t workspace_size) { @@ -97,7 +98,7 @@ inline Status Launch( scale, out_ptr, q_ptr, - nullptr, // attentionSinks + attention_sinks, k_ptr, v_ptr, is_bsnh, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h index 8439c19687097..ee4fbc88982f8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h @@ -34,9 +34,10 @@ Status LaunchXQAKernel( const int head_size, const int max_seq_len, // Max sequence length of cache const float scale, - const bool is_bsnh, // Layout of KV cache - const int* past_seq_lens, // Past sequence lengths [BatchSize] - const float* kv_cache_scale, // KV cache dequant scale (nullptr for FP16/BF16, per-tensor float for INT8) + const bool is_bsnh, // Layout of KV cache + const int* past_seq_lens, // Past sequence lengths [BatchSize] + const float* attention_sinks, // Attention sink per query head, nullptr if not used + const float* kv_cache_scale, // KV cache dequant scale (nullptr for FP16/BF16, per-tensor float for INT8) const XqaQuantType kv_quant_type, void* workspace = nullptr, // Scratch memory size_t workspace_size = 0 // Size of scratch memory diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu index 4c6731b10fe77..4a2d22938d48d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu @@ -26,6 +26,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -49,6 +50,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -72,6 +74,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -118,12 +121,14 @@ Status LaunchXQAKernel<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, size_t workspace_size) { // Dispatch to INT8 path if requested if (kv_quant_type == XqaQuantType::kInt8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with INT8 KV cache."); return LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); } @@ -131,15 +136,15 @@ Status LaunchXQAKernel<__nv_bfloat16>( if (head_size == 256) { return H256::LaunchXQAKernelImpl<__nv_bfloat16>( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else if (head_size == 128) { return H128::LaunchXQAKernelImpl<__nv_bfloat16>( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else if (head_size == 64) { return H64::LaunchXQAKernelImpl<__nv_bfloat16>( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA only supports head_size=64, 128, or 256. Input has ", head_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu index 7572986d14632..a8ea76ab23b8b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu index 2706a9de32b14..79ddc2d0d7c34 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu index 7bd8897fdfd93..c94f6b5fc0695 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh index 481fcb63c1f8c..773b4810b6b30 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh @@ -102,13 +102,13 @@ Status LaunchXQAFp8KernelBF16( int group_size = num_heads / kv_num_heads; switch (group_size) { case 4: - return grp4_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 only supports group_size 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh index c2d9c057c6e50..6a84d452f1384 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh @@ -158,6 +158,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -179,6 +180,7 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -187,6 +189,7 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( // Dispatch to INT8 path if requested if (kv_quant_type == XqaQuantType::kInt8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with INT8 KV cache."); return LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, @@ -196,6 +199,7 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( #ifdef USE_FP8_KV_CACHE // Dispatch to FP8 path if requested if (kv_quant_type == XqaQuantType::kFp8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with FP8 KV cache."); return LaunchXQAFp8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, @@ -206,17 +210,17 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( int group_size = num_heads / kv_num_heads; switch (group_size) { case 1: - return grp1_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp1_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 2: - return grp2_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp2_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 4: - return grp4_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh index acec9aeed9973..0ad18e99c5841 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh @@ -102,13 +102,13 @@ Status LaunchXQAInt8KernelBF16( int group_size = num_heads / kv_num_heads; switch (group_size) { case 4: - return grp4_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 only supports group_size 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu index 37b974a8a3e60..b8171392e0f50 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu @@ -28,6 +28,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -52,6 +53,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -76,6 +78,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -101,6 +104,7 @@ Status LaunchXQAKernel( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -112,15 +116,15 @@ Status LaunchXQAKernel( if (head_size == 256) { return H256::LaunchXQAKernelImpl( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else if (head_size == 128) { return H128::LaunchXQAKernelImpl( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else if (head_size == 64) { return H64::LaunchXQAKernelImpl( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA only supports head_size=64, 128, or 256. Input has ", head_size); } @@ -186,6 +190,7 @@ template Status LaunchXQAKernel( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu index 87304cfd1adc2..06c8b0ce0ea2a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu index 3d070a87f87a8..756cc61cb9720 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu index 1664122dbc6d3..4b5b0fe4f17c9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh index 5e18d21defb79..0a613ead5c16e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh @@ -101,13 +101,13 @@ Status LaunchXQAFp8Kernel( int group_size = num_heads / kv_num_heads; switch (group_size) { case 4: - return grp4_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 only supports group_size 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh index 675beb3c92d0f..269b7956c0999 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh @@ -158,6 +158,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -166,6 +167,7 @@ Status LaunchXQAKernelImpl( // Dispatch to INT8 path if requested if (kv_quant_type == XqaQuantType::kInt8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with INT8 KV cache."); if constexpr (std::is_same::value) { return LaunchXQAInt8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); } else { @@ -177,6 +179,7 @@ Status LaunchXQAKernelImpl( #ifdef USE_FP8_KV_CACHE // Dispatch to FP8 path if requested if (kv_quant_type == XqaQuantType::kFp8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with FP8 KV cache."); if constexpr (std::is_same::value) { return LaunchXQAFp8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); } else { @@ -189,17 +192,17 @@ Status LaunchXQAKernelImpl( int group_size = num_heads / kv_num_heads; switch (group_size) { case 1: - return grp1_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp1_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 2: - return grp2_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp2_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 4: - return grp4_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh index f3a1fcd8a8e63..ebeccfb60c7ba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh @@ -101,13 +101,13 @@ Status LaunchXQAInt8Kernel( int group_size = num_heads / kv_num_heads; switch (group_size) { case 4: - return grp4_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 only supports group_size 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/test/python/transformers/env_var_helper.py b/onnxruntime/test/python/transformers/env_var_helper.py new file mode 100644 index 0000000000000..77d42291ce12e --- /dev/null +++ b/onnxruntime/test/python/transformers/env_var_helper.py @@ -0,0 +1,25 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os +from contextlib import contextmanager + + +@contextmanager +def scoped_env_var(name: str, value: str): + """Temporarily set an environment variable, restoring the previous value on exit. + + Keeps tests order-independent by ensuring env-var mutations do not leak into + later tests running in the same process. + """ + previous = os.environ.get(name) + os.environ[name] = value + try: + yield + finally: + if previous is None: + os.environ.pop(name, None) + else: + os.environ[name] = previous diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py index 7f0d50a7ac8ed..d3dd86ea9bbc6 100644 --- a/onnxruntime/test/python/transformers/gqa_test_helper.py +++ b/onnxruntime/test/python/transformers/gqa_test_helper.py @@ -310,6 +310,7 @@ def __init__( v_quant_type: str = "NONE", kv_cache_type: str = "float16", share_kv_scale: bool = False, + has_head_sink: bool = False, ): super().__init__( "GroupQueryAttention", @@ -341,6 +342,7 @@ def __init__( self.k_quant_type = k_quant_type self.v_quant_type = v_quant_type self.share_kv_scale = share_kv_scale + self.has_head_sink = has_head_sink # Determine bit width from cache type if applicable if kv_cache_type == "int4": self.kv_cache_bit_width = 4 @@ -359,6 +361,8 @@ def shape_dict(self): "seqlens_k": (self.batch_size,), } ) + if self.has_head_sink: + shapes["head_sink"] = (self.num_heads,) # Note: We don't adjust shapes for int4 here because the parent's random_inputs # creates float tensors first, then quantization will pack them return shapes @@ -371,6 +375,8 @@ def random_inputs(self): "seqlens_k": k_seqlens - 1, } ) + if self.has_head_sink: + feeds["head_sink"] = torch.rand((self.num_heads,), device=self.device, dtype=self.dtype) # Generate quantized cache and scales if quantization is enabled if self.k_quant_type != "NONE": @@ -423,7 +429,7 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): "sin_cache" if config.do_rotary else "", "", # position_ids (optional, not used in benchmark) "", # attention_bias (optional, not used in benchmark) - "", # head_sink (optional, not used in benchmark) + "head_sink" if config.has_head_sink else "", "k_scale" if config.k_quant_type != "NONE" else "", "v_scale" if config.v_quant_type != "NONE" else "", ] @@ -512,6 +518,9 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])), ] + if config.has_head_sink: + graph_input.append(helper.make_tensor_value_info("head_sink", float_type, list(shape_dict["head_sink"]))) + # Add scale inputs for quantization # Shape depends on quantization type: # - PER_TENSOR: [1] diff --git a/onnxruntime/test/python/transformers/profile_gqa.py b/onnxruntime/test/python/transformers/profile_gqa.py index e11184f31673e..49ce26b5126b8 100644 --- a/onnxruntime/test/python/transformers/profile_gqa.py +++ b/onnxruntime/test/python/transformers/profile_gqa.py @@ -20,10 +20,18 @@ """ import argparse +import os import time import torch -from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention + +try: + from gqa_test_helper import GroupQueryAttentionConfig, OrtGroupQueryAttention +except ImportError: + import sys + + sys.path.insert(0, os.path.dirname(__file__)) + from gqa_test_helper import GroupQueryAttentionConfig, OrtGroupQueryAttention # Optional NVTX support for nsys range markers try: @@ -62,6 +70,7 @@ def create_gqa_config( local_window_size: int = -1, is_packed_qkv: bool = False, do_rotary: bool = True, + has_head_sink: bool = False, device: str = "cuda", share_kv_scale: bool = False, ) -> GroupQueryAttentionConfig: @@ -103,6 +112,7 @@ def create_gqa_config( dtype=dtype, is_packed_qkv=is_packed_qkv, use_smooth_softmax=False, + has_head_sink=has_head_sink, device=device, k_quant_type=k_quant_type, v_quant_type=v_quant_type, @@ -147,6 +157,7 @@ def run_comparison(args): print(f"{'=' * 70}") print(f"Config: batch={args.batch_size}, seq_len={args.sequence_length}, past_seq={args.past_sequence_length}") print(f" num_heads={args.num_heads}, kv_heads={args.kv_num_heads}, head_size={args.head_size}") + print(f" packed_qkv={args.is_packed_qkv}, rotary={not args.no_rotary}, head_sink={args.head_sink}") print(f" warmup={args.warmup}, repeat={args.repeat}") print(f"{'=' * 70}\n") @@ -166,6 +177,7 @@ def run_comparison(args): local_window_size=args.local_window_size, is_packed_qkv=args.is_packed_qkv, do_rotary=not args.no_rotary, + has_head_sink=args.head_sink, share_kv_scale=args.share_kv_scale, ) avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=mode) @@ -203,6 +215,7 @@ def main(): parser.add_argument("--warmup", type=int, default=50, help="Warmup iterations") parser.add_argument("--repeat", type=int, default=100, help="Benchmark iterations") parser.add_argument("--is-packed-qkv", action="store_true", help="Use packed QKV") + parser.add_argument("--head-sink", action="store_true", help="Add a head_sink input") parser.add_argument("--no-rotary", action="store_true", help="Disable rotary embeddings") parser.add_argument("--share-kv-scale", action="store_true", help="Share KV scale tensor for XQA") diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 310018c3395ae..529eae1494e94 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -22,6 +22,7 @@ import torch from cuda_plugin_ep_helper import get_cuda_provider_name, resolve_cuda_plugin_ep from einops import rearrange, repeat +from env_var_helper import scoped_env_var # --- ONNX and Torch/Numpy Dtype Mappings --- from gqa_test_helper import ( @@ -87,6 +88,10 @@ class GQAConfig: softcap: float = 0.0 use_smooth_softmax: bool = False has_head_sink: bool = False + # When True, head_sink is baked into the model as a constant initializer (instead of a runtime + # input). This exercises the GroupQueryAttention::PrePack path that converts the constant + # head_sink to a cached FP32 XQA buffer. + head_sink_as_initializer: bool = False kv_cache_type: str = "" share_buffer: bool = True share_kv_scale: bool = False @@ -190,12 +195,23 @@ def apply_rotary_embedding(x, cos, sin, pos, interleaved, device="cpu"): # ################################################################################################# +def make_head_sink_initializer(head_sink, ort_type, num_heads): + """Build a constant head_sink initializer (fp16/bf16) so GroupQueryAttention::PrePack runs. + + The 16-bit float bits are reinterpreted as uint16 and stored as raw bytes, which works for + both float16 and bfloat16 without relying on numpy bfloat16 support. + """ + raw = head_sink.detach().reshape(num_heads).cpu().contiguous().view(torch.uint16).numpy().tobytes() + return helper.make_tensor(name="head_sink", data_type=ort_type, dims=[num_heads], vals=raw, raw=True) + + def create_gqa_node_and_io( config: GQAConfig, ort_type, share_buffer=True, is_past=False, output_qk: int = 0, # CUDA does not support output_qk for GQA + head_sink_values=None, ): if is_past: if share_buffer: @@ -211,6 +227,8 @@ def create_gqa_node_and_io( if not config.kv_cache_type: config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" + initializers = [] + # --- Node Definition --- outputs = [ "output", @@ -348,7 +366,11 @@ def create_gqa_node_and_io( ) ) if config.has_head_sink: - graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) + if config.head_sink_as_initializer and head_sink_values is not None: + # Constant initializer (not a graph input) so ORT treats it as a constant and PrePack runs. + initializers.append(make_head_sink_initializer(head_sink_values, ort_type, config.num_heads)) + else: + graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] @@ -372,19 +394,23 @@ def create_gqa_node_and_io( ) ) - return node, graph_input, graph_output + return node, graph_input, graph_output, initializers def create_group_query_attention_graph_prompt(config: GQAConfig, ort_type, share_buffer=True): - node, graph_input, graph_output = create_gqa_node_and_io(config, ort_type, share_buffer, is_past=False) - graph = helper.make_graph([node], "GroupQueryAttention_Graph", graph_input, graph_output) + node, graph_input, graph_output, initializers = create_gqa_node_and_io( + config, ort_type, share_buffer, is_past=False + ) + graph = helper.make_graph([node], "GroupQueryAttention_Graph", graph_input, graph_output, initializer=initializers) model = helper.make_model(graph) return model.SerializeToString() -def create_group_query_attention_graph_past(config: GQAConfig, ort_type, share_buffer=True): - node, graph_input, graph_output = create_gqa_node_and_io(config, ort_type, share_buffer, is_past=True) - graph = helper.make_graph([node], "GroupQueryAttention_Graph", graph_input, graph_output) +def create_group_query_attention_graph_past(config: GQAConfig, ort_type, share_buffer=True, head_sink_values=None): + node, graph_input, graph_output, initializers = create_gqa_node_and_io( + config, ort_type, share_buffer, is_past=True, head_sink_values=head_sink_values + ) + graph = helper.make_graph([node], "GroupQueryAttention_Graph", graph_input, graph_output, initializer=initializers) model = helper.make_model(graph) return model.SerializeToString() @@ -605,10 +631,12 @@ def gqa_past_func( if not config.kv_cache_type: config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" + head_sink_as_initializer = config.has_head_sink and config.head_sink_as_initializer and head_sink is not None onnx_model_str = create_group_query_attention_graph_past( config=config, ort_type=ort_type, share_buffer=share_buffer, + head_sink_values=head_sink if head_sink_as_initializer else None, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) @@ -671,7 +699,7 @@ def gqa_past_func( if config.has_attention_bias and attention_bias is not None: bind_tensor(io_binding, "attention_bias", attention_bias, device, ort_type) - if config.has_head_sink and head_sink is not None: + if config.has_head_sink and head_sink is not None and not head_sink_as_initializer: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) # 6. Quantization @@ -1948,6 +1976,11 @@ def has_flash_attention(bf16=False): return True +def has_xqa(): + # The XQA decode kernels require Ampere (SM 8.0) or newer. + return has_cuda_device(80) + + rtol = { "fp16": 5e-3, "bf16": 5e-2, @@ -1988,17 +2021,17 @@ def test_gqa_prompt_flash_attention(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @parameterized.expand(gqa_cuda_past_test_cases()) def test_gqa_past_flash_attention(self, name, config): @@ -2006,17 +2039,17 @@ def test_gqa_past_flash_attention(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @unittest.skipIf(not has_flash_attention(bf16=True), "Flash Attention is not available, skipping tests.") @@ -2037,17 +2070,17 @@ def test_gqa_prompt_flash_attention_bf16(self, name, config): print(f"test_case: {name}\n{config}") config.kv_cache_type = "bfloat16" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol["bf16"], - atol=atol["bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) @parameterized.expand(gqa_cuda_past_test_cases()) def test_gqa_past_flash_attention_bf16(self, name, config): @@ -2059,17 +2092,17 @@ def test_gqa_past_flash_attention_bf16(self, name, config): print(f"test_case: {name}\n{config}") config.kv_cache_type = "bfloat16" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol["bf16"], - atol=atol["bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @@ -2100,17 +2133,17 @@ def test_gqa_quantized_prompt_bf16(self, name, config): self.manual_seed() - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol[f"{config.kv_cache_type}_bf16"], - atol=atol[f"{config.kv_cache_type}_bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol[f"{config.kv_cache_type}_bf16"], + atol=atol[f"{config.kv_cache_type}_bf16"], + ) @parameterized.expand(gqa_cuda_quantized_test_cases(is_past=True)) def test_gqa_quantized_past_bf16(self, name, config): @@ -2120,17 +2153,17 @@ def test_gqa_quantized_past_bf16(self, name, config): self.manual_seed() - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol[f"{config.kv_cache_type}_bf16"], - atol=atol[f"{config.kv_cache_type}_bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol[f"{config.kv_cache_type}_bf16"], + atol=atol[f"{config.kv_cache_type}_bf16"], + ) @unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") @@ -2141,17 +2174,17 @@ def test_gqa_prompt_memory_efficient(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_check_gqa_prompt( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "1"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @parameterized.expand(gqa_cuda_past_test_cases(allow_head_sink=False)) def test_gqa_past_memory_efficient(self, name, config): @@ -2159,17 +2192,17 @@ def test_gqa_past_memory_efficient(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @unittest.skipIf(not has_cuda_device(80), "BF16 requires Ampere or higher GPU, skipping tests.") @@ -2180,17 +2213,17 @@ def test_gqa_past_memory_efficient_bf16(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol["bf16"], - atol=atol["bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @@ -2200,8 +2233,8 @@ def test_gqa_padding_prompt_flash_attention(self): print("-" * 20) print("test_case: test_gqa_padding_prompt_flash_attention") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_test_gqa_padding_prompt() + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_test_gqa_padding_prompt() @unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") @@ -2211,8 +2244,8 @@ def test_gqa_padding_prompt_memory_efficient_attention(self): print("-" * 20) print("test_case: test_gqa_padding_prompt_memory_efficient_attention") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_test_gqa_padding_prompt() + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "1"): + parity_test_gqa_padding_prompt() # ################################################################################################# @@ -2320,8 +2353,106 @@ def tearDown(self): @parameterized.expand(gqa_xqa_test_cases()) def test_xqa_quantized_parity(self, name, config, torch_type, ort_type): """Test XQA per-tensor INT8 quantized parity.""" - os.environ["ORT_ENABLE_XQA"] = "1" + with scoped_env_var("ORT_ENABLE_XQA", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol["int8_bf16"] if torch_type == torch.bfloat16 else rtol["int8_fp16"], + atol=atol["int8_bf16"] if torch_type == torch.bfloat16 else atol["int8_fp16"], + std=0.1, + ) + +def gqa_xqa_head_sink_test_cases(): + # Non-quantized global decode with a head_sink (attention sink) input. + # These configs exercise the XQA attention-sink path added for GPT-OSS style models: + # seq_len=1, shared KV buffer, no softcap, no local window, head_size in {64, 128}, + # and 64 % group_size == 0. + for torch_type, ort_type in [(torch.float16, TensorProto.FLOAT16), (torch.bfloat16, TensorProto.BFLOAT16)]: + for group_size in [1, 4, 8]: + for head_size in [64, 128]: + for rotary in [False, True]: + kv_num_heads = 4 + num_heads = kv_num_heads * group_size + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + past_kv_sequence_length=4, + buffer_sequence_length=4 + 128, + rotary=rotary, + packed=False, + share_buffer=True, + has_head_sink=True, + ) + type_str = "bf16" if torch_type == torch.bfloat16 else "fp16" + rot_str = "rot" if rotary else "norot" + name = f"{type_str}_g{group_size}_h{head_size}_{rot_str}" + yield name, config, torch_type, ort_type + + +def gqa_xqa_head_sink_prepack_test_cases(): + # Same XQA attention-sink decode path as gqa_xqa_head_sink_test_cases(), but head_sink is baked + # into the model as a constant initializer. This exercises GroupQueryAttention::PrePack, which + # converts the constant head_sink once into the cached FP32 XQA buffer (use_prepacked_xqa_head_sink), + # instead of the per-launch conversion scratch path used for runtime head_sink inputs. + for torch_type, ort_type in [(torch.float16, TensorProto.FLOAT16), (torch.bfloat16, TensorProto.BFLOAT16)]: + for group_size in [1, 4]: + for rotary in [False, True]: + kv_num_heads = 4 + num_heads = kv_num_heads * group_size + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=128, + past_kv_sequence_length=4, + buffer_sequence_length=4 + 128, + rotary=rotary, + packed=False, + share_buffer=True, + has_head_sink=True, + head_sink_as_initializer=True, + ) + type_str = "bf16" if torch_type == torch.bfloat16 else "fp16" + rot_str = "rot" if rotary else "norot" + name = f"{type_str}_g{group_size}_h128_{rot_str}_prepack" + yield name, config, torch_type, ort_type + + +@unittest.skipIf(not has_xqa(), "XQA is not available, skipping tests.") +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestXQAHeadSinkParity(unittest.TestCase): + """Verify the non-quantized XQA attention-sink (head_sink) decode path matches the reference.""" + + def setUp(self): + # XQA is enabled by default when a head_sink input is present, so this path is exercised + # without ORT_ENABLE_XQA. Clear it (saving the previous value) to test the real default. + self._prev_enable_xqa = os.environ.pop("ORT_ENABLE_XQA", None) + + def tearDown(self): + # Restore the environment so other tests run with the default XQA setting. + if self._prev_enable_xqa is None: + os.environ.pop("ORT_ENABLE_XQA", None) + else: + os.environ["ORT_ENABLE_XQA"] = self._prev_enable_xqa + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + + @parameterized.expand(gqa_xqa_head_sink_test_cases()) + def test_xqa_head_sink_parity(self, name, config, torch_type, ort_type): + """Test XQA non-quantized parity with a head_sink (attention sink) input.""" parity_check_gqa_past( config=config, ep="CUDAExecutionProvider", @@ -2329,8 +2460,23 @@ def test_xqa_quantized_parity(self, name, config, torch_type, ort_type): torch_type=torch_type, ort_type=ort_type, causal=True, - rtol=rtol["int8_bf16"] if torch_type == torch.bfloat16 else rtol["int8_fp16"], - atol=atol["int8_bf16"] if torch_type == torch.bfloat16 else atol["int8_fp16"], + rtol=rtol["bf16"] if torch_type == torch.bfloat16 else rtol["fp16"], + atol=atol["bf16"] if torch_type == torch.bfloat16 else atol["fp16"], + std=0.1, + ) + + @parameterized.expand(gqa_xqa_head_sink_prepack_test_cases()) + def test_xqa_head_sink_prepack_parity(self, name, config, torch_type, ort_type): + """Test XQA parity when head_sink is a constant initializer (exercises PrePack).""" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol["bf16"] if torch_type == torch.bfloat16 else rtol["fp16"], + atol=atol["bf16"] if torch_type == torch.bfloat16 else atol["fp16"], std=0.1, ) @@ -2454,17 +2600,17 @@ def test_gqa_local_window_large_context_decode(self): ort_type = TensorProto.FLOAT16 device = "cuda" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device=device, - torch_type=torch_type, - ort_type=ort_type, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") def test_gqa_fp8_kv_cache(self): diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 2a3f7eedf4cba..62041d8a432dc 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -25,15 +25,14 @@ # normalization on the selected experts. This provides proper weight distribution # while maintaining computational efficiency. # -------------------------------------------------------------------------- -import os import time import unittest from collections import OrderedDict -from contextlib import contextmanager import numpy import torch import torch.nn.functional as F +from env_var_helper import scoped_env_var from onnx import helper from parameterized import parameterized from torch import nn @@ -1196,19 +1195,6 @@ def with_mlas_q4_mode(test_cases): return expanded_cases -@contextmanager -def scoped_env_var(name: str, value: str): - previous = os.environ.get(name) - os.environ[name] = value - try: - yield - finally: - if previous is None: - os.environ.pop(name, None) - else: - os.environ[name] = previous - - def run_parity_with_mlas_q4_mode(test_runner, enable_mlas_q4_gemm: bool | None): if enable_mlas_q4_gemm is None: # No env var test_runner()