@@ -2277,7 +2277,12 @@ ggml_tensor * llm_graph_context::build_attn(
22772277 const int64_t padded_v_head = v->ne [0 ];
22782278 if (padded_v_head != orig_v_head) {
22792279 // Reshape to 4D, extract original head_dim, reshape back to 2D
2280- const int64_t n_head_v = hparams.n_head_kv (il);
2280+ // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2281+ // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2282+ // (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2283+ // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2284+ // count check in ggml_reshape_3d.
2285+ const int64_t n_head_v = hparams.n_head (il);
22812286 const int64_t n_tokens_cur = cur->ne [1 ];
22822287 cur = ggml_reshape_3d (ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
22832288 // ggml_view_3d to extract first orig_v_head elements per head
@@ -2399,7 +2404,12 @@ ggml_tensor * llm_graph_context::build_attn(
23992404 const int64_t padded_v_head = v->ne [0 ]; // padded V head_dim in cache
24002405 if (padded_v_head != orig_v_head) {
24012406 // cur is 2D: (padded_v_head * n_head, n_tokens) after build_attn_mha
2402- const int64_t n_head_v = hparams.n_head_kv (il);
2407+ // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2408+ // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2409+ // (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2410+ // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2411+ // count check in ggml_reshape_3d.
2412+ const int64_t n_head_v = hparams.n_head (il);
24032413 const int64_t n_tokens_cur = cur->ne [1 ];
24042414 cur = ggml_reshape_3d (ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
24052415 cur = ggml_view_3d (ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,
@@ -2514,7 +2524,12 @@ ggml_tensor * llm_graph_context::build_attn(
25142524 const int64_t orig_v_head = hparams.n_embd_head_v (il);
25152525 const int64_t padded_v_head = v->ne [0 ];
25162526 if (padded_v_head != orig_v_head) {
2517- const int64_t n_head_v = hparams.n_head_kv (il);
2527+ // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2528+ // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2529+ // (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2530+ // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2531+ // count check in ggml_reshape_3d.
2532+ const int64_t n_head_v = hparams.n_head (il);
25182533 const int64_t n_tokens_cur = cur->ne [1 ];
25192534 cur = ggml_reshape_3d (ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
25202535 cur = ggml_view_3d (ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,
0 commit comments