From b6f8e7f7289e1fd9991686582d3d0003106b792f Mon Sep 17 00:00:00 2001 From: TheTom Date: Fri, 1 May 2026 08:28:31 -0500 Subject: [PATCH] fix(llama-graph): n_head_v reshape uses Q-head count, not KV-head count (#78) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Post-attention V-padded reshape in build_attn was using hparams.n_head_kv(il), but cur returned from build_attn_mha has shape (n_embd_head * n_head, n_tokens) — n_head is the Q-head count. On GQA models where n_head != n_head_kv (e.g. Qwen2.5-0.5B with head_dim=64 padded → 128, n_head=14, n_head_kv=2), the reshape element count fails the assertion in ggml_reshape_3d and the process aborts. Symptom: GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2) at ggml.c:3656. Reported and diagnosed by @bingh0 in TheTom/llama-cpp-turboquant#78. Verified locally on Qwen2.5-7B (head_dim=128, no padding, regression check passes) and on AMD MI300X with Qwen2.5-0.5B (head_dim=64, was crashing pre-fix). Three sites fixed (lines 2285, 2412, 2532 — same idiom in three build_attn overloads). Closes #78. Likely also closes #108 (speculative decoding hits the same assertion). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/llama-graph.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 08731f5381c0..5e5b9c9054e1 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2277,7 +2277,12 @@ ggml_tensor * llm_graph_context::build_attn( const int64_t padded_v_head = v->ne[0]; if (padded_v_head != orig_v_head) { // Reshape to 4D, extract original head_dim, reshape back to 2D - const int64_t n_head_v = hparams.n_head_kv(il); + // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens), + // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head + // (Q-head count) so GQA models with n_head != n_head_kv (e.g. + // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element + // count check in ggml_reshape_3d. + const int64_t n_head_v = hparams.n_head(il); const int64_t n_tokens_cur = cur->ne[1]; cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); // ggml_view_3d to extract first orig_v_head elements per head @@ -2399,7 +2404,12 @@ ggml_tensor * llm_graph_context::build_attn( const int64_t padded_v_head = v->ne[0]; // padded V head_dim in cache if (padded_v_head != orig_v_head) { // cur is 2D: (padded_v_head * n_head, n_tokens) after build_attn_mha - const int64_t n_head_v = hparams.n_head_kv(il); + // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens), + // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head + // (Q-head count) so GQA models with n_head != n_head_kv (e.g. + // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element + // count check in ggml_reshape_3d. + const int64_t n_head_v = hparams.n_head(il); const int64_t n_tokens_cur = cur->ne[1]; cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur, @@ -2514,7 +2524,12 @@ ggml_tensor * llm_graph_context::build_attn( const int64_t orig_v_head = hparams.n_embd_head_v(il); const int64_t padded_v_head = v->ne[0]; if (padded_v_head != orig_v_head) { - const int64_t n_head_v = hparams.n_head_kv(il); + // Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens), + // not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head + // (Q-head count) so GQA models with n_head != n_head_kv (e.g. + // Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element + // count check in ggml_reshape_3d. + const int64_t n_head_v = hparams.n_head(il); const int64_t n_tokens_cur = cur->ne[1]; cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,