Skip to content

fix: build_attn V-padded reshape uses Q-head count, not KV-head count (#78)#116

Merged
TheTom merged 1 commit intofeature/turboquant-kv-cachefrom
fix/issue-78-gqa-reshape
May 1, 2026
Merged

fix: build_attn V-padded reshape uses Q-head count, not KV-head count (#78)#116
TheTom merged 1 commit intofeature/turboquant-kv-cachefrom
fix/issue-78-gqa-reshape

Conversation

@TheTom
Copy link
Copy Markdown
Owner

@TheTom TheTom commented May 1, 2026

Summary

Fixes #78 (and likely #108 — same assertion in the speculative decoding path). One-liner diagnosed by @bingh0:

hparams.n_head_kv(il) is used to reshape the head dimension, which fails for models where n_head is not equal to n_head_kv. Switching to hparams.n_head(il) fixes the reshape.

The post-MHA reshape in build_attn uses hparams.n_head_kv(il), but cur returned from build_attn_mha has shape (n_embd_head * n_head, n_tokens)n_head is the Q-head count, not the KV-head count. On GQA models where n_head != n_head_kv, the element count fails the assertion in ggml_reshape_3d and the process aborts with:

GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2) failed
in ggml_reshape_3d, called from llm_graph_context::build_attn

The bug only triggers when V padding is also active (head_dim < 128 → padded to 128) AND n_head ≠ n_head_kv, which is why it didn't surface on Qwen2.5-7B (head_dim=128, no padding) but does on Qwen2.5-0.5B (head_dim=64 → padded → reshape executes → crash).

Three identical sites in three build_attn overloads (lines 2285, 2412, 2532).

Validation

AMD MI300X (gfx942, ROCm 7.2)

Test Before After
Qwen2.5-0.5B-Instruct Q4_K_M, sym turbo3 (the crash case) GGML_ASSERT abort, core dumped ✅ runs to completion
Qwen3-8B BF16, q8_0/turbo3 speed regression 122.61 ± 0.19 t/s 121.82 ± 0.14 t/s (−0.6%, noise floor)

Metal (M5 Max), regression check

Test Before After
Qwen2.5-7B-Instruct Q8_0, sym turbo3 (auto-asym fires) PPL 6.6594 PPL 6.6594 (exact)
Mistral-Small-24B Q4_K_M, sym turbo3 (no auto-asym) PPL 6.2792 PPL 6.2792 (exact)

PPL match exactly — the fix is element-count correctness, not numerical change.

Caveats

  • Qwen2.5-0.5B post-fix lands at PPL 24538. That's a separate model/quant compatibility issue (turbo3 V on a 64-dim padded head with sym turbo doesn't recover well even with auto-asym), not introduced by this fix. The point of this PR is "the process no longer aborts before getting to a result" — the resulting result happens to be poor on this specific small model, but it's a result instead of a SIGABRT.
  • @bingh0 also reported reproducing the fix locally on LFM2 + gemma-4-e4b-it + gemma-4-e2b-it. I haven't independently re-validated those models; flagging in case anyone wants to confirm.

Closes

Credit: @bingh0 for the diagnosis and proposed fix.

🤖 Generated with Claude Code

…nt (#78)

Post-attention V-padded reshape in build_attn was using hparams.n_head_kv(il),
but cur returned from build_attn_mha has shape (n_embd_head * n_head, n_tokens) —
n_head is the Q-head count. On GQA models where n_head != n_head_kv (e.g.
Qwen2.5-0.5B with head_dim=64 padded → 128, n_head=14, n_head_kv=2), the
reshape element count fails the assertion in ggml_reshape_3d and the process
aborts.

Symptom: GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2) at ggml.c:3656.

Reported and diagnosed by @bingh0 in #78. Verified
locally on Qwen2.5-7B (head_dim=128, no padding, regression check passes) and
on AMD MI300X with Qwen2.5-0.5B (head_dim=64, was crashing pre-fix).

Three sites fixed (lines 2285, 2412, 2532 — same idiom in three build_attn
overloads).

Closes #78. Likely also closes #108 (speculative decoding hits the same
assertion).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@TheTom TheTom force-pushed the fix/issue-78-gqa-reshape branch from dbf3d40 to b6f8e7f Compare May 1, 2026 13:33
@TheTom TheTom merged commit cde3e1a into feature/turboquant-kv-cache May 1, 2026
22 of 50 checks passed
@TheTom TheTom deleted the fix/issue-78-gqa-reshape branch May 1, 2026 13:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant