Skip to content

Commit 14a6c9e

Browse files
authored
Fix GroupQueryAttention right-padded rotary prefill CUDA test (#29218)
### Description The `GroupQueryAttentionTest.BatchedRightPaddedRotaryPrefill_CUDA` test (added in #29002) fed **fp32** inputs via `AddInput<float>`. The CUDA (and WebGPU) GroupQueryAttention kernels only register for `MLFloat16`/`BFloat16`, so the fp32 node silently fell back to the **CPU EP** — the `_CUDA` test never actually exercised the CUDA kernel it is named for. This surfaced as a CI failure on the CUDA test leg after #29002 and #29046 merged. This PR makes `RunGQAPackedQKVRotaryPrefill` feed **fp16** tensors when targeting CUDA EP, matching the existing `RunGQASharedKVFp16` convention and the test's own "loose enough for fp16 rounding" tolerance. The CPU code path is unchanged. ### Key Changes - `RunGQAPackedQKVRotaryPrefill` now branches on the target EP: - CUDA EP: inputs/outputs use `MLFloat16` (converted via `ToFloat16`), so the node is placed on the real GPU kernel. - WebGPU/CPU EP: unchanged (`float`). - Output is converted back to `float` for the existing comparison logic. ### Testing - `onnxruntime_provider_test --gtest_filter='GroupQueryAttentionTest.BatchedRightPaddedRotaryPrefill_CUDA'` → **PASSED** (now runs on the CUDA fp16 kernel). - Full `GroupQueryAttentionTest.*` suite → 47 passed, WebGPU-only tests skipped locally (no WebGPU EP), no regressions. ### Motivation and Context Restores genuine CUDA kernel coverage for the right-padded rotary prefill scenario and fixes the CI failure. Related: #29002, #29046.
1 parent 8c856fb commit 14a6c9e

1 file changed

Lines changed: 60 additions & 18 deletions

File tree

onnxruntime/test/contrib_ops/group_query_attention_op_test.cc

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,18 +2463,32 @@ static std::vector<float> RunGQAPackedQKVRotaryPrefill(
24632463
const int half_rotary = head_size / 2;
24642464
const int max_seq_len = sequence_length + 8;
24652465

2466+
// The CUDA GQA kernel only registers for MLFloat16/BFloat16, so float inputs
2467+
// silently fall back to the CPU EP. Feed fp16 tensors when targeting CUDA so
2468+
// the *_CUDA test genuinely exercises the CUDA kernel. The CPU and WebGPU
2469+
// kernels both support float (WebGpuSupportedFloatTypes = {float, MLFloat16}),
2470+
// so those paths keep fp32 for tighter numeric comparison.
2471+
const bool use_fp16 = target_ep == GqaTargetEp::kCuda;
2472+
24662473
OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
24672474
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
24682475
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
24692476
tester.AddAttribute<int64_t>("do_rotary", static_cast<int64_t>(1));
24702477

24712478
// Packed QKV: pass through `query` input, leave key/value as optional edges.
2472-
tester.AddInput<float>("query", {batch_size, sequence_length, qkv_hidden}, packed_qkv_data);
2473-
tester.AddOptionalInputEdge<float>(); // key (signals packed)
2474-
tester.AddOptionalInputEdge<float>(); // value (signals packed)
2475-
2476-
tester.AddOptionalInputEdge<float>(); // past_key
2477-
tester.AddOptionalInputEdge<float>(); // past_value
2479+
if (use_fp16) {
2480+
tester.AddInput<MLFloat16>("query", {batch_size, sequence_length, qkv_hidden}, ToFloat16(packed_qkv_data));
2481+
tester.AddOptionalInputEdge<MLFloat16>(); // key (signals packed)
2482+
tester.AddOptionalInputEdge<MLFloat16>(); // value (signals packed)
2483+
tester.AddOptionalInputEdge<MLFloat16>(); // past_key
2484+
tester.AddOptionalInputEdge<MLFloat16>(); // past_value
2485+
} else {
2486+
tester.AddInput<float>("query", {batch_size, sequence_length, qkv_hidden}, packed_qkv_data);
2487+
tester.AddOptionalInputEdge<float>(); // key (signals packed)
2488+
tester.AddOptionalInputEdge<float>(); // value (signals packed)
2489+
tester.AddOptionalInputEdge<float>(); // past_key
2490+
tester.AddOptionalInputEdge<float>(); // past_value
2491+
}
24782492

24792493
tester.AddInput<int32_t>("seqlens_k", {batch_size}, seqlens_k_data);
24802494
tester.AddInput<int32_t>("total_sequence_length", {1}, {total_sequence_length},
@@ -2490,21 +2504,40 @@ static std::vector<float> RunGQAPackedQKVRotaryPrefill(
24902504
sin_cache[pos * half_rotary + d] = std::sin(static_cast<float>(pos) * freq);
24912505
}
24922506
}
2493-
tester.AddInput<float>("cos_cache", {max_seq_len, half_rotary}, cos_cache);
2494-
tester.AddInput<float>("sin_cache", {max_seq_len, half_rotary}, sin_cache);
2507+
if (use_fp16) {
2508+
tester.AddInput<MLFloat16>("cos_cache", {max_seq_len, half_rotary}, ToFloat16(cos_cache));
2509+
tester.AddInput<MLFloat16>("sin_cache", {max_seq_len, half_rotary}, ToFloat16(sin_cache));
2510+
} else {
2511+
tester.AddInput<float>("cos_cache", {max_seq_len, half_rotary}, cos_cache);
2512+
tester.AddInput<float>("sin_cache", {max_seq_len, half_rotary}, sin_cache);
2513+
}
24952514

24962515
tester.AddOptionalInputEdge<int64_t>(); // position_ids
2497-
tester.AddOptionalInputEdge<float>(); // attention_bias
2498-
tester.AddOptionalInputEdge<float>(); // head_sink
2516+
if (use_fp16) {
2517+
tester.AddOptionalInputEdge<MLFloat16>(); // attention_bias
2518+
tester.AddOptionalInputEdge<MLFloat16>(); // head_sink
2519+
} else {
2520+
tester.AddOptionalInputEdge<float>(); // attention_bias
2521+
tester.AddOptionalInputEdge<float>(); // head_sink
2522+
}
24992523

25002524
const int output_size = batch_size * sequence_length * hidden_size;
2501-
tester.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
2502-
std::vector<float>(output_size, 0.0f));
25032525
const int present_size = batch_size * kv_num_heads * total_sequence_length * head_size;
2504-
tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size},
2505-
std::vector<float>(present_size, 0.0f));
2506-
tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, total_sequence_length, head_size},
2507-
std::vector<float>(present_size, 0.0f));
2526+
if (use_fp16) {
2527+
tester.AddOutput<MLFloat16>("output", {batch_size, sequence_length, hidden_size},
2528+
std::vector<MLFloat16>(output_size, MLFloat16(0.0f)));
2529+
tester.AddOutput<MLFloat16>("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size},
2530+
std::vector<MLFloat16>(present_size, MLFloat16(0.0f)));
2531+
tester.AddOutput<MLFloat16>("present_value", {batch_size, kv_num_heads, total_sequence_length, head_size},
2532+
std::vector<MLFloat16>(present_size, MLFloat16(0.0f)));
2533+
} else {
2534+
tester.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
2535+
std::vector<float>(output_size, 0.0f));
2536+
tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size},
2537+
std::vector<float>(present_size, 0.0f));
2538+
tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, total_sequence_length, head_size},
2539+
std::vector<float>(present_size, 0.0f));
2540+
}
25082541

25092542
tester.SetOutputTolerance(1e6f); // We fetch and compare outputs ourselves.
25102543

@@ -2513,8 +2546,17 @@ static std::vector<float> RunGQAPackedQKVRotaryPrefill(
25132546
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
25142547

25152548
auto fetches = tester.GetFetches();
2516-
const float* out_data = fetches[0].Get<Tensor>().Data<float>();
2517-
return std::vector<float>(out_data, out_data + output_size);
2549+
std::vector<float> result(output_size);
2550+
if (use_fp16) {
2551+
const MLFloat16* out_data = fetches[0].Get<Tensor>().Data<MLFloat16>();
2552+
for (int i = 0; i < output_size; ++i) {
2553+
result[i] = out_data[i].ToFloat();
2554+
}
2555+
} else {
2556+
const float* out_data = fetches[0].Get<Tensor>().Data<float>();
2557+
std::copy_n(out_data, output_size, result.begin());
2558+
}
2559+
return result;
25182560
}
25192561

25202562
// Inner helper: builds packed-QKV inputs, computes per-prompt references, runs

0 commit comments

Comments
 (0)