Skip to content

Commit dbf3d40

Browse files
TheTomclaude
andcommitted
fix(llama-graph): n_head_v reshape uses Q-head count, not KV-head count (#78)
Post-attention V-padded reshape in build_attn was using hparams.n_head_kv(il), but cur returned from build_attn_mha has shape (n_embd_head * n_head, n_tokens) — n_head is the Q-head count. On GQA models where n_head != n_head_kv (e.g. Qwen2.5-0.5B with head_dim=64 padded → 128, n_head=14, n_head_kv=2), the reshape element count fails the assertion in ggml_reshape_3d and the process aborts. Symptom: GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2) at ggml.c:3656. Reported and diagnosed by @bingh0 in #78. Verified locally on Qwen2.5-7B (head_dim=128, no padding, regression check passes) and on AMD MI300X with Qwen2.5-0.5B (head_dim=64, was crashing pre-fix). Three sites fixed (lines 2285, 2412, 2532 — same idiom in three build_attn overloads). Closes #78. Likely also closes #108 (speculative decoding hits the same assertion). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 157f27f commit dbf3d40

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

src/llama-graph.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,7 +2277,12 @@ ggml_tensor * llm_graph_context::build_attn(
22772277
const int64_t padded_v_head = v->ne[0];
22782278
if (padded_v_head != orig_v_head) {
22792279
// Reshape to 4D, extract original head_dim, reshape back to 2D
2280-
const int64_t n_head_v = hparams.n_head_kv(il);
2280+
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2281+
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2282+
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2283+
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2284+
// count check in ggml_reshape_3d.
2285+
const int64_t n_head_v = hparams.n_head(il);
22812286
const int64_t n_tokens_cur = cur->ne[1];
22822287
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
22832288
// ggml_view_3d to extract first orig_v_head elements per head
@@ -2399,7 +2404,12 @@ ggml_tensor * llm_graph_context::build_attn(
23992404
const int64_t padded_v_head = v->ne[0]; // padded V head_dim in cache
24002405
if (padded_v_head != orig_v_head) {
24012406
// cur is 2D: (padded_v_head * n_head, n_tokens) after build_attn_mha
2402-
const int64_t n_head_v = hparams.n_head_kv(il);
2407+
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2408+
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2409+
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2410+
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2411+
// count check in ggml_reshape_3d.
2412+
const int64_t n_head_v = hparams.n_head(il);
24032413
const int64_t n_tokens_cur = cur->ne[1];
24042414
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
24052415
cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,
@@ -2514,7 +2524,12 @@ ggml_tensor * llm_graph_context::build_attn(
25142524
const int64_t orig_v_head = hparams.n_embd_head_v(il);
25152525
const int64_t padded_v_head = v->ne[0];
25162526
if (padded_v_head != orig_v_head) {
2517-
const int64_t n_head_v = hparams.n_head_kv(il);
2527+
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
2528+
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
2529+
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
2530+
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
2531+
// count check in ggml_reshape_3d.
2532+
const int64_t n_head_v = hparams.n_head(il);
25182533
const int64_t n_tokens_cur = cur->ne[1];
25192534
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
25202535
cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,

0 commit comments

Comments
 (0)