Skip to content

Commit 0dbf74d

Browse files
authored
Merge pull request #26 from boxwrench/feature/turboquant-kv-cache
fix(gemma4-mtp): resolve PARALLEL=2 multi-slot crash in Gemma 4 MTP speculative decoding
2 parents c419fd5 + 6564568 commit 0dbf74d

4 files changed

Lines changed: 12 additions & 11 deletions

File tree

src/llama-context.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,14 +1279,6 @@ bool llama_context::ensure_sched_mtp() {
12791279
return false;
12801280
}
12811281

1282-
llama_memory_context_ptr mctx = memory->init_full();
1283-
if (!mctx) {
1284-
LLAMA_LOG_ERROR("%s: failed to init memory context for MTP reserve\n", __func__);
1285-
sched_mtp.reset();
1286-
gf_res_prev_mtp.reset();
1287-
return false;
1288-
}
1289-
12901282
const uint32_t n_bb = model.mtp_assistant->hparams.n_embd_backbone;
12911283
auto data = std::make_shared<llama_ubatch::data_t>();
12921284
data->token.resize(1);
@@ -1321,6 +1313,14 @@ bool llama_context::ensure_sched_mtp() {
13211313
ub.output = data->output.data();
13221314
ub.data = data;
13231315

1316+
llama_memory_context_ptr mctx = kv_iswa->init_mtp(0, ub);
1317+
if (!mctx) {
1318+
LLAMA_LOG_ERROR("%s: failed to init memory context for MTP reserve\n", __func__);
1319+
sched_mtp.reset();
1320+
gf_res_prev_mtp.reset();
1321+
return false;
1322+
}
1323+
13241324
const uint32_t save_n_outputs = n_outputs;
13251325
n_outputs = 1;
13261326

src/llama-graph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@ void llm_graph_result::set_params(const llm_graph_params & params) {
947947

948948
llm_graph_context::llm_graph_context(const llm_graph_params & params) :
949949
arch (params.arch),
950+
gtype (params.gtype),
950951
hparams (params.hparams),
951952
cparams (params.cparams),
952953
ubatch (params.ubatch),
@@ -1899,7 +1900,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
18991900
const bool v_trans = v->nb[1] > v->nb[2];
19001901

19011902
// split the batch into streams if needed
1902-
const auto n_stream = k->ne[3];
1903+
const auto n_stream = (gtype == LLM_GRAPH_TYPE_MTP) ? 1 : k->ne[3];
19031904

19041905
q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
19051906

@@ -1930,7 +1931,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
19301931
if (v->type == GGML_TYPE_F32) {
19311932
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
19321933
}
1933-
19341934
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
19351935
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
19361936
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);

src/llama-graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,7 @@ using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_
742742

743743
struct llm_graph_context {
744744
const llm_arch arch;
745+
const llm_graph_type gtype;
745746

746747
const llama_hparams & hparams;
747748
const llama_cparams & cparams;

src/models/gemma4-assistant.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ static void gemma4_mtp_build_one_step(
106106
ggml_tensor * Qcur = gctx.build_lora_mm(mtp.layers[il].wq, cur);
107107
cb(Qcur, "Qcur", il);
108108

109-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
109+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, 1);
110110

111111
Qcur = gctx.build_norm(Qcur, mtp.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
112112
cb(Qcur, "Qcur_normed", il);

0 commit comments

Comments
 (0)