Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2277,7 +2277,12 @@ ggml_tensor * llm_graph_context::build_attn(
const int64_t padded_v_head = v->ne[0];
if (padded_v_head != orig_v_head) {
// Reshape to 4D, extract original head_dim, reshape back to 2D
const int64_t n_head_v = hparams.n_head_kv(il);
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
// count check in ggml_reshape_3d.
const int64_t n_head_v = hparams.n_head(il);
const int64_t n_tokens_cur = cur->ne[1];
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
// ggml_view_3d to extract first orig_v_head elements per head
Expand Down Expand Up @@ -2399,7 +2404,12 @@ ggml_tensor * llm_graph_context::build_attn(
const int64_t padded_v_head = v->ne[0]; // padded V head_dim in cache
if (padded_v_head != orig_v_head) {
// cur is 2D: (padded_v_head * n_head, n_tokens) after build_attn_mha
const int64_t n_head_v = hparams.n_head_kv(il);
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
// count check in ggml_reshape_3d.
const int64_t n_head_v = hparams.n_head(il);
const int64_t n_tokens_cur = cur->ne[1];
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,
Expand Down Expand Up @@ -2514,7 +2524,12 @@ ggml_tensor * llm_graph_context::build_attn(
const int64_t orig_v_head = hparams.n_embd_head_v(il);
const int64_t padded_v_head = v->ne[0];
if (padded_v_head != orig_v_head) {
const int64_t n_head_v = hparams.n_head_kv(il);
// Fix #78 (bingh0): cur shape post-MHA is (n_embd_head * n_head, n_tokens),
// not (n_embd_head * n_head_kv, n_tokens). Reshape needs n_head
// (Q-head count) so GQA models with n_head != n_head_kv (e.g.
// Qwen2.5-0.5B head_dim=64 padded → 128) don't fail the element
// count check in ggml_reshape_3d.
const int64_t n_head_v = hparams.n_head(il);
const int64_t n_tokens_cur = cur->ne[1];
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,
Expand Down
Loading