@@ -354,7 +354,7 @@ static void ExpectOutputsMatch(const std::vector<float>& a, const std::vector<fl
354354
355355// ---------------------------------------------------------------------------
356356// Tests for kv_sequence_length=0 with borrowed past_key/past_value
357- // (shared KV pattern: empty K/V inputs, all KV data in past buffer)
357+ // (Gemma4 shared KV pattern: empty K/V inputs, all KV data in past buffer)
358358// ---------------------------------------------------------------------------
359359
360360// Helper: run GQA with empty K/V and past_key/past_value (shared KV pattern).
@@ -616,7 +616,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_LargeHeadSize_CPU) {
616616 EXPECT_FALSE (all_zero) << " Output should not be all zeros" ;
617617}
618618
619- // CPU: GQA ratio num_heads=8, kv_num_heads=1.
619+ // CPU: GQA ratio num_heads=8, kv_num_heads=1 (matches Gemma4 config) .
620620TEST (GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_GQARatio8_CPU) {
621621 constexpr int batch_size = 1 ;
622622 constexpr int q_seq_len = 1 ;
@@ -816,7 +816,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_GQARatio8_CUDA) {
816816}
817817
818818// ---------------------------------------------------------------------------
819- // Shared KV tests with do_rotary=1
819+ // Shared KV tests with do_rotary=1 (Gemma4 primary use case)
820820// ---------------------------------------------------------------------------
821821
822822// Helper: run GQA with empty K/V, past_key/past_value, and do_rotary=1.
@@ -1754,7 +1754,7 @@ TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_MultiBatch) {
17541754}
17551755
17561756// ---------------------------------------------------------------------------
1757- // WebGPU: shared KV tests (kv_sequence_length=0 pattern)
1757+ // WebGPU: shared KV tests (Gemma4 kv_sequence_length=0 pattern)
17581758// Each test cross-checks WebGPU against CPU for correctness.
17591759// ---------------------------------------------------------------------------
17601760
0 commit comments