Skip to content

Commit 607b4e8

Browse files
feich-msclaude
andcommitted
Fix kv_empty layers failing with sliding window on long sequences
When total_sequence_length exceeds local_window_size, the sliding window check was blocking flash attention for kv_empty (shared KV) layers. This is incorrect because sliding window is irrelevant for these layers — they have no local KV cache and reuse another layer's already-computed cache. Add regression test for this case. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
1 parent 53a98a1 commit 607b4e8

2 files changed

Lines changed: 67 additions & 1 deletion

File tree

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
283283
// Use a sliding window if the total sequence exceeds the window's length.
284284
bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_);
285285
bool will_use_flash_attention = false;
286-
if (!use_smooth_softmax_ && !use_sliding_window) {
286+
// For kv_empty layers (shared KV), sliding window is irrelevant — there's no new KV to window
287+
// over, the layer reuses another layer's already-computed KV cache. Flash attention is required
288+
// for these layers, so we bypass the sliding window check to allow it.
289+
if (!use_smooth_softmax_ && (!use_sliding_window || kv_empty)) {
287290
// Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking
288291
WebgpuAttentionParameters temp_params = parameters;
289292
temp_params.is_packed_qkv_ = false;

onnxruntime/test/contrib_ops/group_query_attention_op_test.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,5 +1854,68 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Rotary) {
18541854
ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_Rotary_WebGPU_vs_CPU");
18551855
}
18561856

1857+
// WebGPU: kv_sequence_length=0 with sliding window active (total_seq > local_window_size).
1858+
// Regression test: sliding window must not block flash attention for kv_empty layers.
1859+
TEST(GroupQueryAttentionTest, WebGPU_SharedKV_SlidingWindow) {
1860+
auto webgpu_ep = DefaultWebGpuExecutionProvider();
1861+
if (!webgpu_ep) {
1862+
GTEST_SKIP() << "WebGPU EP not available";
1863+
}
1864+
1865+
constexpr int batch_size = 1;
1866+
constexpr int q_seq_len = 4;
1867+
constexpr int past_seq_len = 32;
1868+
constexpr int num_heads = 2;
1869+
constexpr int kv_num_heads = 1;
1870+
constexpr int head_size = 8;
1871+
constexpr int hidden_size = num_heads * head_size;
1872+
constexpr int kv_hidden_size = kv_num_heads * head_size;
1873+
constexpr int local_window_size = 16; // < past_seq_len to trigger sliding window
1874+
constexpr int total_seq_len = past_seq_len;
1875+
1876+
OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
1877+
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
1878+
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
1879+
tester.AddAttribute<int64_t>("local_window_size", static_cast<int64_t>(local_window_size));
1880+
1881+
std::vector<float> query_data(batch_size * q_seq_len * hidden_size);
1882+
std::vector<float> past_key_data(batch_size * kv_num_heads * past_seq_len * head_size);
1883+
std::vector<float> past_value_data(batch_size * kv_num_heads * past_seq_len * head_size);
1884+
for (size_t i = 0; i < query_data.size(); i++) query_data[i] = 0.1f * static_cast<float>(i % 7 + 1);
1885+
for (size_t i = 0; i < past_key_data.size(); i++) past_key_data[i] = 0.2f * static_cast<float>(i % 5 + 1);
1886+
for (size_t i = 0; i < past_value_data.size(); i++) past_value_data[i] = 0.3f * static_cast<float>(i % 3 + 1);
1887+
1888+
tester.AddInput<float>("query", {batch_size, q_seq_len, hidden_size}, query_data);
1889+
tester.AddInput<float>("key", {batch_size, 0, kv_hidden_size}, {});
1890+
tester.AddInput<float>("value", {batch_size, 0, kv_hidden_size}, {});
1891+
tester.AddInput<float>("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, past_key_data);
1892+
tester.AddInput<float>("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, past_value_data);
1893+
1894+
std::vector<int32_t> seqlens_k_data(batch_size, static_cast<int32_t>(total_seq_len - 1));
1895+
tester.AddInput<int32_t>("seqlens_k", {batch_size}, seqlens_k_data);
1896+
tester.AddInput<int32_t>("total_sequence_length", {1}, {static_cast<int32_t>(total_seq_len)});
1897+
1898+
tester.AddOptionalInputEdge<float>(); // cos_cache
1899+
tester.AddOptionalInputEdge<float>(); // sin_cache
1900+
tester.AddOptionalInputEdge<int64_t>(); // position_ids
1901+
tester.AddOptionalInputEdge<float>(); // attention_bias
1902+
tester.AddOptionalInputEdge<float>(); // head_sink
1903+
1904+
const int output_size = batch_size * q_seq_len * hidden_size;
1905+
tester.AddOutput<float>("output", {batch_size, q_seq_len, hidden_size},
1906+
std::vector<float>(output_size, 0.0f));
1907+
const int present_size = batch_size * kv_num_heads * past_seq_len * head_size;
1908+
tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, past_seq_len, head_size},
1909+
std::vector<float>(present_size, 0.0f));
1910+
tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, past_seq_len, head_size},
1911+
std::vector<float>(present_size, 0.0f));
1912+
1913+
tester.SetOutputTolerance(1e6f);
1914+
1915+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
1916+
execution_providers.push_back(DefaultWebGpuExecutionProvider());
1917+
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
1918+
}
1919+
18571920
} // namespace test
18581921
} // namespace onnxruntime

0 commit comments

Comments
 (0)