Skip to content

Commit 44e6274

Browse files
jan-wassenbergcopybara-github
authored andcommitted
1.07x speedup: merge MQA parallel sections as suggested by @veluca93
PiperOrigin-RevId: 621772392
1 parent ede337f commit 44e6274

1 file changed

Lines changed: 1 addition & 5 deletions

File tree

gemma.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,18 +405,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
405405
});
406406
} else {
407407
// Multi-Query Attention
408-
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
409-
ProjQ(head, head * kQKVDim * kModelDim);
410-
});
411-
412408
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
413409
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
414410
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
415411
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
416-
417412
ProjKV(k_offset, v_offset, kv_offset);
418413

419414
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
415+
ProjQ(head, head * kQKVDim * kModelDim);
420416
Attn(head, 0);
421417
});
422418
}

0 commit comments

Comments
 (0)