From 1e86cf4e5c094c251f3575b8cc3ba8c5ae1d4155 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 05:56:11 +0000 Subject: [PATCH 1/6] Optimize FlashDecode split planning for local-window GQA --- .../cuda/bert/group_query_attention.cc | 7 +- .../test/python/transformers/profile_gqa.py | 225 ++++++++++++++++++ .../test/python/transformers/profile_gqa.sh | 178 ++++++++++++++ .../test/python/transformers/test_gqa.py | 44 ++++ 4 files changed, 453 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/python/transformers/profile_gqa.py create mode 100644 onnxruntime/test/python/transformers/profile_gqa.sh diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 3c09a25edd2af..c2ec19147fc9f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -452,8 +452,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); int num_heads_for_split = data.use_flash_attention_fast_decode ? parameters.kv_num_heads : parameters.num_heads; + size_t sequence_length_for_split = static_cast(parameters.total_sequence_length); + if (data.use_flash_attention_fast_decode && parameters.local_window_size > 0) { + sequence_length_for_split = std::min(sequence_length_for_split, static_cast(parameters.local_window_size)); + } + auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, num_heads_for_split, + parameters.batch_size, parameters.sequence_length, sequence_length_for_split, num_heads_for_split, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = static_cast(num_splits); diff --git a/onnxruntime/test/python/transformers/profile_gqa.py b/onnxruntime/test/python/transformers/profile_gqa.py new file mode 100644 index 0000000000000..e11184f31673e --- /dev/null +++ b/onnxruntime/test/python/transformers/profile_gqa.py @@ -0,0 +1,225 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Simple profiling script for GroupQueryAttention with quantized KV cache. + +Usage: + cd /onnxruntime/test/python/transformers + python profile_gqa.py + + # Profile with Nsight Compute (kernel-level analysis) + ncu --set full -o gqa_fp16 python profile_gqa.py --mode fp16 --warmup 5 --repeat 1 + ncu --set full -o gqa_int8 python profile_gqa.py --mode int8 --warmup 5 --repeat 1 + + # Profile with Nsight Systems (timeline analysis) and extract kernel timings + nsys profile -o gqa_int8 --export=sqlite python profile_gqa.py --mode int8 --warmup 5 --repeat 10 + python parse_nsys.py gqa_int8.sqlite +""" + +import argparse +import time + +import torch +from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention + +# Optional NVTX support for nsys range markers +try: + import nvtx + + HAS_NVTX = True +except ImportError: + HAS_NVTX = False + + # Dummy context manager when NVTX is not available + class DummyNvtxRange: + def __init__(self, name): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + class nvtx: # noqa: N801 + @staticmethod + def annotate(name, color=None): + return DummyNvtxRange(name) + + +def create_gqa_config( + mode: str = "fp16", + batch_size: int = 1, + sequence_length: int = 1, + past_sequence_length: int = 2048, + max_sequence_length: int = 4096, + num_heads: int = 32, + kv_num_heads: int = 8, + head_size: int = 128, + local_window_size: int = -1, + is_packed_qkv: bool = False, + do_rotary: bool = True, + device: str = "cuda", + share_kv_scale: bool = False, +) -> GroupQueryAttentionConfig: + """Create a GQA config based on the mode.""" + if mode == "fp16": + k_quant_type = "NONE" + v_quant_type = "NONE" + kv_cache_type = "float16" + dtype = torch.float16 + elif mode == "bf16": + k_quant_type = "NONE" + v_quant_type = "NONE" + kv_cache_type = "bfloat16" + dtype = torch.bfloat16 + elif mode == "int8": + k_quant_type = "PER_TENSOR" + v_quant_type = "PER_TENSOR" + kv_cache_type = "int8" + dtype = torch.float16 + elif mode == "int4": + k_quant_type = "PER_CHANNEL" + v_quant_type = "PER_CHANNEL" + kv_cache_type = "int4" + dtype = torch.float16 + else: + raise ValueError(f"Unknown mode: {mode}") + + config = GroupQueryAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_sequence_length, + past_sequence_length=past_sequence_length, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + local_window_size=local_window_size, + do_rotary=do_rotary, + rotary_interleaved=False, + dtype=dtype, + is_packed_qkv=is_packed_qkv, + use_smooth_softmax=False, + device=device, + k_quant_type=k_quant_type, + v_quant_type=v_quant_type, + kv_cache_type=kv_cache_type, + share_kv_scale=share_kv_scale, + ) + return config + + +def benchmark_gqa(config: GroupQueryAttentionConfig, warmup: int = 50, repeat: int = 100, mode: str = ""): + """Run benchmark and return average time in ms.""" + obj = OrtGroupQueryAttention(config) + + # Warmup phase with NVTX annotation + with nvtx.annotate(f"warmup_{mode}", color="yellow"): + for _ in range(warmup): + obj.infer() + torch.cuda.synchronize() + + # Benchmark phase with NVTX annotation + with nvtx.annotate(f"benchmark_{mode}", color="green"): + start = time.perf_counter() + for _ in range(repeat): + obj.infer() + torch.cuda.synchronize() + end = time.perf_counter() + + avg_ms = (end - start) * 1000 / repeat + return avg_ms + + +def run_comparison(args): + """Compare FP16/BF16 vs quantized performance.""" + # Auto-adjust max_sequence_length to be at least total_sequence_length + total_sequence_length = args.past_sequence_length + args.sequence_length + if args.max_sequence_length < total_sequence_length: + args.max_sequence_length = total_sequence_length + print(f"Note: max_sequence_length auto-adjusted to {args.max_sequence_length}") + + print(f"\n{'=' * 70}") + print("GQA Performance Comparison") + print(f"{'=' * 70}") + print(f"Config: batch={args.batch_size}, seq_len={args.sequence_length}, past_seq={args.past_sequence_length}") + print(f" num_heads={args.num_heads}, kv_heads={args.kv_num_heads}, head_size={args.head_size}") + print(f" warmup={args.warmup}, repeat={args.repeat}") + print(f"{'=' * 70}\n") + + modes = ["fp16", "bf16", "int8", "int4"] if args.mode == "all" else [args.mode] + results = {} + + for mode in modes: + config = create_gqa_config( + mode=mode, + batch_size=args.batch_size, + sequence_length=args.sequence_length, + past_sequence_length=args.past_sequence_length, + max_sequence_length=args.max_sequence_length, + num_heads=args.num_heads, + kv_num_heads=args.kv_num_heads, + head_size=args.head_size, + local_window_size=args.local_window_size, + is_packed_qkv=args.is_packed_qkv, + do_rotary=not args.no_rotary, + share_kv_scale=args.share_kv_scale, + ) + avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=mode) + results[mode] = avg_ms + print(f" {mode.upper():6s} (dtype={config.dtype}): {avg_ms:.4f} ms") + + # Print comparison if we have baseline + baseline = "fp16" if "fp16" in results else ("bf16" if "bf16" in results else None) + if baseline and len(results) > 1: + print(f"\n Relative to {baseline.upper()}:") + for mode, ms in results.items(): + if mode != baseline: + ratio = ms / results[baseline] + print(f" {mode.upper()}: {ratio:.2f}x slower") + + +def main(): + parser = argparse.ArgumentParser(description="Profile GQA with quantized KV cache") + parser.add_argument( + "--mode", choices=["fp16", "bf16", "int8", "int4", "all"], default="all", help="Quantization mode to test" + ) + parser.add_argument("--batch-size", type=int, default=1, help="Batch size") + parser.add_argument("--sequence-length", type=int, default=1, help="Query sequence length (1 for token generation)") + parser.add_argument("--past-sequence-length", type=int, default=2048, help="Past KV cache sequence length") + parser.add_argument("--max-sequence-length", type=int, default=4096, help="Max sequence length for KV cache buffer") + parser.add_argument("--num-heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv-num-heads", type=int, default=8, help="Number of KV heads") + parser.add_argument("--head-size", type=int, default=128, help="Head dimension") + parser.add_argument( + "--local-window-size", + type=int, + default=-1, + help="Local attention window size (-1 disables sliding window, e.g. gpt-oss uses 128)", + ) + parser.add_argument("--warmup", type=int, default=50, help="Warmup iterations") + parser.add_argument("--repeat", type=int, default=100, help="Benchmark iterations") + parser.add_argument("--is-packed-qkv", action="store_true", help="Use packed QKV") + + parser.add_argument("--no-rotary", action="store_true", help="Disable rotary embeddings") + parser.add_argument("--share-kv-scale", action="store_true", help="Share KV scale tensor for XQA") + + args = parser.parse_args() + + # Check CUDA + if not torch.cuda.is_available(): + print("CUDA not available!") + return + + major, minor = torch.cuda.get_device_capability() + print(f"GPU: {torch.cuda.get_device_name()} (SM{major}{minor})") + + with torch.cuda.stream(torch.cuda.Stream()), torch.no_grad(): + run_comparison(args) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/python/transformers/profile_gqa.sh b/onnxruntime/test/python/transformers/profile_gqa.sh new file mode 100644 index 0000000000000..3fe10eaf35ad2 --- /dev/null +++ b/onnxruntime/test/python/transformers/profile_gqa.sh @@ -0,0 +1,178 @@ +#!/bin/bash +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# Profile the CUDA GroupQueryAttention decode path with nsys. +# +# Usage: +# ./profile_gqa.sh --all +# ./profile_gqa.sh --fp16 --int8 +# ./profile_gqa.sh --fp16 --past-sequence-length 8192 --local-window-size 128 +# ./profile_gqa.sh --bf16 --num-heads 64 --kv-num-heads 8 +# CUDA_VISIBLE_DEVICES=1 PYTHON=python3 ./profile_gqa.sh --int4 +# + +set -e +set -o pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PY="${PYTHON:-python}" + +# Parse arguments +RUN_FP16=false +RUN_INT8=false +RUN_INT4=false +RUN_INT8_QUANT=false +RUN_BF16=false + +# Profile parameters to pass through to profile_gqa.py +BATCH_SIZE="" +SEQUENCE_LENGTH="" +PAST_SEQUENCE_LENGTH="" +PACKED_QKV="" +SHARE_KV_SCALE="" +NUM_HEADS="" +KV_NUM_HEADS="" +LOCAL_WINDOW_SIZE="" +while [[ "$#" -gt 0 ]]; do + case $1 in + --fp16) + RUN_FP16=true + echo "==== ๐Ÿš€ FP16 run enabled ====" + ;; + --int8) + RUN_INT8=true + echo "==== ๐Ÿš€ INT8 run enabled ====" + ;; + --int4) + RUN_INT4=true + echo "==== ๐Ÿš€ INT4 run enabled ====" + ;; + --int8_quant) + RUN_INT8_QUANT=true + echo "==== ๐Ÿš€ INT8 Quant run enabled ====" + ;; + --bf16) + RUN_BF16=true + echo "==== ๐Ÿš€ BF16 run enabled ====" + ;; + --all) + RUN_FP16=true + RUN_INT8=true + RUN_INT4=true + RUN_INT8_QUANT=true + RUN_BF16=true + echo "==== ๐Ÿš€ All runs enabled ====" + ;; + -b|--batch-size) + BATCH_SIZE="--batch-size $2" + echo "==== Batch size: $2 ====" + shift + ;; + -s|--sequence-length) + SEQUENCE_LENGTH="--sequence-length $2" + echo "==== Sequence length: $2 ====" + shift + ;; + -p|--past-sequence-length) + PAST_SEQUENCE_LENGTH="--past-sequence-length $2" + echo "==== Past sequence length: $2 ====" + shift + ;; + --qkv) + PACKED_QKV="--is-packed-qkv" + echo "==== Packed QKV enabled ====" + ;; + --share-kv-scale) + SHARE_KV_SCALE="--share-kv-scale" + echo "==== Share KV scale enabled ====" + ;; + --num-heads) + NUM_HEADS="--num-heads $2" + echo "==== Num Heads: $2 ====" + shift + ;; + --kv-num-heads) + KV_NUM_HEADS="--kv-num-heads $2" + echo "==== KV Num Heads: $2 ====" + shift + ;; + -w|--local-window-size) + LOCAL_WINDOW_SIZE="--local-window-size $2" + echo "==== Local window size: $2 ====" + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac + shift +done + +# Build extra args string +EXTRA_ARGS="${BATCH_SIZE} ${SEQUENCE_LENGTH} ${PAST_SEQUENCE_LENGTH} ${PACKED_QKV} ${SHARE_KV_SCALE} ${NUM_HEADS} ${KV_NUM_HEADS} ${LOCAL_WINDOW_SIZE}" + +if ! command -v nsys >/dev/null; then + echo "Error: nsys not found. Install NVIDIA Nsight Systems or add it to PATH." + exit 1 +fi + +HAVE_NVTX=0 +if "${PY}" -c "import nvtx" 2>/dev/null; then + HAVE_NVTX=1 +else + echo "Note: 'nvtx' package not installed. NVTX range markers will be disabled." + echo " Install with: pip install nvtx" + echo " Falling back to --skip-first to exclude warmup-like first calls." +fi + +# profile_one [env_var=value ...] +profile_one() { + local mode="$1" + local tag="$2" + local base="$3" + shift 3 + + local env_args=() + local e + for e in "$@"; do + env_args+=(-e "${e}") + done + + echo "" + echo "---- Profiling ${mode} ----" + rm -f "${base}.nsys-rep" "${base}.sqlite" + nsys profile -t cuda,nvtx --force-overwrite true "${env_args[@]}" -o "${base}" --export=sqlite \ + "${PY}" "${SCRIPT_DIR}/profile_gqa.py" --mode "${mode}" --warmup 5 --repeat 100 ${EXTRA_ARGS} + + echo "" + echo "---- Kernel results (${mode}) ----" + if [[ "${HAVE_NVTX}" -eq 1 ]]; then + "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --nvtx-range "benchmark_${mode}" --tag "${tag}" + else + "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --skip-first 5 --tag "${tag}" + fi +} + +if [ "$RUN_FP16" = true ]; then + profile_one fp16 Fp16 gqa_fp16 +fi + +if [ "$RUN_BF16" = true ]; then + profile_one bf16 Bf16 gqa_bf16 +fi + +if [ "$RUN_INT8" = true ]; then + profile_one int8 Int8 gqa_int8 ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=0 +fi + +if [ "$RUN_INT8_QUANT" = true ]; then + profile_one int8 Int8Q gqa_int8_quant ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=1 +fi + +if [ "$RUN_INT4" = true ]; then + profile_one int4 Int4 gqa_int4 +fi diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 8a98c39d61eb9..310018c3395ae 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -2422,6 +2422,50 @@ def test_gqa_int8_large_seq_batch4(self): atol=5e-2, ) + def test_gqa_local_window_large_context_decode(self): + """ + Regression test for FlashDecode split planning with a local attention window. + + Mirrors a gpt-oss-style decode step: a large past KV context combined with a small + sliding (local) window. The split-K planning is clamped to the local window length, + so only the windowed portion of the KV cache participates in the decode. This verifies + that the narrowed split planning still produces correct results. + """ + if not has_flash_attention(): + self.skipTest("Flash Attention is not available") + + # Decode (q_sequence_length=1) with a large past context but a small local window. + config = GQAConfig( + batch_size=2, + num_heads=64, + kv_num_heads=8, + head_size=64, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=4096, + buffer_sequence_length=4096 + 8, + local_window_size=128, + rotary=True, + rotary_interleaved=False, + share_buffer=True, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") def test_gqa_fp8_kv_cache(self): """ From b1ee7231695c8450f4e6b87b6a94ab27127a160c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 06:08:33 +0000 Subject: [PATCH 2/6] Enable XQA for GQA head_sink decode --- .../contrib_ops/cuda/bert/attention_data.h | 2 + .../cuda/bert/attention_kernel_options.cc | 4 +- .../cuda/bert/attention_kernel_options.h | 1 + .../cuda/bert/group_query_attention.cc | 75 +++++++++++++++++-- .../cuda/bert/group_query_attention.h | 5 ++ .../cuda/bert/group_query_attention_impl.cu | 42 +++++++++++ .../cuda/bert/group_query_attention_impl.h | 8 ++ .../cuda/bert/xqa/xqa_impl_gen.cuh | 3 +- .../contrib_ops/cuda/bert/xqa/xqa_loader.h | 7 +- .../cuda/bert/xqa/xqa_loader_bf16.cu | 11 ++- .../cuda/bert/xqa/xqa_loader_bf16_128.cu | 1 + .../cuda/bert/xqa/xqa_loader_bf16_256.cu | 1 + .../cuda/bert/xqa/xqa_loader_bf16_64.cu | 1 + .../bert/xqa/xqa_loader_bf16_fp8_impl.cuh | 8 +- .../cuda/bert/xqa/xqa_loader_bf16_impl.cuh | 16 ++-- .../bert/xqa/xqa_loader_bf16_int8_impl.cuh | 8 +- .../cuda/bert/xqa/xqa_loader_fp16.cu | 11 ++- .../cuda/bert/xqa/xqa_loader_fp16_128.cu | 1 + .../cuda/bert/xqa/xqa_loader_fp16_256.cu | 1 + .../cuda/bert/xqa/xqa_loader_fp16_64.cu | 1 + .../bert/xqa/xqa_loader_fp16_fp8_impl.cuh | 8 +- .../cuda/bert/xqa/xqa_loader_fp16_impl.cuh | 15 ++-- .../bert/xqa/xqa_loader_fp16_int8_impl.cuh | 8 +- .../python/transformers/gqa_test_helper.py | 11 ++- .../test/python/transformers/profile_gqa.py | 5 ++ 25 files changed, 209 insertions(+), 45 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 7d5c9bc1e221e..6ae814641125a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -206,6 +206,8 @@ struct GroupQueryAttentionData { // XQA buffer void* xqa_buffer = nullptr; size_t xqa_buffer_bytes = 0; + float* xqa_head_sink = nullptr; + bool xqa_head_sink_needs_conversion = false; // Unfused fallback buffers (see LaunchUnfusedAttention in unfused_attention.h): // unfused_q_bnsh : [B, N_q, S_q, H] (Q transposed from BSNH to BNSH) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc index 0723a48528dfe..ab8f0ad6ca375 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -156,7 +156,9 @@ void AttentionKernelDebugInfo::Print(const char* operator_name, } sstream << " SdpaKernel="; - if (use_flash_attention.has_value() && use_flash_attention.value()) { + if (use_xqa.has_value() && use_xqa.value()) { + sstream << "XQA"; + } else if (use_flash_attention.has_value() && use_flash_attention.value()) { sstream << "FLASH_ATTENTION"; #if USE_LEAN_ATTENTION } else if (use_lean_attention.has_value() && use_lean_attention.value()) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index d94eb5ec06e5c..652a7e0e910c2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -8,6 +8,7 @@ namespace onnxruntime { struct AttentionKernelDebugInfo { + std::optional use_xqa = std::nullopt; std::optional use_flash_attention = std::nullopt; std::optional use_lean_attention = std::nullopt; std::optional use_efficient_attention = std::nullopt; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index c2ec19147fc9f..ca5b738ebb4c5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -76,6 +76,7 @@ REGISTER_KERNEL_TYPED(BFloat16, uint8_t) #endif constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE"; +constexpr int kHeadSinkInputIndex = 11; // Group Query Attention (GQA) Operator // @@ -133,6 +134,48 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) disable_flash_decode_ = ParseEnvironmentVariableWithDefault(kDisableFlashDecode, false); } +template +Status GroupQueryAttention::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + is_packed = false; + + if (input_idx != kHeadSinkInputIndex) { + return Status::OK(); + } + + if constexpr (std::is_same_v || std::is_same_v) { + const auto& shape = tensor.Shape(); + ORT_RETURN_IF_NOT(shape.NumDimensions() == 1, + "head_sink must be a 1D tensor, got ", shape.NumDimensions(), " dimensions"); + ORT_RETURN_IF_NOT(shape[0] == num_heads_, + "head_sink dimension 0 must be equal to the num heads, got ", shape[0]); + ORT_RETURN_IF_NOT(tensor.IsDataType(), "head_sink type must match GroupQueryAttention input type"); + + const size_t head_sink_bytes = tensor.SizeInBytes(); + const void* head_sink_data = tensor.DataRaw(); + IAllocatorUniquePtr head_sink_gpu; + cudaStream_t stream = cudaStreamLegacy; + + if (tensor.Location().device.Type() == OrtDevice::CPU) { + head_sink_gpu = IAllocator::MakeUniquePtr(alloc, head_sink_bytes, true); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(head_sink_gpu.get(), head_sink_data, head_sink_bytes, + cudaMemcpyHostToDevice, stream)); + head_sink_data = head_sink_gpu.get(); + } + + xqa_head_sink_ = IAllocator::MakeUniquePtr(alloc, static_cast(num_heads_), true); + using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; + ORT_RETURN_IF_ERROR(LaunchConvertHeadSinkToFloat( + reinterpret_cast(head_sink_data), xqa_head_sink_.get(), num_heads_, stream, + GetDeviceProp().maxThreadsPerBlock)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + xqa_head_sink_count_ = num_heads_; + } + + return Status::OK(); +} + // ComputeInternal executes the GQA kernel. // // Inputs: @@ -338,8 +381,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // 3. Sequence length is 1. // 4. Past and Present KV cache share the same buffer (required for XQA specific memory access). // 5. No Softcap (XQA doesn't support softcap). - // 6. Standard Softmax (no smooth softmax). + // 6. Standard Softmax, or smooth softmax represented by a head_sink tensor. // 7. No local window attention (global attention only). + const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized; + const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks; if (enable_xqa_ && (device_prop.major >= 8) && !parameters.is_first_prompt && @@ -347,7 +392,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.kv_sequence_length > 0 && // Shared KV (kv_seq=0) has no new K/V to append parameters.past_present_share_buffer && parameters.softcap == 0.0f && - !parameters.use_smooth_softmax && + is_xqa_smooth_softmax_supported && parameters.local_window_size == -1) { int group_size = parameters.num_heads / parameters.kv_num_heads; @@ -389,23 +434,42 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons assert(xqa_internal_bytes > 0); // Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding size_t xqa_total_bytes = xqa_internal_bytes; + size_t q_bytes = 0; + size_t k_bytes = 0; if (parameters.do_rotary) { // 1. Q_rotated buffer: B * N * H * sizeof(T) (if rotary) // 2. K_rotated buffer: B * Nk * H * sizeof(T) (if rotary) size_t element_size = sizeof(CudaT); - size_t q_bytes = parameters.batch_size * parameters.num_heads * parameters.head_size * element_size; - size_t k_bytes = parameters.batch_size * parameters.kv_num_heads * parameters.head_size * element_size; + q_bytes = parameters.batch_size * parameters.num_heads * parameters.head_size * element_size; + k_bytes = parameters.batch_size * parameters.kv_num_heads * parameters.head_size * element_size; q_bytes = (q_bytes + 255) / 256 * 256; k_bytes = (k_bytes + 255) / 256 * 256; xqa_total_bytes += q_bytes + k_bytes; } + const bool use_prepacked_xqa_head_sink = + use_xqa_attention_sinks && xqa_head_sink_ != nullptr && xqa_head_sink_count_ == parameters.num_heads; + const bool convert_xqa_head_sink = use_xqa_attention_sinks && !use_prepacked_xqa_head_sink; + size_t xqa_head_sink_bytes = 0; + if (convert_xqa_head_sink) { + xqa_head_sink_bytes = parameters.num_heads * sizeof(float); + xqa_head_sink_bytes = (xqa_head_sink_bytes + 255) / 256 * 256; + xqa_total_bytes += xqa_head_sink_bytes; + } xqa_scratch_buffer = this->GetScratchBuffer(xqa_total_bytes, GetComputeStream(context)); data.xqa_buffer = xqa_scratch_buffer.get(); data.xqa_buffer_bytes = xqa_internal_bytes; + char* xqa_extra_buffer = reinterpret_cast(data.xqa_buffer) + xqa_internal_bytes; if (parameters.do_rotary) { - data.qkv_buffer = reinterpret_cast(reinterpret_cast(data.xqa_buffer) + xqa_internal_bytes); + data.qkv_buffer = reinterpret_cast(xqa_extra_buffer); + xqa_extra_buffer += q_bytes + k_bytes; + } + if (use_prepacked_xqa_head_sink) { + data.xqa_head_sink = xqa_head_sink_.get(); + } else if (convert_xqa_head_sink) { + data.xqa_head_sink = reinterpret_cast(xqa_extra_buffer); + data.xqa_head_sink_needs_conversion = true; } } } @@ -606,6 +670,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons if (kernel_options_->AllowDebugInfo()) { AttentionKernelDebugInfo debug_info; + debug_info.use_xqa = data.use_xqa; debug_info.use_flash_attention = data.use_flash_attention; debug_info.use_efficient_attention = data.use_memory_efficient_attention; debug_info.use_cudnn_flash_attention = data.use_cudnn_sdpa; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 34847983ad7de..9c22de951a6f1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -21,6 +21,9 @@ class GroupQueryAttention final : public CudaKernel { GroupQueryAttention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + protected: int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention @@ -45,6 +48,8 @@ class GroupQueryAttention final : public CudaKernel { static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; + IAllocatorUniquePtr xqa_head_sink_; + int xqa_head_sink_count_ = 0; const AttentionKernelOptions* kernel_options_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 1c39de01fef66..6a55f18bd939a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -68,6 +68,26 @@ namespace cuda { // QKV Preprocessing Helpers // ============================================================================ +template +__global__ void ConvertHeadSinkToFloatKernel(const T* input, float* output, int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) { + output[i] = static_cast(input[i]); + } +} + +template +Status LaunchConvertHeadSinkToFloat( + const T* input, + float* output, + int count, + cudaStream_t stream, + int max_threads_per_block) { + int blocks = (count + max_threads_per_block - 1) / max_threads_per_block; + ConvertHeadSinkToFloatKernel<<>>(input, output, count); + return CUDA_CALL(cudaGetLastError()); +} + // Internal helper to get Q, K, V pointers, handling packed input // // This function orchestrates the preparation of Q, K, and V tensors for attention kernels. @@ -655,6 +675,13 @@ Status ExtremeDecoding( void* xqa_workspace = data.xqa_buffer; size_t xqa_workspace_size = data.xqa_buffer_bytes; + if (data.xqa_head_sink_needs_conversion) { + ORT_ENFORCE(data.xqa_head_sink != nullptr, "XQA head_sink conversion buffer was not allocated."); + ORT_ENFORCE(data.head_sink != nullptr, "XQA head_sink input was not available for conversion."); + ORT_RETURN_IF_ERROR(LaunchConvertHeadSinkToFloat( + data.head_sink, data.xqa_head_sink, num_heads, stream, device_prop.maxThreadsPerBlock)); + } + constexpr bool is_fp8 = std::is_same::value; using onnxruntime::contrib::cuda::XqaQuantType; // 5. Launch XQA @@ -673,6 +700,7 @@ Status ExtremeDecoding( scale, past_bsnh, data.past_seq_lens, + data.xqa_head_sink, data.k_scale, // kv_cache_scale // Map cache type to XqaQuantType: NONE->kNone, Float8E4M3FN->kFp8, int8->kInt8 (parameters.k_quant_type == KVQuantizationType::NONE) ? XqaQuantType::kNone : (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8), @@ -1316,6 +1344,20 @@ template struct GroupQueryAttentionData; template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>; template struct GroupQueryAttentionData; +template Status LaunchConvertHeadSinkToFloat( + const half* input, + float* output, + int count, + cudaStream_t stream, + int max_threads_per_block); + +template Status LaunchConvertHeadSinkToFloat<__nv_bfloat16>( + const __nv_bfloat16* input, + float* output, + int count, + cudaStream_t stream, + int max_threads_per_block); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 89945b20fcfb3..348dc0832d3ba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -28,6 +28,14 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); +template +Status LaunchConvertHeadSinkToFloat( + const T* input, + float* output, + int count, + cudaStream_t stream, + int max_threads_per_block); + // ============================================================================ // GQABufferRequirements: Centralized buffer size calculation // ============================================================================ diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh index d132fba85988c..cd4088cf757ba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh @@ -56,6 +56,7 @@ inline Status Launch( [[maybe_unused]] const float scale, [[maybe_unused]] const bool is_bsnh, [[maybe_unused]] const int* past_seq_lens, + [[maybe_unused]] const float* attention_sinks, [[maybe_unused]] const float* kv_cache_scale, [[maybe_unused]] void* workspace, [[maybe_unused]] size_t workspace_size) { @@ -97,7 +98,7 @@ inline Status Launch( scale, out_ptr, q_ptr, - nullptr, // attentionSinks + attention_sinks, k_ptr, v_ptr, is_bsnh, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h index 8439c19687097..ee4fbc88982f8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h @@ -34,9 +34,10 @@ Status LaunchXQAKernel( const int head_size, const int max_seq_len, // Max sequence length of cache const float scale, - const bool is_bsnh, // Layout of KV cache - const int* past_seq_lens, // Past sequence lengths [BatchSize] - const float* kv_cache_scale, // KV cache dequant scale (nullptr for FP16/BF16, per-tensor float for INT8) + const bool is_bsnh, // Layout of KV cache + const int* past_seq_lens, // Past sequence lengths [BatchSize] + const float* attention_sinks, // Attention sink per query head, nullptr if not used + const float* kv_cache_scale, // KV cache dequant scale (nullptr for FP16/BF16, per-tensor float for INT8) const XqaQuantType kv_quant_type, void* workspace = nullptr, // Scratch memory size_t workspace_size = 0 // Size of scratch memory diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu index 4c6731b10fe77..4a2d22938d48d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu @@ -26,6 +26,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -49,6 +50,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -72,6 +74,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -118,12 +121,14 @@ Status LaunchXQAKernel<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, size_t workspace_size) { // Dispatch to INT8 path if requested if (kv_quant_type == XqaQuantType::kInt8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with INT8 KV cache."); return LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); } @@ -131,15 +136,15 @@ Status LaunchXQAKernel<__nv_bfloat16>( if (head_size == 256) { return H256::LaunchXQAKernelImpl<__nv_bfloat16>( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else if (head_size == 128) { return H128::LaunchXQAKernelImpl<__nv_bfloat16>( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else if (head_size == 64) { return H64::LaunchXQAKernelImpl<__nv_bfloat16>( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA only supports head_size=64, 128, or 256. Input has ", head_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu index 7572986d14632..a8ea76ab23b8b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu index 2706a9de32b14..79ddc2d0d7c34 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu index 7bd8897fdfd93..c94f6b5fc0695 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh index 481fcb63c1f8c..773b4810b6b30 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh @@ -102,13 +102,13 @@ Status LaunchXQAFp8KernelBF16( int group_size = num_heads / kv_num_heads; switch (group_size) { case 4: - return grp4_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 only supports group_size 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh index c2d9c057c6e50..6a84d452f1384 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh @@ -158,6 +158,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -179,6 +180,7 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -187,6 +189,7 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( // Dispatch to INT8 path if requested if (kv_quant_type == XqaQuantType::kInt8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with INT8 KV cache."); return LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, @@ -196,6 +199,7 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( #ifdef USE_FP8_KV_CACHE // Dispatch to FP8 path if requested if (kv_quant_type == XqaQuantType::kFp8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with FP8 KV cache."); return LaunchXQAFp8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, @@ -206,17 +210,17 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( int group_size = num_heads / kv_num_heads; switch (group_size) { case 1: - return grp1_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp1_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 2: - return grp2_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp2_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 4: - return grp4_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh index acec9aeed9973..0ad18e99c5841 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh @@ -102,13 +102,13 @@ Status LaunchXQAInt8KernelBF16( int group_size = num_heads / kv_num_heads; switch (group_size) { case 4: - return grp4_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 only supports group_size 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu index 37b974a8a3e60..b8171392e0f50 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu @@ -28,6 +28,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -52,6 +53,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -76,6 +78,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -101,6 +104,7 @@ Status LaunchXQAKernel( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -112,15 +116,15 @@ Status LaunchXQAKernel( if (head_size == 256) { return H256::LaunchXQAKernelImpl( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else if (head_size == 128) { return H128::LaunchXQAKernelImpl( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else if (head_size == 64) { return H64::LaunchXQAKernelImpl( device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, - max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, kv_quant_type, workspace, workspace_size); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA only supports head_size=64, 128, or 256. Input has ", head_size); } @@ -186,6 +190,7 @@ template Status LaunchXQAKernel( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu index 87304cfd1adc2..06c8b0ce0ea2a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu index 3d070a87f87a8..756cc61cb9720 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu index 1664122dbc6d3..4b5b0fe4f17c9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu @@ -26,6 +26,7 @@ template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh index 5e18d21defb79..0a613ead5c16e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh @@ -101,13 +101,13 @@ Status LaunchXQAFp8Kernel( int group_size = num_heads / kv_num_heads; switch (group_size) { case 4: - return grp4_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 only supports group_size 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh index 675beb3c92d0f..269b7956c0999 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh @@ -158,6 +158,7 @@ Status LaunchXQAKernelImpl( const float scale, const bool is_bsnh, const int* past_seq_lens, + const float* attention_sinks, const float* kv_cache_scale, const XqaQuantType kv_quant_type, void* workspace, @@ -166,6 +167,7 @@ Status LaunchXQAKernelImpl( // Dispatch to INT8 path if requested if (kv_quant_type == XqaQuantType::kInt8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with INT8 KV cache."); if constexpr (std::is_same::value) { return LaunchXQAInt8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); } else { @@ -177,6 +179,7 @@ Status LaunchXQAKernelImpl( #ifdef USE_FP8_KV_CACHE // Dispatch to FP8 path if requested if (kv_quant_type == XqaQuantType::kFp8) { + ORT_RETURN_IF(attention_sinks != nullptr, "XQA attention sinks are not supported with FP8 KV cache."); if constexpr (std::is_same::value) { return LaunchXQAFp8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); } else { @@ -189,17 +192,17 @@ Status LaunchXQAKernelImpl( int group_size = num_heads / kv_num_heads; switch (group_size) { case 1: - return grp1_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp1_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 2: - return grp2_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp2_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 4: - return grp4_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh index f3a1fcd8a8e63..ebeccfb60c7ba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh @@ -101,13 +101,13 @@ Status LaunchXQAInt8Kernel( int group_size = num_heads / kv_num_heads; switch (group_size) { case 4: - return grp4_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp4_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 8: - return grp8_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp8_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 16: - return grp16_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp16_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); case 32: - return grp32_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + return grp32_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, nullptr, kv_cache_scale, workspace, workspace_size); default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 only supports group_size 4, 8, 16, 32. Input has ", group_size); } diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py index 7f0d50a7ac8ed..d3dd86ea9bbc6 100644 --- a/onnxruntime/test/python/transformers/gqa_test_helper.py +++ b/onnxruntime/test/python/transformers/gqa_test_helper.py @@ -310,6 +310,7 @@ def __init__( v_quant_type: str = "NONE", kv_cache_type: str = "float16", share_kv_scale: bool = False, + has_head_sink: bool = False, ): super().__init__( "GroupQueryAttention", @@ -341,6 +342,7 @@ def __init__( self.k_quant_type = k_quant_type self.v_quant_type = v_quant_type self.share_kv_scale = share_kv_scale + self.has_head_sink = has_head_sink # Determine bit width from cache type if applicable if kv_cache_type == "int4": self.kv_cache_bit_width = 4 @@ -359,6 +361,8 @@ def shape_dict(self): "seqlens_k": (self.batch_size,), } ) + if self.has_head_sink: + shapes["head_sink"] = (self.num_heads,) # Note: We don't adjust shapes for int4 here because the parent's random_inputs # creates float tensors first, then quantization will pack them return shapes @@ -371,6 +375,8 @@ def random_inputs(self): "seqlens_k": k_seqlens - 1, } ) + if self.has_head_sink: + feeds["head_sink"] = torch.rand((self.num_heads,), device=self.device, dtype=self.dtype) # Generate quantized cache and scales if quantization is enabled if self.k_quant_type != "NONE": @@ -423,7 +429,7 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): "sin_cache" if config.do_rotary else "", "", # position_ids (optional, not used in benchmark) "", # attention_bias (optional, not used in benchmark) - "", # head_sink (optional, not used in benchmark) + "head_sink" if config.has_head_sink else "", "k_scale" if config.k_quant_type != "NONE" else "", "v_scale" if config.v_quant_type != "NONE" else "", ] @@ -512,6 +518,9 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])), ] + if config.has_head_sink: + graph_input.append(helper.make_tensor_value_info("head_sink", float_type, list(shape_dict["head_sink"]))) + # Add scale inputs for quantization # Shape depends on quantization type: # - PER_TENSOR: [1] diff --git a/onnxruntime/test/python/transformers/profile_gqa.py b/onnxruntime/test/python/transformers/profile_gqa.py index e11184f31673e..6c36577133600 100644 --- a/onnxruntime/test/python/transformers/profile_gqa.py +++ b/onnxruntime/test/python/transformers/profile_gqa.py @@ -62,6 +62,7 @@ def create_gqa_config( local_window_size: int = -1, is_packed_qkv: bool = False, do_rotary: bool = True, + has_head_sink: bool = False, device: str = "cuda", share_kv_scale: bool = False, ) -> GroupQueryAttentionConfig: @@ -103,6 +104,7 @@ def create_gqa_config( dtype=dtype, is_packed_qkv=is_packed_qkv, use_smooth_softmax=False, + has_head_sink=has_head_sink, device=device, k_quant_type=k_quant_type, v_quant_type=v_quant_type, @@ -147,6 +149,7 @@ def run_comparison(args): print(f"{'=' * 70}") print(f"Config: batch={args.batch_size}, seq_len={args.sequence_length}, past_seq={args.past_sequence_length}") print(f" num_heads={args.num_heads}, kv_heads={args.kv_num_heads}, head_size={args.head_size}") + print(f" packed_qkv={args.is_packed_qkv}, rotary={not args.no_rotary}, head_sink={args.head_sink}") print(f" warmup={args.warmup}, repeat={args.repeat}") print(f"{'=' * 70}\n") @@ -166,6 +169,7 @@ def run_comparison(args): local_window_size=args.local_window_size, is_packed_qkv=args.is_packed_qkv, do_rotary=not args.no_rotary, + has_head_sink=args.head_sink, share_kv_scale=args.share_kv_scale, ) avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=mode) @@ -203,6 +207,7 @@ def main(): parser.add_argument("--warmup", type=int, default=50, help="Warmup iterations") parser.add_argument("--repeat", type=int, default=100, help="Benchmark iterations") parser.add_argument("--is-packed-qkv", action="store_true", help="Use packed QKV") + parser.add_argument("--head-sink", action="store_true", help="Add a head_sink input") parser.add_argument("--no-rotary", action="store_true", help="Disable rotary embeddings") parser.add_argument("--share-kv-scale", action="store_true", help="Share KV scale tensor for XQA") From 2d726419e4dc5dea233bc386b89916ae1345b60a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 06:46:59 +0000 Subject: [PATCH 3/6] enable xqa by default when head sink exists --- .../contrib_ops/cuda/bert/attention_data.h | 4 +++ .../cuda/bert/group_query_attention.cc | 35 +++++++++++++++---- .../cuda/bert/group_query_attention.h | 4 ++- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 6ae814641125a..74aeaf1285e8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -206,7 +206,11 @@ struct GroupQueryAttentionData { // XQA buffer void* xqa_buffer = nullptr; size_t xqa_buffer_bytes = 0; + // FP32 per-head attention sink consumed by the XQA kernel (nullptr when no head_sink input). + // Either points to a PrePack-cached buffer or to scratch that is filled at launch time. float* xqa_head_sink = nullptr; + // When true, head_sink was not prepacked (e.g. dynamic/non-initializer input) and the FP16/BF16 + // head_sink must be converted to xqa_head_sink (FP32 scratch) before launching XQA. bool xqa_head_sink_needs_conversion = false; // Unfused fallback buffers (see LaunchUnfusedAttention in unfused_attention.h): diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index ca5b738ebb4c5..60235024b9118 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -111,8 +111,15 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) kv_cache_bit_width_ = static_cast(info.GetAttrOrDefault("kv_cache_bit_width", 0)); bool is_quantized = (k_quant_type_ != KVQuantizationType::NONE || v_quant_type_ != KVQuantizationType::NONE); - int default_enable_xqa = is_quantized ? 1 : 0; - enable_xqa_ = (std::is_same_v || std::is_same_v) && ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", default_enable_xqa) != 0; + // XQA enablement: + // - An explicit ORT_ENABLE_XQA overrides everything (1 = on, 0 = off, including the head_sink default-on path). + // - When unset, XQA defaults on for the quantized KV cache path and off for the non-quantized path + // (the non-quantized head_sink decode path is additionally enabled per-Run in ComputeInternal). + constexpr bool kIsFp16OrBf16 = std::is_same_v || std::is_same_v; + const int xqa_env = ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", -1); // -1 means unset + xqa_force_disabled_ = (xqa_env == 0); + const int effective_enable_xqa = (xqa_env == -1) ? (is_quantized ? 1 : 0) : xqa_env; + enable_xqa_ = kIsFp16OrBf16 && (effective_enable_xqa != 0); kernel_options_ = this->GetAttentionKernelOptions(); @@ -122,7 +129,6 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); // cuDNN SDPA (cudnn_frontend) supports FP16 and BF16 and is auto-preferred on SM>=90. - constexpr bool kIsFp16OrBf16 = std::is_same::value || std::is_same::value; enable_cudnn_flash_attention_ = kIsFp16OrBf16 && kernel_options_->UseCudnnFlashAttention(); auto_enable_cudnn_flash_attention_ = kIsFp16OrBf16 && kernel_options_->AllowCudnnFlashAttentionAuto(); @@ -138,12 +144,17 @@ template Status GroupQueryAttention::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); + // Keep is_packed=false so the original fp16/bf16 head_sink remains available to the Flash/fallback + // paths (which are used when XQA is disabled or ineligible). We only cache an extra FP32 copy for XQA. is_packed = false; if (input_idx != kHeadSinkInputIndex) { return Status::OK(); } + // XQA consumes the attention sink as FP32. When head_sink is a constant initializer, convert it once + // here into a cached device buffer (xqa_head_sink_) to avoid a per-launch conversion. Dynamic / + // non-initializer head_sink inputs are not prepacked and fall back to the per-launch scratch path. if constexpr (std::is_same_v || std::is_same_v) { const auto& shape = tensor.Shape(); ORT_RETURN_IF_NOT(shape.NumDimensions() == 1, @@ -152,6 +163,8 @@ Status GroupQueryAttention::PrePack(const Tensor& tensor, int input_idx, A "head_sink dimension 0 must be equal to the num heads, got ", shape[0]); ORT_RETURN_IF_NOT(tensor.IsDataType(), "head_sink type must match GroupQueryAttention input type"); + // Derive the element count from the tensor itself (one sink per head) rather than num_heads_. + const int head_sink_count = static_cast(shape.Size()); const size_t head_sink_bytes = tensor.SizeInBytes(); const void* head_sink_data = tensor.DataRaw(); IAllocatorUniquePtr head_sink_gpu; @@ -164,13 +177,13 @@ Status GroupQueryAttention::PrePack(const Tensor& tensor, int input_idx, A head_sink_data = head_sink_gpu.get(); } - xqa_head_sink_ = IAllocator::MakeUniquePtr(alloc, static_cast(num_heads_), true); + xqa_head_sink_ = IAllocator::MakeUniquePtr(alloc, static_cast(head_sink_count), true); using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; ORT_RETURN_IF_ERROR(LaunchConvertHeadSinkToFloat( - reinterpret_cast(head_sink_data), xqa_head_sink_.get(), num_heads_, stream, + reinterpret_cast(head_sink_data), xqa_head_sink_.get(), head_sink_count, stream, GetDeviceProp().maxThreadsPerBlock)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); - xqa_head_sink_count_ = num_heads_; + xqa_head_sink_count_ = head_sink_count; } return Status::OK(); @@ -385,7 +398,14 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // 7. No local window attention (global attention only). const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized; const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks; - if (enable_xqa_ && + // XQA is opt-in for the non-quantized path (ORT_ENABLE_XQA), but a head_sink (attention sink) input + // signals a GPT-OSS style decode model that benefits from XQA, so enable it by default in that case. + // An explicit ORT_ENABLE_XQA=0 (xqa_force_disabled_) still wins and turns XQA off entirely. + // The dtype guard mirrors enable_xqa_ (XQA only supports fp16/bf16); ineligible cases fall back below. + constexpr bool kIsFp16OrBf16 = std::is_same_v || std::is_same_v; + const bool xqa_enabled_for_run = + !xqa_force_disabled_ && (enable_xqa_ || (kIsFp16OrBf16 && use_xqa_attention_sinks)); + if (xqa_enabled_for_run && (device_prop.major >= 8) && !parameters.is_first_prompt && parameters.sequence_length == 1 && @@ -451,6 +471,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const bool convert_xqa_head_sink = use_xqa_attention_sinks && !use_prepacked_xqa_head_sink; size_t xqa_head_sink_bytes = 0; if (convert_xqa_head_sink) { + // No prepacked FP32 head_sink (dynamic input): reserve scratch for the per-launch conversion. xqa_head_sink_bytes = parameters.num_heads * sizeof(float); xqa_head_sink_bytes = (xqa_head_sink_bytes + 255) / 256 * 256; xqa_total_bytes += xqa_head_sink_bytes; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 9c22de951a6f1..d5b980bdca290 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -39,6 +39,7 @@ class GroupQueryAttention final : public CudaKernel { bool disable_memory_efficient_attention_; bool disable_flash_decode_; bool enable_xqa_; + bool xqa_force_disabled_; // True when ORT_ENABLE_XQA=0 is explicitly set (overrides default-on paths). bool enable_cudnn_flash_attention_; // cuDNN SDPA explicitly enabled (env / sdpa_kernel) bool auto_enable_cudnn_flash_attention_; // auto-prefer cuDNN SDPA on SM>=90 when no explicit kernel pinned @@ -48,8 +49,9 @@ class GroupQueryAttention final : public CudaKernel { static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; + // FP32 head_sink cached in PrePack for the XQA path (empty when head_sink is not a constant initializer). IAllocatorUniquePtr xqa_head_sink_; - int xqa_head_sink_count_ = 0; + int xqa_head_sink_count_ = 0; // Number of elements in xqa_head_sink_ (0 when not prepacked). const AttentionKernelOptions* kernel_options_; }; From 00ba6003063cab863798c65a1f954ffaee17144a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 06:47:23 +0000 Subject: [PATCH 4/6] add test and doc --- docs/contrib_ops/gqa.md | 172 ++++++++++++++++++ .../test/python/transformers/test_gqa.py | 162 ++++++++++++++++- 2 files changed, 326 insertions(+), 8 deletions(-) create mode 100644 docs/contrib_ops/gqa.md diff --git a/docs/contrib_ops/gqa.md b/docs/contrib_ops/gqa.md new file mode 100644 index 0000000000000..e6d640d1a0d8f --- /dev/null +++ b/docs/contrib_ops/gqa.md @@ -0,0 +1,172 @@ +# GroupQueryAttention โ€” Operator Documentation + +This document describes the `com.microsoft::GroupQueryAttention` (GQA) contrib operator: its schema, +the CUDA kernel backends and how one is selected, and the attention-sink (`head_sink`) decode path +that is accelerated by the XQA kernel. + +For CPU-specific implementation details (including the quantized KV-cache flash path), see +[cpu/gqa.md](cpu/gqa.md). + +--- + +## Table of Contents + +1. [Overview](#1-overview) +2. [Operator Schema](#2-operator-schema) +3. [KV Cache and Quantization](#3-kv-cache-and-quantization) +4. [Attention Sink (`head_sink`) and Smooth Softmax](#4-attention-sink-head_sink-and-smooth-softmax) +5. [CUDA Kernel Backends and Dispatch](#5-cuda-kernel-backends-and-dispatch) +6. [XQA Decode Path](#6-xqa-decode-path) +7. [XQA `head_sink` PrePack](#7-xqa-head_sink-prepack) +8. [Environment Variables](#8-environment-variables) +9. [Testing](#9-testing) + +--- + +## 1. Overview + +GroupQueryAttention implements causal grouped-query attention with KV-cache (past/present) support. +Grouped-query attention uses fewer key/value heads than query heads: each KV head is shared by a +group of `num_heads / kv_num_heads` query heads. The operator also supports: + +- Rotary positional embeddings (RoPE) +- Past/present KV cache with optional in-place (shared) buffer +- Quantized KV cache (int4 / int8 / float8e4m3fn) to reduce memory footprint +- Optional attention bias and local (sliding) window attention +- Smooth softmax, including a per-head attention sink (`head_sink`) + +The operator schema is defined in +[onnxruntime/core/graph/contrib_ops/bert_defs.cc](../../onnxruntime/core/graph/contrib_ops/bert_defs.cc). +The CUDA kernel is implemented in +[onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc) +and [group_query_attention_impl.cu](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu). + +## 2. Operator Schema + +Selected attributes: + +| Attribute | Description | +|-----------|-------------| +| `num_heads` | Number of query heads. | +| `kv_num_heads` | Number of key/value heads. `num_heads % kv_num_heads == 0`. | +| `scale` | Softmax scale. Defaults to `1/sqrt(head_size)`. | +| `softcap` | Optional logit soft-capping value. `0` disables it. | +| `local_window_size` | Left window size for local attention. `-1` means global attention. | +| `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. | +| `smooth_softmax` | Add a smooth factor to the softmax denominator. | +| `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. | +| `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). | + +Selected inputs (see the schema for the full list and shapes): + +| Index | Name | Notes | +|-------|------|-------| +| 0 | `query` | `(batch, seq, hidden)`, or packed QKV. | +| 1, 2 | `key`, `value` | Optional when QKV is packed into `query`. | +| 3, 4 | `past_key`, `past_value` | BNSH cache. Shares the buffer with `present_*` when in-place. | +| 5 | `seqlens_k` | `total_sequence_lengths - 1` per batch entry. | +| 6 | `total_sequence_length` | Scalar used to distinguish prompt vs. decode. | +| 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. | +| 11 | `head_sink` | `(num_heads,)` per-head attention sink (see ยง4). | +| 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. | + +Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`. + +## 3. KV Cache and Quantization + +The past/present KV cache uses BNSH layout `(batch_size, kv_num_heads, cache_sequence_length, head_size)`. +When `past_present_share_buffer` holds (the past and present tensors alias the same memory), the cache +length is the maximum sequence length and new keys/values are appended in place. This shared-buffer mode +is required by the XQA decode path. + +When quantization is enabled, `k_quant_type` and `v_quant_type` select `PER_TENSOR` or `PER_CHANNEL` +scaling, and `kv_cache_bit_width` selects 8-bit or 4-bit storage. The `k_scale` / `v_scale` inputs are +always FP32. + +## 4. Attention Sink (`head_sink`) and Smooth Softmax + +An attention sink adds a learned per-head bias term to the softmax denominator. With sink value `s_h` +for head `h`, the attention weights over `T` cached positions become: + +$$ +\text{softmax}_i = \frac{e^{x_i - m}}{e^{s_h - m} + \sum_{j} e^{x_j - m}}, \quad m = \max\left(s_h, \max_j x_j\right) +$$ + +This is equivalent to appending a single extra logit `s_h` (whose value contributes nothing to the +output, only to normalization). GPT-OSS style models use this to let a head attend to "nothing". + +In the kernel, providing the `head_sink` input is treated as smooth softmax: +`parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr`. The `head_sink` tensor is +1D of shape `(num_heads,)` and matches the operator's floating-point type (`float16` or `bfloat16` on +the XQA path). + +## 5. CUDA Kernel Backends and Dispatch + +The CUDA EP can route a GQA node to several backends. At runtime it selects the first eligible one: + +| Backend | Typical use | +|---------|-------------| +| **XQA** | Single-token global decode (`seq_len == 1`), shared KV buffer. Fastest decode path. | +| **Flash Attention / Flash Decoding** | General prompt and decode, including local window and softcap. | +| **cuDNN SDPA** | Preferred on SMโ‰ฅ90 for non-quantized FP16/BF16 causal attention. | +| **Memory Efficient Attention** | Fallback for FP16/FP32 (and BF16 on SM80+). | +| **Unfused** | Last-resort fallback (e.g. `head_size > 256` with past KV). | + +The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is enabled. + +## 6. XQA Decode Path + +XQA (a highly optimized cross/decode attention kernel) is used only when **all** of the following hold: + +1. Compute capability SM 8.0+ (Ampere or newer). +2. Decoding phase (not the first prompt) with `sequence_length == 1`. +3. `kv_sequence_length > 0` (there is a new K/V to append). +4. Past and present KV cache share the same buffer. +5. No softcap. +6. Standard softmax, **or** smooth softmax expressed via a `head_sink` tensor (non-quantized KV cache). +7. No local (sliding) window attention โ€” global attention only. +8. Supported `head_size` (64, 128, or 256) and group size. + +`head_sink` (attention sink) is supported on the non-quantized XQA path only. Quantized KV cache +(int8 / fp8) paths explicitly reject a non-null attention sink, so a GQA node with both `head_sink` +and a quantized cache falls back to Flash/Flash-Decoding. + +XQA selection defaults are: + +- **Quantized KV cache (int8 / fp8):** on by default. +- **Non-quantized with a `head_sink` input:** on by default (GPT-OSS style decode). +- **Non-quantized without `head_sink`:** opt-in via `ORT_ENABLE_XQA=1`. + +Setting `ORT_ENABLE_XQA=0` disables XQA for the non-quantized path regardless of `head_sink`. + +## 7. XQA `head_sink` PrePack + +XQA consumes the attention sink as an FP32 buffer, while the model stores `head_sink` as FP16/BF16. To +avoid converting on every decode step, `GroupQueryAttention::PrePack` converts a **constant-initializer** +`head_sink` once into a cached FP32 device buffer (`xqa_head_sink_`): + +- The cached buffer is reused for every launch when XQA is eligible. +- A dynamic / non-initializer `head_sink` is **not** prepacked; the kernel instead reserves a small FP32 + scratch buffer and converts the sink per launch (`xqa_head_sink_needs_conversion = true`). +- `PrePack` keeps `is_packed = false` so the original FP16/BF16 `head_sink` is still delivered to the + Flash/fallback paths when XQA is disabled or ineligible. + +## 8. Environment Variables + +| Variable | Effect | +|----------|--------| +| `ORT_ENABLE_XQA` | `1` enables the XQA decode path for the non-quantized KV cache (default off; default on for quantized). | +| `ORT_DISABLE_FLASH_DECODE` | `1` disables the Flash Decoding split-KV optimization. | + +These are read once when the kernel is constructed. + +## 9. Testing + +CUDA parity tests live in +[onnxruntime/test/python/transformers/test_gqa.py](../../onnxruntime/test/python/transformers/test_gqa.py): + +- `TestXQAQuantizedParity` โ€” XQA per-tensor int8 quantized decode parity. +- `TestXQAHeadSinkParity` โ€” non-quantized XQA decode parity with a `head_sink` (attention sink) input. + +Both classes set `ORT_ENABLE_XQA=1` so the XQA path is exercised, and compare against a PyTorch +reference (`attention_ref` with `smooth_softmax_ref`). diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 310018c3395ae..7229166443f93 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -87,6 +87,10 @@ class GQAConfig: softcap: float = 0.0 use_smooth_softmax: bool = False has_head_sink: bool = False + # When True, head_sink is baked into the model as a constant initializer (instead of a runtime + # input). This exercises the GroupQueryAttention::PrePack path that converts the constant + # head_sink to a cached FP32 XQA buffer. + head_sink_as_initializer: bool = False kv_cache_type: str = "" share_buffer: bool = True share_kv_scale: bool = False @@ -190,12 +194,23 @@ def apply_rotary_embedding(x, cos, sin, pos, interleaved, device="cpu"): # ################################################################################################# +def make_head_sink_initializer(head_sink, ort_type, num_heads): + """Build a constant head_sink initializer (fp16/bf16) so GroupQueryAttention::PrePack runs. + + The 16-bit float bits are reinterpreted as uint16 and stored as raw bytes, which works for + both float16 and bfloat16 without relying on numpy bfloat16 support. + """ + raw = head_sink.detach().reshape(num_heads).cpu().contiguous().view(torch.uint16).numpy().tobytes() + return helper.make_tensor(name="head_sink", data_type=ort_type, dims=[num_heads], vals=raw, raw=True) + + def create_gqa_node_and_io( config: GQAConfig, ort_type, share_buffer=True, is_past=False, output_qk: int = 0, # CUDA does not support output_qk for GQA + head_sink_values=None, ): if is_past: if share_buffer: @@ -211,6 +226,8 @@ def create_gqa_node_and_io( if not config.kv_cache_type: config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" + initializers = [] + # --- Node Definition --- outputs = [ "output", @@ -348,7 +365,11 @@ def create_gqa_node_and_io( ) ) if config.has_head_sink: - graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) + if config.head_sink_as_initializer and head_sink_values is not None: + # Constant initializer (not a graph input) so ORT treats it as a constant and PrePack runs. + initializers.append(make_head_sink_initializer(head_sink_values, ort_type, config.num_heads)) + else: + graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] @@ -372,19 +393,23 @@ def create_gqa_node_and_io( ) ) - return node, graph_input, graph_output + return node, graph_input, graph_output, initializers def create_group_query_attention_graph_prompt(config: GQAConfig, ort_type, share_buffer=True): - node, graph_input, graph_output = create_gqa_node_and_io(config, ort_type, share_buffer, is_past=False) - graph = helper.make_graph([node], "GroupQueryAttention_Graph", graph_input, graph_output) + node, graph_input, graph_output, initializers = create_gqa_node_and_io( + config, ort_type, share_buffer, is_past=False + ) + graph = helper.make_graph([node], "GroupQueryAttention_Graph", graph_input, graph_output, initializer=initializers) model = helper.make_model(graph) return model.SerializeToString() -def create_group_query_attention_graph_past(config: GQAConfig, ort_type, share_buffer=True): - node, graph_input, graph_output = create_gqa_node_and_io(config, ort_type, share_buffer, is_past=True) - graph = helper.make_graph([node], "GroupQueryAttention_Graph", graph_input, graph_output) +def create_group_query_attention_graph_past(config: GQAConfig, ort_type, share_buffer=True, head_sink_values=None): + node, graph_input, graph_output, initializers = create_gqa_node_and_io( + config, ort_type, share_buffer, is_past=True, head_sink_values=head_sink_values + ) + graph = helper.make_graph([node], "GroupQueryAttention_Graph", graph_input, graph_output, initializer=initializers) model = helper.make_model(graph) return model.SerializeToString() @@ -605,10 +630,12 @@ def gqa_past_func( if not config.kv_cache_type: config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" + head_sink_as_initializer = config.has_head_sink and config.head_sink_as_initializer and head_sink is not None onnx_model_str = create_group_query_attention_graph_past( config=config, ort_type=ort_type, share_buffer=share_buffer, + head_sink_values=head_sink if head_sink_as_initializer else None, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) @@ -671,7 +698,7 @@ def gqa_past_func( if config.has_attention_bias and attention_bias is not None: bind_tensor(io_binding, "attention_bias", attention_bias, device, ort_type) - if config.has_head_sink and head_sink is not None: + if config.has_head_sink and head_sink is not None and not head_sink_as_initializer: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) # 6. Quantization @@ -1948,6 +1975,11 @@ def has_flash_attention(bf16=False): return True +def has_xqa(): + # The XQA decode kernels require Ampere (SM 8.0) or newer. + return has_cuda_device(80) + + rtol = { "fp16": 5e-3, "bf16": 5e-2, @@ -2335,6 +2367,120 @@ def test_xqa_quantized_parity(self, name, config, torch_type, ort_type): ) +def gqa_xqa_head_sink_test_cases(): + # Non-quantized global decode with a head_sink (attention sink) input. + # These configs exercise the XQA attention-sink path added for GPT-OSS style models: + # seq_len=1, shared KV buffer, no softcap, no local window, head_size in {64, 128, 256}, + # and 64 % group_size == 0. + for torch_type, ort_type in [(torch.float16, TensorProto.FLOAT16), (torch.bfloat16, TensorProto.BFLOAT16)]: + for group_size in [1, 4, 8]: + for head_size in [64, 128]: + for rotary in [False, True]: + kv_num_heads = 4 + num_heads = kv_num_heads * group_size + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + past_kv_sequence_length=4, + buffer_sequence_length=4 + 128, + rotary=rotary, + packed=False, + share_buffer=True, + has_head_sink=True, + ) + type_str = "bf16" if torch_type == torch.bfloat16 else "fp16" + rot_str = "rot" if rotary else "norot" + name = f"{type_str}_g{group_size}_h{head_size}_{rot_str}" + yield name, config, torch_type, ort_type + + +def gqa_xqa_head_sink_prepack_test_cases(): + # Same XQA attention-sink decode path as gqa_xqa_head_sink_test_cases(), but head_sink is baked + # into the model as a constant initializer. This exercises GroupQueryAttention::PrePack, which + # converts the constant head_sink once into the cached FP32 XQA buffer (use_prepacked_xqa_head_sink), + # instead of the per-launch conversion scratch path used for runtime head_sink inputs. + for torch_type, ort_type in [(torch.float16, TensorProto.FLOAT16), (torch.bfloat16, TensorProto.BFLOAT16)]: + for group_size in [1, 4]: + for rotary in [False, True]: + kv_num_heads = 4 + num_heads = kv_num_heads * group_size + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=128, + past_kv_sequence_length=4, + buffer_sequence_length=4 + 128, + rotary=rotary, + packed=False, + share_buffer=True, + has_head_sink=True, + head_sink_as_initializer=True, + ) + type_str = "bf16" if torch_type == torch.bfloat16 else "fp16" + rot_str = "rot" if rotary else "norot" + name = f"{type_str}_g{group_size}_h128_{rot_str}_prepack" + yield name, config, torch_type, ort_type + + +@unittest.skipIf(not has_xqa(), "XQA is not available, skipping tests.") +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestXQAHeadSinkParity(unittest.TestCase): + """Verify the non-quantized XQA attention-sink (head_sink) decode path matches the reference.""" + + def setUp(self): + # XQA is enabled by default when a head_sink input is present, so this path is exercised + # without ORT_ENABLE_XQA. Clear it (saving the previous value) to test the real default. + self._prev_enable_xqa = os.environ.pop("ORT_ENABLE_XQA", None) + + def tearDown(self): + # Restore the environment so other tests run with the default XQA setting. + if self._prev_enable_xqa is None: + os.environ.pop("ORT_ENABLE_XQA", None) + else: + os.environ["ORT_ENABLE_XQA"] = self._prev_enable_xqa + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + + @parameterized.expand(gqa_xqa_head_sink_test_cases()) + def test_xqa_head_sink_parity(self, name, config, torch_type, ort_type): + """Test XQA non-quantized parity with a head_sink (attention sink) input.""" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol["bf16"] if torch_type == torch.bfloat16 else rtol["fp16"], + atol=atol["bf16"] if torch_type == torch.bfloat16 else atol["fp16"], + std=0.1, + ) + + @parameterized.expand(gqa_xqa_head_sink_prepack_test_cases()) + def test_xqa_head_sink_prepack_parity(self, name, config, torch_type, ort_type): + """Test XQA parity when head_sink is a constant initializer (exercises PrePack).""" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol["bf16"] if torch_type == torch.bfloat16 else rtol["fp16"], + atol=atol["bf16"] if torch_type == torch.bfloat16 else atol["fp16"], + std=0.1, + ) + + @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @unittest.skipIf(not has_quantized_kv_cache(), "Quantized KV Cache is not available, skipping tests.") class TestGQARegressions(unittest.TestCase): From 8c8df9b73a94e56612a9d4331ea0ee03ea5a4af1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 19:15:48 +0000 Subject: [PATCH 5/6] scoped env var --- .../python/transformers/env_var_helper.py | 25 ++ .../test/python/transformers/test_gqa.py | 254 +++++++++--------- .../test/python/transformers/test_qmoe_cpu.py | 16 +- 3 files changed, 153 insertions(+), 142 deletions(-) create mode 100644 onnxruntime/test/python/transformers/env_var_helper.py diff --git a/onnxruntime/test/python/transformers/env_var_helper.py b/onnxruntime/test/python/transformers/env_var_helper.py new file mode 100644 index 0000000000000..77d42291ce12e --- /dev/null +++ b/onnxruntime/test/python/transformers/env_var_helper.py @@ -0,0 +1,25 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os +from contextlib import contextmanager + + +@contextmanager +def scoped_env_var(name: str, value: str): + """Temporarily set an environment variable, restoring the previous value on exit. + + Keeps tests order-independent by ensuring env-var mutations do not leak into + later tests running in the same process. + """ + previous = os.environ.get(name) + os.environ[name] = value + try: + yield + finally: + if previous is None: + os.environ.pop(name, None) + else: + os.environ[name] = previous diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 7229166443f93..9d169f8ba6b7f 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -22,6 +22,7 @@ import torch from cuda_plugin_ep_helper import get_cuda_provider_name, resolve_cuda_plugin_ep from einops import rearrange, repeat +from env_var_helper import scoped_env_var # --- ONNX and Torch/Numpy Dtype Mappings --- from gqa_test_helper import ( @@ -2020,17 +2021,17 @@ def test_gqa_prompt_flash_attention(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @parameterized.expand(gqa_cuda_past_test_cases()) def test_gqa_past_flash_attention(self, name, config): @@ -2038,17 +2039,17 @@ def test_gqa_past_flash_attention(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @unittest.skipIf(not has_flash_attention(bf16=True), "Flash Attention is not available, skipping tests.") @@ -2069,17 +2070,17 @@ def test_gqa_prompt_flash_attention_bf16(self, name, config): print(f"test_case: {name}\n{config}") config.kv_cache_type = "bfloat16" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol["bf16"], - atol=atol["bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) @parameterized.expand(gqa_cuda_past_test_cases()) def test_gqa_past_flash_attention_bf16(self, name, config): @@ -2091,17 +2092,17 @@ def test_gqa_past_flash_attention_bf16(self, name, config): print(f"test_case: {name}\n{config}") config.kv_cache_type = "bfloat16" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol["bf16"], - atol=atol["bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @@ -2132,17 +2133,17 @@ def test_gqa_quantized_prompt_bf16(self, name, config): self.manual_seed() - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol[f"{config.kv_cache_type}_bf16"], - atol=atol[f"{config.kv_cache_type}_bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol[f"{config.kv_cache_type}_bf16"], + atol=atol[f"{config.kv_cache_type}_bf16"], + ) @parameterized.expand(gqa_cuda_quantized_test_cases(is_past=True)) def test_gqa_quantized_past_bf16(self, name, config): @@ -2152,17 +2153,17 @@ def test_gqa_quantized_past_bf16(self, name, config): self.manual_seed() - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol[f"{config.kv_cache_type}_bf16"], - atol=atol[f"{config.kv_cache_type}_bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol[f"{config.kv_cache_type}_bf16"], + atol=atol[f"{config.kv_cache_type}_bf16"], + ) @unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") @@ -2173,17 +2174,17 @@ def test_gqa_prompt_memory_efficient(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_check_gqa_prompt( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "1"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @parameterized.expand(gqa_cuda_past_test_cases(allow_head_sink=False)) def test_gqa_past_memory_efficient(self, name, config): @@ -2191,17 +2192,17 @@ def test_gqa_past_memory_efficient(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @unittest.skipIf(not has_cuda_device(80), "BF16 requires Ampere or higher GPU, skipping tests.") @@ -2212,17 +2213,17 @@ def test_gqa_past_memory_efficient_bf16(self, name, config): print("-" * 20) print(f"test_case: {name}\n{config}") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol["bf16"], - atol=atol["bf16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @@ -2232,8 +2233,8 @@ def test_gqa_padding_prompt_flash_attention(self): print("-" * 20) print("test_case: test_gqa_padding_prompt_flash_attention") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_test_gqa_padding_prompt() + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_test_gqa_padding_prompt() @unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") @@ -2243,8 +2244,8 @@ def test_gqa_padding_prompt_memory_efficient_attention(self): print("-" * 20) print("test_case: test_gqa_padding_prompt_memory_efficient_attention") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_test_gqa_padding_prompt() + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "1"): + parity_test_gqa_padding_prompt() # ################################################################################################# @@ -2352,19 +2353,18 @@ def tearDown(self): @parameterized.expand(gqa_xqa_test_cases()) def test_xqa_quantized_parity(self, name, config, torch_type, ort_type): """Test XQA per-tensor INT8 quantized parity.""" - os.environ["ORT_ENABLE_XQA"] = "1" - - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch_type, - ort_type=ort_type, - causal=True, - rtol=rtol["int8_bf16"] if torch_type == torch.bfloat16 else rtol["int8_fp16"], - atol=atol["int8_bf16"] if torch_type == torch.bfloat16 else atol["int8_fp16"], - std=0.1, - ) + with scoped_env_var("ORT_ENABLE_XQA", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol["int8_bf16"] if torch_type == torch.bfloat16 else rtol["int8_fp16"], + atol=atol["int8_bf16"] if torch_type == torch.bfloat16 else atol["int8_fp16"], + std=0.1, + ) def gqa_xqa_head_sink_test_cases(): @@ -2600,17 +2600,17 @@ def test_gqa_local_window_large_context_decode(self): ort_type = TensorProto.FLOAT16 device = "cuda" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device=device, - torch_type=torch_type, - ort_type=ort_type, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") def test_gqa_fp8_kv_cache(self): diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 2a3f7eedf4cba..62041d8a432dc 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -25,15 +25,14 @@ # normalization on the selected experts. This provides proper weight distribution # while maintaining computational efficiency. # -------------------------------------------------------------------------- -import os import time import unittest from collections import OrderedDict -from contextlib import contextmanager import numpy import torch import torch.nn.functional as F +from env_var_helper import scoped_env_var from onnx import helper from parameterized import parameterized from torch import nn @@ -1196,19 +1195,6 @@ def with_mlas_q4_mode(test_cases): return expanded_cases -@contextmanager -def scoped_env_var(name: str, value: str): - previous = os.environ.get(name) - os.environ[name] = value - try: - yield - finally: - if previous is None: - os.environ.pop(name, None) - else: - os.environ[name] = previous - - def run_parity_with_mlas_q4_mode(test_runner, enable_mlas_q4_gemm: bool | None): if enable_mlas_q4_gemm is None: # No env var test_runner() From 73062182ef2736ede7a69f2ee3c3451490f40555 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 19:19:43 +0000 Subject: [PATCH 6/6] Address PR review feedback: fix head_size comment, profile_gqa import, and XQA testing doc --- docs/contrib_ops/gqa.md | 5 +++-- onnxruntime/test/python/transformers/profile_gqa.py | 10 +++++++++- onnxruntime/test/python/transformers/test_gqa.py | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/contrib_ops/gqa.md b/docs/contrib_ops/gqa.md index e6d640d1a0d8f..08596ff4b5dd9 100644 --- a/docs/contrib_ops/gqa.md +++ b/docs/contrib_ops/gqa.md @@ -168,5 +168,6 @@ CUDA parity tests live in - `TestXQAQuantizedParity` โ€” XQA per-tensor int8 quantized decode parity. - `TestXQAHeadSinkParity` โ€” non-quantized XQA decode parity with a `head_sink` (attention sink) input. -Both classes set `ORT_ENABLE_XQA=1` so the XQA path is exercised, and compare against a PyTorch -reference (`attention_ref` with `smooth_softmax_ref`). +`TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity` +instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input +is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`). diff --git a/onnxruntime/test/python/transformers/profile_gqa.py b/onnxruntime/test/python/transformers/profile_gqa.py index 6c36577133600..49ce26b5126b8 100644 --- a/onnxruntime/test/python/transformers/profile_gqa.py +++ b/onnxruntime/test/python/transformers/profile_gqa.py @@ -20,10 +20,18 @@ """ import argparse +import os import time import torch -from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention + +try: + from gqa_test_helper import GroupQueryAttentionConfig, OrtGroupQueryAttention +except ImportError: + import sys + + sys.path.insert(0, os.path.dirname(__file__)) + from gqa_test_helper import GroupQueryAttentionConfig, OrtGroupQueryAttention # Optional NVTX support for nsys range markers try: diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 9d169f8ba6b7f..529eae1494e94 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -2370,7 +2370,7 @@ def test_xqa_quantized_parity(self, name, config, torch_type, ort_type): def gqa_xqa_head_sink_test_cases(): # Non-quantized global decode with a head_sink (attention sink) input. # These configs exercise the XQA attention-sink path added for GPT-OSS style models: - # seq_len=1, shared KV buffer, no softcap, no local window, head_size in {64, 128, 256}, + # seq_len=1, shared KV buffer, no softcap, no local window, head_size in {64, 128}, # and 64 % group_size == 0. for torch_type, ort_type in [(torch.float16, TensorProto.FLOAT16), (torch.bfloat16, TensorProto.BFLOAT16)]: for group_size in [1, 4, 8]: