diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 08731f5381c..5e5b9c9054e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2277,7 +2277,12 @@ ggml_tensor * llm_graph_context::build_attn( const int64_t padded_v_head = v->ne[0]; if (padded_v_head != orig_v_head) { // Reshape to 4D, extract original head_dim, reshape back to 2D - const int64_t n_head_v = hparams.n_head_kv(il); + // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens), + // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head + // (Q-head count) so GQA models with n_head != n_head_kv (e.g. + // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element + // count check in ggml_reshape_3d. + const int64_t n_head_v = hparams.n_head(il); const int64_t n_tokens_cur = cur->ne[1]; cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); // ggml_view_3d to extract first orig_v_head elements per head @@ -2399,7 +2404,12 @@ ggml_tensor * llm_graph_context::build_attn( const int64_t padded_v_head = v->ne[0]; // padded V head_dim in cache if (padded_v_head != orig_v_head) { // cur is 2D: (padded_v_head * n_head, n_tokens) after build_attn_mha - const int64_t n_head_v = hparams.n_head_kv(il); + // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens), + // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head + // (Q-head count) so GQA models with n_head != n_head_kv (e.g. + // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element + // count check in ggml_reshape_3d. + const int64_t n_head_v = hparams.n_head(il); const int64_t n_tokens_cur = cur->ne[1]; cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur, @@ -2514,7 +2524,12 @@ ggml_tensor * llm_graph_context::build_attn( const int64_t orig_v_head = hparams.n_embd_head_v(il); const int64_t padded_v_head = v->ne[0]; if (padded_v_head != orig_v_head) { - const int64_t n_head_v = hparams.n_head_kv(il); + // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens), + // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head + // (Q-head count) so GQA models with n_head != n_head_kv (e.g. + // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element + // count check in ggml_reshape_3d. + const int64_t n_head_v = hparams.n_head(il); const int64_t n_tokens_cur = cur->ne[1]; cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,