Skip to content

Commit a66d505

Browse files
graph: guard iswa kq_mask on its own buffer (ggml-org#24294)
A SWA-only draft head (e.g. StepFun MTP) leaves the base sub-cache empty, so its kq_mask buffer stays null and asserts at load. Guard each mask on its own buffer in set_input and can_reuse, base and swa. Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 1705d43 commit a66d505

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

src/llama-graph.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,15 +567,20 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
567567
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
568568
}
569569

570-
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
570+
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
571+
if (self_kq_mask && self_kq_mask->buffer) {
572+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
573+
}
571574

572575
// swa tensors may not be allocated if there are no SWA attention layers
573576
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
574577
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
575578
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
576579
}
577580

578-
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
581+
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
582+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
583+
}
579584

580585
if (self_k_rot) {
581586
mctx->get_base()->set_input_k_rot(self_k_rot);
@@ -607,15 +612,19 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
607612
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
608613
}
609614

610-
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
615+
if (self_kq_mask && self_kq_mask->buffer) {
616+
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
617+
}
611618

612619
// swa tensors may not be allocated if there are no SWA attention layers
613620
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
614621
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
615622
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
616623
}
617624

618-
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
625+
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
626+
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
627+
}
619628

620629
return res;
621630
}

0 commit comments

Comments
 (0)