Skip to content

Commit cc59fcb

Browse files
committed
Fix tensor reshaping logic in build_attn to correctly handle head dimensions for MHA and GQA scenarios, addressing potential assertion failures in sparse-attention layouts.
1 parent 0dbf74d commit cc59fcb

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

src/llama-graph.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,9 +2232,15 @@ ggml_tensor * llm_graph_context::build_attn(
22322232
// cur is 2D: (n_embd_head * n_head, n_tokens) after build_attn_mha
22332233
const int64_t padded_v_head = v->ne[0];
22342234
if (padded_v_head != orig_v_head) {
2235-
// Reshape to 4D, extract original head_dim, reshape back to 2D
2236-
const int64_t n_head_v = hparams.n_head_kv(il);
2235+
// Reshape to 3D, extract original head_dim, reshape back to 2D.
2236+
// The MHA output carries one slice per QUERY head (n_head), not per
2237+
// KV head. Under GQA (n_head != n_head_kv) using n_head_kv(il) here
2238+
// mis-sizes the reshape and trips GGML_ASSERT — e.g. the lfm2moe
2239+
// sparse-attention layout (ATO-137). Derive the head count from the
2240+
// tensor so it is correct for both MHA and GQA.
22372241
const int64_t n_tokens_cur = cur->ne[1];
2242+
GGML_ASSERT(cur->ne[0] % padded_v_head == 0);
2243+
const int64_t n_head_v = cur->ne[0] / padded_v_head;
22382244
cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur);
22392245
// ggml_view_3d to extract first orig_v_head elements per head
22402246
cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur,

0 commit comments

Comments
 (0)