diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 02e764d01e05e..d9d299d4fd5d9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -606,8 +606,17 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return Status::OK(); } -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - return !parameters.is_packed_qkv_ && +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const bool kv_empty = parameters.kv_sequence_length_ == 0; + // FlashAttention here does not implement right-padded per-batch prefill, so the + // first disjunction restricts it to inputs where padding cannot occur: + // - batch_size_ == 1: single sequence, no padding possible. + // - seqlen_k == nullptr: no per-batch lengths, padding inexpressible. + // - kv_empty (shared-KV layer): FA is mandatory; that path takes a different shader. + // The remaining conjuncts exclude packed-QKV (handled by a separate rotary kernel), + // mismatched head/value sizes, and head_size alignments unsupported by the kernel. + return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && + !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 3da6b33b4dc0e..218baf926173f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -205,7 +205,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr); -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); // Split packed QKV with Q/K rotary embedding and copy KV cache fusion Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 930cb296122ce..36d688c9723fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -350,7 +350,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(temp_params, context); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k); } if (kv_empty) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index 1b11a69de7824..b4fbdc555a6d5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -44,7 +44,8 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let seqlen_i = " << position_ids_or_seqlens.GetByOffset("batch_idx") << ";\n" << " let seqlen = u32(seqlen_i);\n" " let total_seqlen = seqlen + 1u;\n" - " let past_seqlen = total_seqlen - uniforms.global_shape[1];\n" + " // Right-padded batches with prompt shorter than global_shape[1] would underflow u32; clamp to 0.\n" + " let past_seqlen = select(total_seqlen - uniforms.global_shape[1], 0u, total_seqlen <= uniforms.global_shape[1]);\n" " let position_id = past_seqlen + bsnh[1];\n" << " let i = dot(bsnh, uniforms.input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let j = i + select(half_rotary_emb_dim, 1u, " << interleaved_str << ");\n" @@ -200,7 +201,8 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c << " let seqlen_i = " << seqlens.GetByOffset("batch_idx") << ";\n" << " let seqlen = u32(seqlen_i);\n" << " let total_seqlen = seqlen + 1u;\n" - << " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n" + << " // Right-padded batches with prompt shorter than q_global_shape[1] would underflow u32; clamp to 0.\n" + << " let past_seqlen = select(total_seqlen - uniforms.q_global_shape[1], 0u, total_seqlen <= uniforms.q_global_shape[1]);\n" << " let position_id = past_seqlen + sequence_idx;\n" << " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template index 7fcdfcfddfb25..51eda83d089f1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template @@ -30,7 +30,8 @@ $MAIN { let seqlen_i = seqlens.getByOffset(batch_idx); let seqlen = u32(seqlen_i); let total_seqlen = seqlen + 1u; - let past_seqlen = total_seqlen - uniforms.sequence_length; + // Right-padded batches with prompt shorter than sequence_length would underflow u32; clamp to 0. + let past_seqlen = select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length); let position_id = past_seqlen + seq_idx; #if use_multi_rotary_cache_concat let base_position = select(0u, multi_rotary_cache_concat_offset, total_seqlen > multi_rotary_cache_concat_offset); diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 7b09a3a6af080..97c610fb90024 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -31,7 +31,8 @@ $MAIN { let seqlen = u32(seqlen_i); let total_seqlen = seqlen + 1u; - let past_seqlen = total_seqlen - uniforms.sequence_length; + // Right-padded batches with prompt shorter than sequence_length would underflow u32; clamp to 0. + let past_seqlen = select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length); // `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value let position_id = past_seqlen + seq_idx; #if use_multi_rotary_cache_concat diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 821f43971848a..84d3b18de73fe 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -17,6 +17,29 @@ namespace onnxruntime { namespace test { +// Selects which EP backs a GQA test helper. Modeled as a single enum (rather +// than two bools) so adding a new EP later does not silently fall through to +// CPU and callers cannot accidentally select two backends at once. +enum class GqaTargetEp { kCpu, + kCuda, + kWebGpu }; + +// Builds the default EP for the chosen backend. Centralized so that adding a +// new enumerator only requires updating one switch; the `ORT_THROW` default +// turns a missed update into a loud runtime failure instead of a silent +// empty-provider fallback inside OpTester. +static std::unique_ptr MakeExecutionProviderForGqaTest(GqaTargetEp target_ep) { + switch (target_ep) { + case GqaTargetEp::kCuda: + return DefaultCudaExecutionProvider(); + case GqaTargetEp::kWebGpu: + return DefaultWebGpuExecutionProvider(); + case GqaTargetEp::kCpu: + return DefaultCpuExecutionProvider(); + } + ORT_THROW("Unhandled GqaTargetEp"); +} + // Helper to build a minimal GQA OpTester with given seqlens_k and total_seq_len. // Uses num_heads=1, kv_num_heads=1, and head_size=8; past may be provided via // provide_past/past_seq_len. @@ -778,8 +801,7 @@ static std::vector RunGQASharedKV( int num_heads, int kv_num_heads, int head_size, - bool use_cuda = false, - bool use_webgpu = false) { + GqaTargetEp target_ep = GqaTargetEp::kCpu) { const int hidden_size = num_heads * head_size; const int total_seq_len = past_seq_len; // all KV data is in past @@ -822,13 +844,7 @@ static std::vector RunGQASharedKV( tester.SetOutputTolerance(1e6f); // We compare fetched outputs ourselves std::vector> execution_providers; - if (use_cuda) { - execution_providers.push_back(DefaultCudaExecutionProvider()); - } else if (use_webgpu) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); - } else { - execution_providers.push_back(DefaultCpuExecutionProvider()); - } + execution_providers.push_back(MakeExecutionProviderForGqaTest(target_ep)); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); auto fetches = tester.GetFetches(); @@ -922,7 +938,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); // Verify non-zero and no NaN bool all_zero = true; @@ -952,7 +968,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Prompt_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; for (size_t i = 0; i < output.size(); i++) { @@ -991,7 +1007,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_CUDA) { auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_CUDA_vs_CPU"); } @@ -1015,7 +1031,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_LargeHeadSize_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; for (size_t i = 0; i < output.size(); i++) { @@ -1044,7 +1060,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_GQARatio8_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; for (size_t i = 0; i < output.size(); i++) { @@ -1073,7 +1089,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Batched_CPU) { auto output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; for (size_t i = 0; i < output.size(); i++) { @@ -1155,7 +1171,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Prompt_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Prompt_CUDA_vs_CPU"); } @@ -1187,7 +1203,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_LargeHeadSize_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_LargeHead_CUDA_vs_CPU"); } @@ -1219,7 +1235,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_GQARatio8_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.15f, "SharedKV_GQA8_CUDA_vs_CPU"); } @@ -1240,8 +1256,7 @@ static std::vector RunGQASharedKVWithRotary( int num_heads, int kv_num_heads, int head_size, - bool use_cuda = false, - bool use_webgpu = false) { + GqaTargetEp target_ep = GqaTargetEp::kCpu) { const int hidden_size = num_heads * head_size; const int total_seq_len = past_seq_len; const int rotary_dim = head_size; // full rotary @@ -1307,13 +1322,7 @@ static std::vector RunGQASharedKVWithRotary( tester.SetOutputTolerance(1e6f); std::vector> execution_providers; - if (use_cuda) { - execution_providers.push_back(DefaultCudaExecutionProvider()); - } else if (use_webgpu) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); - } else { - execution_providers.push_back(DefaultCpuExecutionProvider()); - } + execution_providers.push_back(MakeExecutionProviderForGqaTest(target_ep)); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); auto fetches = tester.GetFetches(); @@ -1423,12 +1432,12 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_CPU) { auto output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); // Output with rotary should differ from without rotary (RoPE changes Q projections) auto output_no_rotary = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); bool all_zero = true; bool differs_from_no_rotary = false; @@ -1468,7 +1477,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Rotary_CUDA_vs_CPU"); } @@ -1500,7 +1509,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_Prompt_CUDA) { num_heads, kv_num_heads, head_size); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false); + num_heads, kv_num_heads, head_size); ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Rotary_Prompt_CUDA_vs_CPU"); } @@ -2191,10 +2200,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Decode) { auto webgpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_WebGPU_vs_CPU"); } @@ -2223,10 +2232,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Prefill) { auto webgpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKV( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Prompt_WebGPU_vs_CPU"); } @@ -2255,10 +2264,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Rotary) { auto webgpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Rotary_WebGPU_vs_CPU"); } @@ -2288,10 +2297,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Rotary_Prefill) { auto webgpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Rotary_Prefill_WebGPU_vs_CPU"); } @@ -2321,10 +2330,10 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Rotary_MultiBatch) { auto webgpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/true); + num_heads, kv_num_heads, head_size, GqaTargetEp::kWebGpu); auto cpu_output = RunGQASharedKVWithRotary( batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data, - num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false); + num_heads, kv_num_heads, head_size, GqaTargetEp::kCpu); ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Rotary_MultiBatch_WebGPU_vs_CPU"); } @@ -2392,5 +2401,206 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_SlidingWindow) { tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// --------------------------------------------------------------------------- +// Batched right-padded packed-QKV prefill with do_rotary. +// +// In a multi-batch prefill where individual prompts have different real lengths, +// GenAI right-pads short prompts up to the max sequence_length and reports each +// batch's real length via seqlens_k[b] = real_len[b] - 1. The property under +// test: each batch's real-last-token output (the one used to predict the next +// token) must equal what we get from running that prompt singly as a batch=1 +// prefill. This is a generic correctness check that any GQA-supporting EP +// should satisfy. +// --------------------------------------------------------------------------- + +// Builds a packed QKV tensor with deterministic values at real positions and +// zeros at right-padded positions. Layout per token: [Q(hidden), K(kv), V(kv)]. +// Uses values of order ~1.0 (well above the 5e-3 mismatch tolerance) so the +// rotated-vs-unrotated divergence is unambiguously detectable. +static void FillBatchedRightPaddedPackedQKV(int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const std::vector& real_lens, + std::vector& packed_out) { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int token_size = hidden_size + 2 * kv_hidden_size; + packed_out.assign(batch_size * sequence_length * token_size, 0.0f); + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + for (int s = 0; s < real_len; ++s) { + float* token = &packed_out[(b * sequence_length + s) * token_size]; + for (int c = 0; c < hidden_size; ++c) { + token[c] = 0.1f + 0.3f * static_cast(((b * 7 + s * 3 + c) % 13) + 1); + } + for (int c = 0; c < kv_hidden_size; ++c) { + token[hidden_size + c] = + 0.1f + 0.25f * static_cast(((b * 5 + s * 2 + c) % 11) + 1); + token[hidden_size + kv_hidden_size + c] = + 0.1f + 0.2f * static_cast(((b * 3 + s + c) % 9) + 1); + } + } + } +} + +// Runs a packed-QKV GQA prefill with do_rotary=1 and the given per-batch +// seqlens_k. Returns the output tensor [batch_size, sequence_length, hidden_size]. +static std::vector RunGQAPackedQKVRotaryPrefill( + int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const std::vector& seqlens_k_data, + const std::vector& packed_qkv_data, + GqaTargetEp target_ep = GqaTargetEp::kCpu) { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int qkv_hidden = hidden_size + 2 * kv_hidden_size; + const int total_sequence_length = sequence_length; // prefill: no past + const int half_rotary = head_size / 2; + const int max_seq_len = sequence_length + 8; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + // Packed QKV: pass through `query` input, leave key/value as optional edges. + tester.AddInput("query", {batch_size, sequence_length, qkv_hidden}, packed_qkv_data); + tester.AddOptionalInputEdge(); // key (signals packed) + tester.AddOptionalInputEdge(); // value (signals packed) + + tester.AddOptionalInputEdge(); // past_key + tester.AddOptionalInputEdge(); // past_value + + tester.AddInput("seqlens_k", {batch_size}, seqlens_k_data); + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, + /*is_initializer=*/true); + + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + const float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / + static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, cos_cache); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, sin_cache); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, 0.0f)); + const int present_size = batch_size * kv_num_heads * total_sequence_length * head_size; + tester.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size}, + std::vector(present_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, head_size}, + std::vector(present_size, 0.0f)); + + tester.SetOutputTolerance(1e6f); // We fetch and compare outputs ourselves. + + std::vector> execution_providers; + execution_providers.push_back(MakeExecutionProviderForGqaTest(target_ep)); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + auto fetches = tester.GetFetches(); + const float* out_data = fetches[0].Get().Data(); + return std::vector(out_data, out_data + output_size); +} + +// Inner helper: builds packed-QKV inputs, computes per-prompt references, runs +// the right-padded batched prefill, and asserts each batch's real-last-token +// output matches its single-prompt reference. Both reference and batched runs +// go through the same EP, so this validates per-batch consistency within each +// EP rather than cross-EP equivalence. +static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { + constexpr int batch_size = 3; + constexpr int num_heads = 4; + constexpr int kv_num_heads = 2; + constexpr int head_size = 16; // multiple of 4 for FlashAttention gate; rotary half = 8 + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int qkv_hidden = hidden_size + 2 * kv_hidden_size; + + // Real prompt lengths per batch; max = sequence_length (right-padding extends + // shorter batches up to this length). The bug only manifests when at least + // one batch is shorter than sequence_length. + const std::vector real_lens = {4, 2, 6}; + const int sequence_length = *std::max_element(real_lens.begin(), real_lens.end()); + + std::vector packed_batched; + FillBatchedRightPaddedPackedQKV(batch_size, sequence_length, num_heads, kv_num_heads, + head_size, real_lens, packed_batched); + + // Build single-prompt references by extracting each batch's real-len slice + // and running it as a batch_size=1 prefill (which is known correct). + std::vector> ref_outputs(batch_size); + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + std::vector packed_single(real_len * qkv_hidden); + for (int s = 0; s < real_len; ++s) { + std::copy_n(&packed_batched[(b * sequence_length + s) * qkv_hidden], qkv_hidden, + &packed_single[s * qkv_hidden]); + } + ref_outputs[b] = RunGQAPackedQKVRotaryPrefill( + /*batch_size=*/1, /*sequence_length=*/real_len, + num_heads, kv_num_heads, head_size, + /*seqlens_k_data=*/{static_cast(real_len - 1)}, + packed_single, target_ep); + } + + // Now run all batches together with right-padding. + std::vector seqlens_k_data(batch_size); + for (int b = 0; b < batch_size; ++b) { + seqlens_k_data[b] = static_cast(real_lens[b] - 1); + } + const auto batched_output = RunGQAPackedQKVRotaryPrefill( + batch_size, sequence_length, num_heads, kv_num_heads, head_size, + seqlens_k_data, packed_batched, target_ep); + + // Each batch's real-last-token output (used to predict next token) must match + // its single-prompt reference. The tolerance is loose enough for fp16 rounding + // while still catching the underflow bug (which produces values that differ + // by orders of magnitude or are NaN/Inf). + constexpr float tolerance = 5e-3f; + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + const int q_last = real_len - 1; + const float* batched_last = + batched_output.data() + (b * sequence_length + q_last) * hidden_size; + const float* ref_last = ref_outputs[b].data() + q_last * hidden_size; + for (int c = 0; c < hidden_size; ++c) { + EXPECT_NEAR(batched_last[c], ref_last[c], tolerance) + << "batch " << b << " real_len=" << real_len + << " channel " << c << " mismatch"; + } + } +} + +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kCuda); +} + +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu); +} + } // namespace test } // namespace onnxruntime