Skip to content

Commit cde3e1a

Browse files
authored
Merge pull request #116 from TheTom/fix/issue-78-gqa-reshape
fix: build_attn V-padded reshape uses Q-head count, not KV-head count (#78)
2 parents 11a241d + b6f8e7f commit cde3e1a

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

src/llama-graph.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,7 +2277,12 @@ ggml_tensor * llm_graph_context::build_attn(
22772277
const int64_t padded_v_head = v->ne[0];
22782278
if (padded_v_head != orig_v_head) {
22792279
// Reshape to 4D, extract original head_dim, reshape back to 2D
2280-
const int64_t n_head_v = hparams.n_head_kv(il);
2280+
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2281+
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2282+
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2283+
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2284+
// count check in ggml_reshape_3d.
2285+
const int64_t n_head_v = hparams.n_head(il);
22812286
const int64_t n_tokens_cur = cur->ne[1];
22822287
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
22832288
// ggml_view_3d to extract first orig_v_head elements per head
@@ -2399,7 +2404,12 @@ ggml_tensor * llm_graph_context::build_attn(
23992404
const int64_t padded_v_head = v->ne[0]; // padded V head_dim in cache
24002405
if (padded_v_head != orig_v_head) {
24012406
// cur is 2D: (padded_v_head * n_head, n_tokens) after build_attn_mha
2402-
const int64_t n_head_v = hparams.n_head_kv(il);
2407+
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2408+
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2409+
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2410+
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2411+
// count check in ggml_reshape_3d.
2412+
const int64_t n_head_v = hparams.n_head(il);
24032413
const int64_t n_tokens_cur = cur->ne[1];
24042414
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
24052415
cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,
@@ -2514,7 +2524,12 @@ ggml_tensor * llm_graph_context::build_attn(
25142524
const int64_t orig_v_head = hparams.n_embd_head_v(il);
25152525
const int64_t padded_v_head = v->ne[0];
25162526
if (padded_v_head != orig_v_head) {
2517-
const int64_t n_head_v = hparams.n_head_kv(il);
2527+
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2528+
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2529+
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2530+
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2531+
// count check in ggml_reshape_3d.
2532+
const int64_t n_head_v = hparams.n_head(il);
25182533
const int64_t n_tokens_cur = cur->ne[1];
25192534
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
25202535
cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,

0 commit comments

Comments
 (0)