@@ -1854,5 +1854,68 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_Rotary) {
18541854 ExpectOutputsMatch (webgpu_output, cpu_output, 0 .05f , " SharedKV_Rotary_WebGPU_vs_CPU" );
18551855}
18561856
1857+ // WebGPU: kv_sequence_length=0 with sliding window active (total_seq > local_window_size).
1858+ // Regression test: sliding window must not block flash attention for kv_empty layers.
1859+ TEST (GroupQueryAttentionTest, WebGPU_SharedKV_SlidingWindow) {
1860+ auto webgpu_ep = DefaultWebGpuExecutionProvider ();
1861+ if (!webgpu_ep) {
1862+ GTEST_SKIP () << " WebGPU EP not available" ;
1863+ }
1864+
1865+ constexpr int batch_size = 1 ;
1866+ constexpr int q_seq_len = 4 ;
1867+ constexpr int past_seq_len = 32 ;
1868+ constexpr int num_heads = 2 ;
1869+ constexpr int kv_num_heads = 1 ;
1870+ constexpr int head_size = 8 ;
1871+ constexpr int hidden_size = num_heads * head_size;
1872+ constexpr int kv_hidden_size = kv_num_heads * head_size;
1873+ constexpr int local_window_size = 16 ; // < past_seq_len to trigger sliding window
1874+ constexpr int total_seq_len = past_seq_len;
1875+
1876+ OpTester tester (" GroupQueryAttention" , 1 , onnxruntime::kMSDomain );
1877+ tester.AddAttribute <int64_t >(" num_heads" , static_cast <int64_t >(num_heads));
1878+ tester.AddAttribute <int64_t >(" kv_num_heads" , static_cast <int64_t >(kv_num_heads));
1879+ tester.AddAttribute <int64_t >(" local_window_size" , static_cast <int64_t >(local_window_size));
1880+
1881+ std::vector<float > query_data (batch_size * q_seq_len * hidden_size);
1882+ std::vector<float > past_key_data (batch_size * kv_num_heads * past_seq_len * head_size);
1883+ std::vector<float > past_value_data (batch_size * kv_num_heads * past_seq_len * head_size);
1884+ for (size_t i = 0 ; i < query_data.size (); i++) query_data[i] = 0 .1f * static_cast <float >(i % 7 + 1 );
1885+ for (size_t i = 0 ; i < past_key_data.size (); i++) past_key_data[i] = 0 .2f * static_cast <float >(i % 5 + 1 );
1886+ for (size_t i = 0 ; i < past_value_data.size (); i++) past_value_data[i] = 0 .3f * static_cast <float >(i % 3 + 1 );
1887+
1888+ tester.AddInput <float >(" query" , {batch_size, q_seq_len, hidden_size}, query_data);
1889+ tester.AddInput <float >(" key" , {batch_size, 0 , kv_hidden_size}, {});
1890+ tester.AddInput <float >(" value" , {batch_size, 0 , kv_hidden_size}, {});
1891+ tester.AddInput <float >(" past_key" , {batch_size, kv_num_heads, past_seq_len, head_size}, past_key_data);
1892+ tester.AddInput <float >(" past_value" , {batch_size, kv_num_heads, past_seq_len, head_size}, past_value_data);
1893+
1894+ std::vector<int32_t > seqlens_k_data (batch_size, static_cast <int32_t >(total_seq_len - 1 ));
1895+ tester.AddInput <int32_t >(" seqlens_k" , {batch_size}, seqlens_k_data);
1896+ tester.AddInput <int32_t >(" total_sequence_length" , {1 }, {static_cast <int32_t >(total_seq_len)});
1897+
1898+ tester.AddOptionalInputEdge <float >(); // cos_cache
1899+ tester.AddOptionalInputEdge <float >(); // sin_cache
1900+ tester.AddOptionalInputEdge <int64_t >(); // position_ids
1901+ tester.AddOptionalInputEdge <float >(); // attention_bias
1902+ tester.AddOptionalInputEdge <float >(); // head_sink
1903+
1904+ const int output_size = batch_size * q_seq_len * hidden_size;
1905+ tester.AddOutput <float >(" output" , {batch_size, q_seq_len, hidden_size},
1906+ std::vector<float >(output_size, 0 .0f ));
1907+ const int present_size = batch_size * kv_num_heads * past_seq_len * head_size;
1908+ tester.AddOutput <float >(" present_key" , {batch_size, kv_num_heads, past_seq_len, head_size},
1909+ std::vector<float >(present_size, 0 .0f ));
1910+ tester.AddOutput <float >(" present_value" , {batch_size, kv_num_heads, past_seq_len, head_size},
1911+ std::vector<float >(present_size, 0 .0f ));
1912+
1913+ tester.SetOutputTolerance (1e6f);
1914+
1915+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
1916+ execution_providers.push_back (DefaultWebGpuExecutionProvider ());
1917+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
1918+ }
1919+
18571920} // namespace test
18581921} // namespace onnxruntime
0 commit comments