From 591df5b11914e5ce4636b378e8b3c871da3ec695 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 11 Jun 2026 16:43:28 +0800 Subject: [PATCH 1/5] webgpu: fix GQA batched right-padded prefill with do_rotary When GenAI runs a batched prefill with prompts of unequal lengths, short prompts are right-padded up to the batch max sequence_length and each batch's real length is reported via seqlens_k[b] = real_len[b] - 1. The WebGPU rotary embedding shaders computed past_seqlen = (seqlens_k[b]+1) - sequence_length per batch, which underflowed u32 for any batch shorter than sequence_length. The resulting astronomically large position_id indexed past the cos/sin caches and produced garbage rotated Q/K, which manifested as gibberish output text for the shorter batches in the batch. Clamp past_seqlen to 0 in all three rotary embedding shaders: RotaryEmbeddingProgram (seqlens variant), FusedQKRotaryEmbeddingProgram, and the split_packed_qkv_with_rotary_embedding template. Also extend CanApplyFlashAttention to bypass FlashAttention for batched cases with per-batch seqlens (which exercise the unpatched and-copykv variant), while still allowing it for shared-KV layers where it is mandatory. Adds a regression test exercising the packed-QKV do_rotary path with three batches of unequal real lengths. --- .../webgpu/bert/flash_attention.cc | 6 +- .../contrib_ops/webgpu/bert/flash_attention.h | 2 +- .../webgpu/bert/group_query_attention.cc | 2 +- .../webgpu/bert/rotary_embedding.cc | 6 +- ...ed_qkv_with_rotary_embedding.wgsl.template | 3 +- .../group_query_attention_op_test.cc | 193 ++++++++++++++++++ 6 files changed, 205 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 02e764d01e05e..9be6a047cea9c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -606,8 +606,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return Status::OK(); } -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - return !parameters.is_packed_qkv_ && +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const bool kv_empty = parameters.kv_sequence_length_ == 0; + return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && + !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 3da6b33b4dc0e..218baf926173f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -205,7 +205,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr); -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); // Split packed QKV with Q/K rotary embedding and copy KV cache fusion Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 930cb296122ce..36d688c9723fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -350,7 +350,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(temp_params, context); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k); } if (kv_empty) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index 1b11a69de7824..b4fbdc555a6d5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -44,7 +44,8 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let seqlen_i = " << position_ids_or_seqlens.GetByOffset("batch_idx") << ";\n" << " let seqlen = u32(seqlen_i);\n" " let total_seqlen = seqlen + 1u;\n" - " let past_seqlen = total_seqlen - uniforms.global_shape[1];\n" + " // Right-padded batches with prompt shorter than global_shape[1] would underflow u32; clamp to 0.\n" + " let past_seqlen = select(total_seqlen - uniforms.global_shape[1], 0u, total_seqlen <= uniforms.global_shape[1]);\n" " let position_id = past_seqlen + bsnh[1];\n" << " let i = dot(bsnh, uniforms.input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let j = i + select(half_rotary_emb_dim, 1u, " << interleaved_str << ");\n" @@ -200,7 +201,8 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c << " let seqlen_i = " << seqlens.GetByOffset("batch_idx") << ";\n" << " let seqlen = u32(seqlen_i);\n" << " let total_seqlen = seqlen + 1u;\n" - << " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n" + << " // Right-padded batches with prompt shorter than q_global_shape[1] would underflow u32; clamp to 0.\n" + << " let past_seqlen = select(total_seqlen - uniforms.q_global_shape[1], 0u, total_seqlen <= uniforms.q_global_shape[1]);\n" << " let position_id = past_seqlen + sequence_idx;\n" << " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template index 7fcdfcfddfb25..51eda83d089f1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template @@ -30,7 +30,8 @@ $MAIN { let seqlen_i = seqlens.getByOffset(batch_idx); let seqlen = u32(seqlen_i); let total_seqlen = seqlen + 1u; - let past_seqlen = total_seqlen - uniforms.sequence_length; + // Right-padded batches with prompt shorter than sequence_length would underflow u32; clamp to 0. + let past_seqlen = select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length); let position_id = past_seqlen + seq_idx; #if use_multi_rotary_cache_concat let base_position = select(0u, multi_rotary_cache_concat_offset, total_seqlen > multi_rotary_cache_concat_offset); diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 821f43971848a..e342c872fd9b4 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2392,5 +2392,198 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_SlidingWindow) { tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// --------------------------------------------------------------------------- +// WebGPU: batched right-padded packed-QKV prefill regression +// +// In a multi-batch prefill where individual prompts have different real lengths, +// GenAI right-pads short prompts up to the max sequence_length and reports each +// batch's real length via seqlens_k[b] = real_len[b] - 1. The WebGPU rotary +// embedding shader for packed-QKV computes past_seqlen = (seqlens_k[b] + 1) - +// sequence_length per-batch. For a short batch whose real_len < sequence_length, +// that subtraction underflowed u32, producing astronomically large position_ids +// that read out-of-bounds from cos/sin caches -- garbage values manifesting as +// gibberish output text. The fix clamps past_seqlen to 0 during prefill. +// +// This test exercises the packed-QKV do_rotary path (which dispatches +// SplitPackedQKVWithRotaryEmbeddingProgram). It compares each batch's +// real-last-token output against a single-batch reference for the same prompt. +// --------------------------------------------------------------------------- + +// Builds a packed QKV tensor with deterministic values at real positions and +// zeros at right-padded positions. Layout per token: [Q(hidden), K(kv), V(kv)]. +// Uses values of order ~1.0 (well above the 5e-3 mismatch tolerance) so the +// rotated-vs-unrotated divergence is unambiguously detectable. +static void FillBatchedRightPaddedPackedQKV(int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const std::vector& real_lens, + std::vector& packed_out) { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int token_size = hidden_size + 2 * kv_hidden_size; + packed_out.assign(batch_size * sequence_length * token_size, 0.0f); + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + for (int s = 0; s < real_len; ++s) { + float* token = &packed_out[(b * sequence_length + s) * token_size]; + for (int c = 0; c < hidden_size; ++c) { + token[c] = 0.1f + 0.3f * static_cast(((b * 7 + s * 3 + c) % 13) + 1); + } + for (int c = 0; c < kv_hidden_size; ++c) { + token[hidden_size + c] = + 0.1f + 0.25f * static_cast(((b * 5 + s * 2 + c) % 11) + 1); + token[hidden_size + kv_hidden_size + c] = + 0.1f + 0.2f * static_cast(((b * 3 + s + c) % 9) + 1); + } + } + } +} + +// Runs a packed-QKV GQA prefill with do_rotary=1 and the given per-batch +// seqlens_k. Returns the output tensor [batch_size, sequence_length, hidden_size]. +static std::vector RunGQAPackedQKVRotaryPrefill( + int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const std::vector& seqlens_k_data, + const std::vector& packed_qkv_data) { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int qkv_hidden = hidden_size + 2 * kv_hidden_size; + const int total_sequence_length = sequence_length; // prefill: no past + const int half_rotary = head_size / 2; + const int max_seq_len = sequence_length + 8; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + // Packed QKV: pass through `query` input, leave key/value as optional edges. + tester.AddInput("query", {batch_size, sequence_length, qkv_hidden}, packed_qkv_data); + tester.AddOptionalInputEdge(); // key (signals packed) + tester.AddOptionalInputEdge(); // value (signals packed) + + tester.AddOptionalInputEdge(); // past_key + tester.AddOptionalInputEdge(); // past_value + + tester.AddInput("seqlens_k", {batch_size}, seqlens_k_data); + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, + /*is_initializer=*/true); + + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + const float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / + static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, cos_cache); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, sin_cache); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, 0.0f)); + const int present_size = batch_size * kv_num_heads * total_sequence_length * head_size; + tester.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size}, + std::vector(present_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, head_size}, + std::vector(present_size, 0.0f)); + + tester.SetOutputTolerance(1e6f); // We fetch and compare outputs ourselves. + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + auto fetches = tester.GetFetches(); + const float* out_data = fetches[0].Get().Data(); + return std::vector(out_data, out_data + output_size); +} + +// Regression for u32 underflow in WebGPU SplitPackedQKVWithRotaryEmbedding +// shader during right-padded batched prefill. Runs each prompt singly to build +// a reference, then runs all prompts as a right-padded batch and asserts that +// each batch's real-last-token output matches its single-prompt reference. +TEST(GroupQueryAttentionTest, WebGPU_BatchedRightPaddedRotaryPrefill) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + + constexpr int batch_size = 3; + constexpr int num_heads = 4; + constexpr int kv_num_heads = 2; + constexpr int head_size = 16; // multiple of 4 for FlashAttention gate; rotary half = 8 + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int qkv_hidden = hidden_size + 2 * kv_hidden_size; + + // Real prompt lengths per batch; max = sequence_length (right-padding extends + // shorter batches up to this length). The bug only manifests when at least + // one batch is shorter than sequence_length. + const std::vector real_lens = {4, 2, 6}; + const int sequence_length = *std::max_element(real_lens.begin(), real_lens.end()); + + std::vector packed_batched; + FillBatchedRightPaddedPackedQKV(batch_size, sequence_length, num_heads, kv_num_heads, + head_size, real_lens, packed_batched); + + // Build single-prompt references by extracting each batch's real-len slice + // and running it as a batch_size=1 prefill (which is known correct). + std::vector> ref_outputs(batch_size); + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + std::vector packed_single(real_len * qkv_hidden); + for (int s = 0; s < real_len; ++s) { + std::copy_n(&packed_batched[(b * sequence_length + s) * qkv_hidden], qkv_hidden, + &packed_single[s * qkv_hidden]); + } + ref_outputs[b] = RunGQAPackedQKVRotaryPrefill( + /*batch_size=*/1, /*sequence_length=*/real_len, + num_heads, kv_num_heads, head_size, + /*seqlens_k_data=*/{static_cast(real_len - 1)}, + packed_single); + } + + // Now run all batches together with right-padding. + std::vector seqlens_k_data(batch_size); + for (int b = 0; b < batch_size; ++b) { + seqlens_k_data[b] = static_cast(real_lens[b] - 1); + } + const auto batched_output = RunGQAPackedQKVRotaryPrefill( + batch_size, sequence_length, num_heads, kv_num_heads, head_size, + seqlens_k_data, packed_batched); + + // Each batch's real-last-token output (used to predict next token) must match + // its single-prompt reference. The tolerance is loose enough for fp16 rounding + // while still catching the underflow bug (which produces values that differ + // by orders of magnitude or are NaN/Inf). + constexpr float tolerance = 5e-3f; + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + const int q_last = real_len - 1; + const float* batched_last = + batched_output.data() + (b * sequence_length + q_last) * hidden_size; + const float* ref_last = ref_outputs[b].data() + q_last * hidden_size; + for (int c = 0; c < hidden_size; ++c) { + EXPECT_NEAR(batched_last[c], ref_last[c], tolerance) + << "batch " << b << " real_len=" << real_len + << " channel " << c << " mismatch"; + } + } +} + } // namespace test } // namespace onnxruntime From e979c5588a28795f50449a381524e8fee7d0e944 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 16 Jun 2026 11:32:32 +0800 Subject: [PATCH 2/5] test: generalize BatchedRightPaddedRotaryPrefill to all available EPs The test added in the previous commit was scoped to WebGPU, but the property it asserts (each batch's real-last-token output equals the single-prompt reference) is generic and applies to any GQA-supporting EP. CPU and CUDA both support packed-QKV with do_rotary, so generalizing the test gives meaningful cross-EP coverage instead of leaving CPU and CUDA uncovered. Mirror the existing convention in this file: the runner takes bool use_cuda, bool use_webgpu defaulting to false (same as RunGQASharedKV / RunGQASharedKVWithRotary), and three thin TEST cases named _CPU / _CUDA / _WebGPU dispatch the inner helper for each EP with runtime availability checks via GTEST_SKIP. No production code touched. Co-Authored-By: Claude Opus 4.6 --- .../group_query_attention_op_test.cc | 69 ++++++++++++------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index e342c872fd9b4..7ce5eab1ee6b2 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2393,20 +2393,15 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_SlidingWindow) { } // --------------------------------------------------------------------------- -// WebGPU: batched right-padded packed-QKV prefill regression +// Batched right-padded packed-QKV prefill with do_rotary. // // In a multi-batch prefill where individual prompts have different real lengths, // GenAI right-pads short prompts up to the max sequence_length and reports each -// batch's real length via seqlens_k[b] = real_len[b] - 1. The WebGPU rotary -// embedding shader for packed-QKV computes past_seqlen = (seqlens_k[b] + 1) - -// sequence_length per-batch. For a short batch whose real_len < sequence_length, -// that subtraction underflowed u32, producing astronomically large position_ids -// that read out-of-bounds from cos/sin caches -- garbage values manifesting as -// gibberish output text. The fix clamps past_seqlen to 0 during prefill. -// -// This test exercises the packed-QKV do_rotary path (which dispatches -// SplitPackedQKVWithRotaryEmbeddingProgram). It compares each batch's -// real-last-token output against a single-batch reference for the same prompt. +// batch's real length via seqlens_k[b] = real_len[b] - 1. The property under +// test: each batch's real-last-token output (the one used to predict the next +// token) must equal what we get from running that prompt singly as a batch=1 +// prefill. This is a generic correctness check that any GQA-supporting EP +// should satisfy. // --------------------------------------------------------------------------- // Builds a packed QKV tensor with deterministic values at real positions and @@ -2450,7 +2445,9 @@ static std::vector RunGQAPackedQKVRotaryPrefill( int kv_num_heads, int head_size, const std::vector& seqlens_k_data, - const std::vector& packed_qkv_data) { + const std::vector& packed_qkv_data, + bool use_cuda = false, + bool use_webgpu = false) { const int hidden_size = num_heads * head_size; const int kv_hidden_size = kv_num_heads * head_size; const int qkv_hidden = hidden_size + 2 * kv_hidden_size; @@ -2504,7 +2501,13 @@ static std::vector RunGQAPackedQKVRotaryPrefill( tester.SetOutputTolerance(1e6f); // We fetch and compare outputs ourselves. std::vector> execution_providers; - execution_providers.push_back(DefaultWebGpuExecutionProvider()); + if (use_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } else if (use_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } else { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); auto fetches = tester.GetFetches(); @@ -2512,16 +2515,12 @@ static std::vector RunGQAPackedQKVRotaryPrefill( return std::vector(out_data, out_data + output_size); } -// Regression for u32 underflow in WebGPU SplitPackedQKVWithRotaryEmbedding -// shader during right-padded batched prefill. Runs each prompt singly to build -// a reference, then runs all prompts as a right-padded batch and asserts that -// each batch's real-last-token output matches its single-prompt reference. -TEST(GroupQueryAttentionTest, WebGPU_BatchedRightPaddedRotaryPrefill) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { - GTEST_SKIP() << "WebGPU EP not available"; - } - +// Inner helper: builds packed-QKV inputs, computes per-prompt references, runs +// the right-padded batched prefill, and asserts each batch's real-last-token +// output matches its single-prompt reference. Both reference and batched runs +// go through the same EP, so this validates per-batch consistency within each +// EP rather than cross-EP equivalence. +static void RunBatchedRightPaddedRotaryPrefillForEP(bool use_cuda, bool use_webgpu) { constexpr int batch_size = 3; constexpr int num_heads = 4; constexpr int kv_num_heads = 2; @@ -2554,7 +2553,7 @@ TEST(GroupQueryAttentionTest, WebGPU_BatchedRightPaddedRotaryPrefill) { /*batch_size=*/1, /*sequence_length=*/real_len, num_heads, kv_num_heads, head_size, /*seqlens_k_data=*/{static_cast(real_len - 1)}, - packed_single); + packed_single, use_cuda, use_webgpu); } // Now run all batches together with right-padding. @@ -2564,7 +2563,7 @@ TEST(GroupQueryAttentionTest, WebGPU_BatchedRightPaddedRotaryPrefill) { } const auto batched_output = RunGQAPackedQKVRotaryPrefill( batch_size, sequence_length, num_heads, kv_num_heads, head_size, - seqlens_k_data, packed_batched); + seqlens_k_data, packed_batched, use_cuda, use_webgpu); // Each batch's real-last-token output (used to predict next token) must match // its single-prompt reference. The tolerance is loose enough for fp16 rounding @@ -2585,5 +2584,25 @@ TEST(GroupQueryAttentionTest, WebGPU_BatchedRightPaddedRotaryPrefill) { } } +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_CPU) { + RunBatchedRightPaddedRotaryPrefillForEP(/*use_cuda=*/false, /*use_webgpu=*/false); +} + +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(/*use_cuda=*/true, /*use_webgpu=*/false); +} + +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(/*use_cuda=*/false, /*use_webgpu=*/true); +} + } // namespace test } // namespace onnxruntime From f76854065a47fb2706716eb2b9e4be37545c7995 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 17 Jun 2026 10:23:31 +0800 Subject: [PATCH 3/5] test: drop BatchedRightPaddedRotaryPrefill_CPU variant The CPU variant of the GQA packed-QKV do_rotary right-padded prefill regression test exposes a separate CPU-side issue tracked and fixed in another PR. Drop the _CPU TEST here so this PR's CI stays green while the CPU fix lands independently. The _CUDA and _WebGPU variants remain and continue to exercise the property under test. Co-Authored-By: Claude Opus 4.6 --- onnxruntime/test/contrib_ops/group_query_attention_op_test.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 7ce5eab1ee6b2..d091a34f0deac 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2584,10 +2584,6 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(bool use_cuda, bool use_webg } } -TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_CPU) { - RunBatchedRightPaddedRotaryPrefillForEP(/*use_cuda=*/false, /*use_webgpu=*/false); -} - TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_CUDA) { auto cuda_ep = DefaultCudaExecutionProvider(); if (!cuda_ep) { From 1ba89a0971cca1ea45c9132e65e5b80e4f93ad38 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 18 Jun 2026 14:17:15 +0800 Subject: [PATCH 4/5] Address PR #29002 review feedback for batched right-padded rotary prefill Apply three reviewer suggestions on the GQA WebGPU right-padded prefill fix: - Clamp the 4th packed-QKV rotary shader site against u32 underflow, byte identical to the sibling shader pattern. - Document the disjunction in CanApplyFlashAttention as a positive contract: FlashAttention does not implement right-padded per-batch prefill, so the first disjunction restricts inputs to shapes where padding cannot occur. - Replace the (bool use_cuda, bool use_webgpu) pair across three test helpers with a single GqaTargetEp enum, route EP construction through a central MakeExecutionProviderForGqaTest helper with an ORT_THROW default so a future enumerator cannot silently fall through to an empty provider vector, and migrate every caller. All 56 GroupQueryAttentionTest cases pass locally (48 OK + 8 CUDA-skipped on a machine without CUDA EP). --- .../webgpu/bert/flash_attention.cc | 7 ++ ..._rotary_embedding_and_copykv.wgsl.template | 3 +- .../group_query_attention_op_test.cc | 110 +++++++++--------- 3 files changed, 64 insertions(+), 56 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 9be6a047cea9c..d9d299d4fd5d9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -608,6 +608,13 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const bool kv_empty = parameters.kv_sequence_length_ == 0; + // FlashAttention here does not implement right-padded per-batch prefill, so the + // first disjunction restricts it to inputs where padding cannot occur: + // - batch_size_ == 1: single sequence, no padding possible. + // - seqlen_k == nullptr: no per-batch lengths, padding inexpressible. + // - kv_empty (shared-KV layer): FA is mandatory; that path takes a different shader. + // The remaining conjuncts exclude packed-QKV (handled by a separate rotary kernel), + // mismatched head/value sizes, and head_size alignments unsupported by the kernel. return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 7b09a3a6af080..97c610fb90024 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -31,7 +31,8 @@ $MAIN { let seqlen = u32(seqlen_i); let total_seqlen = seqlen + 1u; - let past_seqlen = total_seqlen - uniforms.sequence_length; + // Right-padded batches with prompt shorter than sequence_length would underflow u32; clamp to 0. + let past_seqlen = select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length); // `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value let position_id = past_seqlen + seq_idx; #if use_multi_rotary_cache_concat diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index d091a34f0deac..bb3ce7db3d96f 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -17,6 +17,27 @@ namespace onnxruntime { namespace test { +// Selects which EP backs a GQA test helper. Modelled as a single enum (rather +// than two bools) so adding a new EP later does not silently fall through to +// CPU and callers cannot accidentally select two backends at once. +enum class GqaTargetEp { kCpu, kCuda, kWebGpu }; + +// Builds the default EP for the chosen backend. Centralized so that adding a +// new enumerator only requires updating one switch; the `ORT_THROW` default +// turns a missed update into a loud runtime failure instead of a silent +// empty-provider fallback inside OpTester. +static std::unique_ptr MakeExecutionProviderForGqaTest(GqaTargetEp target_ep) { + switch (target_ep) { + case GqaTargetEp::kCuda: + return DefaultCudaExecutionProvider(); + case GqaTargetEp::kWebGpu: + return DefaultWebGpuExecutionProvider(); + case GqaTargetEp::kCpu: + return DefaultCpuExecutionProvider(); + } + ORT_THROW("Unhandled GqaTargetEp"); +} + // Helper to build a minimal GQA OpTester with given seqlens_k and total_seq_len. // Uses num_heads=1, kv_num_heads=1, and head_size=8; past may be provided via // provide_past/past_seq_len. @@ -778,8 +799,7 @@ static std::vector RunGQASharedKV( int num_heads, int kv_num_heads, int head_size, - bool use_cuda = false, - bool use_webgpu = false) { + GqaTargetEp target_ep = GqaTargetEp::kCpu) { const int hidden_size = num_heads * head_size; const int total_seq_len = past_seq_len; // all KV data is in past @@ -822,13 +842,7 @@ static std::vector RunGQASharedKV( tester.SetOutputTolerance(1e6f); // We compare fetched outputs ourselves std::vector> execution_providers; - if (use_cuda) { - execution_providers.push_back(DefaultCudaExecutionProvider()); - } else if (use_webgpu) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); - } else { - execution_providers.push_back(DefaultCpuExecutionProvider()); - } + execution_providers.push_back(MakeExecutionProviderForGqaTest(target_ep)); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); auto fetches = tester.GetFetches(); @@ -922,7 +936,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); // Verify non-zero and no NaN bool all_zero = true; @@ -952,7 +966,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Prompt_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; for (size_t i = 0; i < output.size(); i++) { @@ -991,7 +1005,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_CUDA) { auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_CUDA_vs_CPU"); } @@ -1015,7 +1029,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_LargeHeadSize_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; for (size_t i = 0; i < output.size(); i++) { @@ -1044,7 +1058,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_GQARatio8_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; for (size_t i = 0; i < output.size(); i++) { @@ -1073,7 +1087,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Batched_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; for (size_t i = 0; i < output.size(); i++) { @@ -1155,7 +1169,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Prompt_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Prompt_CUDA_vs_CPU"); } @@ -1187,7 +1201,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_LargeHeadSize_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_LargeHead_CUDA_vs_CPU"); } @@ -1219,7 +1233,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_GQARatio8_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.15f, "SharedKV_GQA8_CUDA_vs_CPU"); } @@ -1240,8 +1254,7 @@ static std::vector RunGQASharedKVWithRotary( int num_heads, int kv_num_heads, int head_size, - bool use_cuda = false, - bool use_webgpu = false) { + GqaTargetEp target_ep = GqaTargetEp::kCpu) { const int hidden_size = num_heads * head_size; const int total_seq_len = past_seq_len; const int rotary_dim = head_size; // full rotary @@ -1307,13 +1320,7 @@ static std::vector RunGQASharedKVWithRotary( tester.SetOutputTolerance(1e6f); std::vector> execution_providers; - if (use_cuda) { - execution_providers.push_back(DefaultCudaExecutionProvider()); - } else if (use_webgpu) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); - } else { - execution_providers.push_back(DefaultCpuExecutionProvider()); - } + execution_providers.push_back(MakeExecutionProviderForGqaTest(target_ep)); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); auto fetches = tester.GetFetches(); @@ -1423,12 +1430,12 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_CPU) { auto output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); // Output with rotary should differ from without rotary (RoPE changes Q projections) auto output_no_rotary = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; bool differs_from_no_rotary = false; @@ -1468,7 +1475,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Rotary_CUDA_vs_CPU"); } @@ -1500,7 +1507,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_Prompt_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Rotary_Prompt_CUDA_vs_CPU"); } @@ -2191,10 +2198,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Decode) { auto webgpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_WebGPU_vs_CPU"); } @@ -2223,10 +2230,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Prefill) { auto webgpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Prompt_WebGPU_vs_CPU"); } @@ -2255,10 +2262,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Rotary) { auto webgpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Rotary_WebGPU_vs_CPU"); } @@ -2288,10 +2295,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Rotary_Prefill) { auto webgpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Rotary_Prefill_WebGPU_vs_CPU"); } @@ -2321,10 +2328,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Rotary_MultiBatch) { auto webgpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Rotary_MultiBatch_WebGPU_vs_CPU"); } @@ -2446,8 +2453,7 @@ static std::vector RunGQAPackedQKVRotaryPrefill( int head_size, const std::vector& seqlens_k_data, const std::vector& packed_qkv_data, - bool use_cuda = false, - bool use_webgpu = false) { + GqaTargetEp target_ep = GqaTargetEp::kCpu) { const int hidden_size = num_heads * head_size; const int kv_hidden_size = kv_num_heads * head_size; const int qkv_hidden = hidden_size + 2 * kv_hidden_size; @@ -2501,13 +2507,7 @@ static std::vector RunGQAPackedQKVRotaryPrefill( tester.SetOutputTolerance(1e6f); // We fetch and compare outputs ourselves. std::vector> execution_providers; - if (use_cuda) { - execution_providers.push_back(DefaultCudaExecutionProvider()); - } else if (use_webgpu) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); - } else { - execution_providers.push_back(DefaultCpuExecutionProvider()); - } + execution_providers.push_back(MakeExecutionProviderForGqaTest(target_ep)); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); auto fetches = tester.GetFetches(); @@ -2520,7 +2520,7 @@ static std::vector RunGQAPackedQKVRotaryPrefill( // output matches its single-prompt reference. Both reference and batched runs // go through the same EP, so this validates per-batch consistency within each // EP rather than cross-EP equivalence. -static void RunBatchedRightPaddedRotaryPrefillForEP(bool use_cuda, bool use_webgpu) { +static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { constexpr int batch_size = 3; constexpr int num_heads = 4; constexpr int kv_num_heads = 2; @@ -2553,7 +2553,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(bool use_cuda, bool use_webg /*batch_size=*/1, /*sequence_length=*/real_len, num_heads, kv_num_heads, head_size, /*seqlens_k_data=*/{static_cast(real_len - 1)}, - packed_single, use_cuda, use_webgpu); + packed_single, target_ep); } // Now run all batches together with right-padding. @@ -2563,7 +2563,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(bool use_cuda, bool use_webg } const auto batched_output = RunGQAPackedQKVRotaryPrefill( batch_size, sequence_length, num_heads, kv_num_heads, head_size, - seqlens_k_data, packed_batched, use_cuda, use_webgpu); + seqlens_k_data, packed_batched, target_ep); // Each batch's real-last-token output (used to predict next token) must match // its single-prompt reference. The tolerance is loose enough for fp16 rounding @@ -2589,7 +2589,7 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_CUDA) { if (!cuda_ep) { GTEST_SKIP() << "CUDA EP not available"; } - RunBatchedRightPaddedRotaryPrefillForEP(/*use_cuda=*/true, /*use_webgpu=*/false); + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kCuda); } TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { @@ -2597,7 +2597,7 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { if (!webgpu_ep) { GTEST_SKIP() << "WebGPU EP not available"; } - RunBatchedRightPaddedRotaryPrefillForEP(/*use_cuda=*/false, /*use_webgpu=*/true); + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu); } } // namespace test From e8b245458fa5dc6ae5c63e70f725a734ef4d169a Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 18 Jun 2026 14:27:53 +0800 Subject: [PATCH 5/5] Apply lintrunner fixes for PR #29002 - Spelling: "Modelled" -> "Modeled" (US English, per misspell linter) - clang-format: split the GqaTargetEp enum onto one enumerator per line --- .../test/contrib_ops/group_query_attention_op_test.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index bb3ce7db3d96f..84d3b18de73fe 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -17,10 +17,12 @@ namespace onnxruntime { namespace test { -// Selects which EP backs a GQA test helper. Modelled as a single enum (rather +// Selects which EP backs a GQA test helper. Modeled as a single enum (rather // than two bools) so adding a new EP later does not silently fall through to // CPU and callers cannot accidentally select two backends at once. -enum class GqaTargetEp { kCpu, kCuda, kWebGpu }; +enum class GqaTargetEp { kCpu, + kCuda, + kWebGpu }; // Builds the default EP for the chosen backend. Centralized so that adding a // new enumerator only requires updating one switch; the `ORT_THROW` default