Skip to content

Commit ef6547e

Browse files
committed
update comments
1 parent 607b4e8 commit ef6547e

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

onnxruntime/test/contrib_ops/group_query_attention_op_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ static void ExpectOutputsMatch(const std::vector<float>& a, const std::vector<fl
354354

355355
// ---------------------------------------------------------------------------
356356
// Tests for kv_sequence_length=0 with borrowed past_key/past_value
357-
// (shared KV pattern: empty K/V inputs, all KV data in past buffer)
357+
// (Gemma4 shared KV pattern: empty K/V inputs, all KV data in past buffer)
358358
// ---------------------------------------------------------------------------
359359

360360
// Helper: run GQA with empty K/V and past_key/past_value (shared KV pattern).
@@ -616,7 +616,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_LargeHeadSize_CPU) {
616616
EXPECT_FALSE(all_zero) << "Output should not be all zeros";
617617
}
618618

619-
// CPU: GQA ratio num_heads=8, kv_num_heads=1.
619+
// CPU: GQA ratio num_heads=8, kv_num_heads=1 (matches Gemma4 config).
620620
TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_GQARatio8_CPU) {
621621
constexpr int batch_size = 1;
622622
constexpr int q_seq_len = 1;
@@ -816,7 +816,7 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_GQARatio8_CUDA) {
816816
}
817817

818818
// ---------------------------------------------------------------------------
819-
// Shared KV tests with do_rotary=1
819+
// Shared KV tests with do_rotary=1 (Gemma4 primary use case)
820820
// ---------------------------------------------------------------------------
821821

822822
// Helper: run GQA with empty K/V, past_key/past_value, and do_rotary=1.
@@ -1754,7 +1754,7 @@ TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_MultiBatch) {
17541754
}
17551755

17561756
// ---------------------------------------------------------------------------
1757-
// WebGPU: shared KV tests (kv_sequence_length=0 pattern)
1757+
// WebGPU: shared KV tests (Gemma4 kv_sequence_length=0 pattern)
17581758
// Each test cross-checks WebGPU against CPU for correctness.
17591759
// ---------------------------------------------------------------------------
17601760

0 commit comments

Comments
 (0)