Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ad61684
wip: port MTP architecture
SamuelOliveirads Dec 28, 2025
b75f70e
Refactors `server_slot` to support generic speculative decoding (MTP …
SamuelOliveirads Dec 29, 2025
f9c4f6c
core: enable hybrid outputs (logits + embeddings) for MTP support
SamuelOliveirads Dec 29, 2025
b61daeb
fix(mtp): correct KV-cache slot finding for updates
SamuelOliveirads Jan 1, 2026
c03ae51
fix(mtp): persist hidden states to prevent context corruption during …
SamuelOliveirads Jan 2, 2026
ab6f4bb
refactor(mtp): clean unused code
SamuelOliveirads Feb 5, 2026
ec2d1a0
fix(mtp): update server to new functions name
SamuelOliveirads Feb 7, 2026
9317463
fix(mtp): fix graph and save hidden state
SamuelOliveirads Feb 8, 2026
d3465f1
mtp: refactor integration, context params and kv cache search
SamuelOliveirads Feb 9, 2026
2539f4f
mtp: fix hidden state extraction and speculative acceptance flow
SamuelOliveirads Feb 9, 2026
07e4936
server: fix MTP warmup for long prompts and reset token buffer
SamuelOliveirads Feb 12, 2026
d088faa
llama: refactor MTP operation state to context parameters
SamuelOliveirads Feb 13, 2026
97ec50e
server: fix n_past calculation in MTP acceptance
SamuelOliveirads Feb 13, 2026
573170e
llama: fix mtp enable flags
SamuelOliveirads Feb 13, 2026
5260bf2
Merge branch 'main' into feat-glm-mtp
SamuelOliveirads Feb 20, 2026
b4a2c88
speculative: refactor MTP to use common_speculative interface
SamuelOliveirads Feb 20, 2026
b8f27f3
context: remove unused signatures
SamuelOliveirads Feb 20, 2026
dd684fb
clip: fix deprecated enum-enum conversion warning
SamuelOliveirads Feb 20, 2026
0bcee4e
common: fix format string crash in help message
SamuelOliveirads Feb 20, 2026
1d5b287
context: fix mtp activation logic
SamuelOliveirads Feb 21, 2026
1da0758
llamat: always use the extracted embedding
SamuelOliveirads Feb 26, 2026
4d774d0
llama: get all embeddings to kv cache
SamuelOliveirads Feb 27, 2026
dc5ee27
Merge branch 'main' into fix-mtp-embedding
SamuelOliveirads Mar 11, 2026
1ab6327
llama: revert logit to not run mtp for not supported arch
SamuelOliveirads Mar 11, 2026
5eec0d3
llama: allocate all the n_outputs for MTP
SamuelOliveirads Mar 19, 2026
301f3db
wip
SamuelOliveirads Mar 20, 2026
6236fb3
server-context: get only the last embedding for hidden state
SamuelOliveirads Mar 20, 2026
f548ac1
ggml-backend: fix array of bounds in debug build
SamuelOliveirads Mar 20, 2026
d53dfc7
server-context: run mt kv update to each prompt batch
SamuelOliveirads Mar 20, 2026
94c8184
revert segmentation fault fixes
SamuelOliveirads Mar 21, 2026
8a47d51
Merge branch 'main' into fix-mtp-embedding
SamuelOliveirads Mar 21, 2026
c81d493
glm-mtp(feat): optimize graph embedding and recursive drafting
SamuelOliveirads Mar 21, 2026
4e4fd95
glm5-mtp(feat): add glm 5 mtp logic
SamuelOliveirads Mar 25, 2026
f978268
Merge branch 'main' into feat/glm5-mtp
SamuelOliveirads Mar 25, 2026
1c8af93
Merge branch 'main' into feat/glm5-mtp
SamuelOliveirads Mar 26, 2026
deb13ea
wip
SamuelOliveirads May 2, 2026
767ebca
glm-mtp: standardize the MTP graph
SamuelOliveirads May 3, 2026
0ead56d
glm 5 mtp: apply post-layer cvec
SamuelOliveirads May 4, 2026
284d754
Merge branch 'main' into feat/glm5-mtp
SamuelOliveirads May 4, 2026
b3a3be0
glm 5 mtp: mark head as mandatory
SamuelOliveirads May 4, 2026
9edded3
Merge remote-tracking branch 'origin/main' into feat/glm5-mtp
SamuelOliveirads May 17, 2026
ee51a7a
get normed embeddings for glm 5
SamuelOliveirads May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
372 changes: 371 additions & 1 deletion src/graphs/build_deepseek2.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/graphs/build_gemma4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() {
// not required for correct inference — the full-vocab matmul against the tied output
// weight still yields valid per-token logits.
{
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb, false);
cb(logits, "result_output", -1);
}
ggml_build_forward_expand(gf, logits);
Expand Down
11 changes: 9 additions & 2 deletions src/llama-build-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2048,25 +2048,30 @@ ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context
}

ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur,
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) {
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb, bool add_normed_name) {
// lm_head
if (output->extra) {
auto split_output = (ggml_split_tensor_t *)output->extra;
auto split_output_norm = output_norm && output_norm->extra ? (ggml_split_tensor_t *)output_norm->extra : nullptr;
std::vector<ggml_tensor *> o;
o.reserve(split_output->n_device);
ggml_tensor * last_norm = nullptr;
for (int id = 0; id < split_output->n_device; ++id) {
auto split = split_output->splits[id];
if (!split) continue;
if (output_norm) {
auto the_norm = split_output_norm ? split_output_norm->splits[id] : output_norm;
auto cur_normed = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, the_norm, NULL, LLM_NORM_RMS, cb, -1);
last_norm = cur_normed;
cb(cur_normed, "result_norm", 1000*(id+1));
o.push_back(llm_build_context::llm_build_lora_mm(lctx, ctx, split, cur_normed));
} else {
o.push_back(llm_build_context::llm_build_lora_mm(lctx, ctx, split, cur));
}
cb(o.back(), "output", id);
if (add_normed_name && last_norm) {
cb(last_norm, "result_norm", -1);
}
}
GGML_ASSERT(!o.empty());
if (o.size() == 1) {
Expand All @@ -2090,7 +2095,9 @@ ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context
}
if (output_norm) {
cur = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, output_norm, NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
if (add_normed_name) {
cb(cur, "result_norm", -1);
}
}
cur = llm_build_context::llm_build_lora_mm(lctx, ctx, output, cur);
}
Expand Down
10 changes: 9 additions & 1 deletion src/llama-build-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ llm_expert_gating_func_type gating_op,
static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * output, const llm_build_cb & cb);

static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur,
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb);
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb, bool add_normed_name = true);

static ggml_tensor * do_split_norm(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * the_norm, const llama_hparams & hparams,
const llm_build_cb & cb, int id, int il_cb, bool is_norm);
Expand All @@ -466,6 +466,14 @@ llm_expert_gating_func_type gating_op,
struct ggml_tensor * rope_cache
);

struct ggml_tensor * build_deepseek2_mtp(
const struct llama_layer & mtp_layer,
struct ggml_tensor * prev_embeddings,
struct ggml_cgraph * gf,
struct ggml_tensor * inp_pos,
struct ggml_tensor * rope_cache
);

struct ggml_tensor * build_qwen35_mtp(
const struct llama_layer & mtp_layer,
struct ggml_tensor * prev_embeddings,
Expand Down
8 changes: 6 additions & 2 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1379,8 +1379,12 @@ void llm_load_hparams(
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);

// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
if (model.mtp) {
hparams.n_layer_kv_from_start = hparams.n_layer;
}
else {
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
}

switch (hparams.n_layer) {
case 79: model.type = MODEL_744B_A40B; break;
Expand Down
23 changes: 13 additions & 10 deletions src/llama-load-tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2600,8 +2600,11 @@ bool create_tensors_helper::create_glm_dsa_tensors(const LLM_TN & tn) {
static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers;

int flags = 0;
if (is_mtp_layer) {
flags |= llama_model_loader::TENSOR_SKIP | llama_model_loader::TENSOR_NOT_REQUIRED;
// Skip loading MTP layers if the feature is disabled
if (!model.mtp) {
if (is_mtp_layer) {
flags |= llama_model_loader::TENSOR_SKIP | llama_model_loader::TENSOR_NOT_REQUIRED;
}
}
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
Expand Down Expand Up @@ -2679,14 +2682,14 @@ bool create_tensors_helper::create_glm_dsa_tensors(const LLM_TN & tn) {
}

if (is_mtp_layer) {
layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.enorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);

// Optional tensors
layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.enorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);

// Optional tensors
layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags);
}
}
return use_mmap_buffer;
Expand Down
9 changes: 6 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,7 @@ static bool llama_kv_cache_init(
}

int n_mla = 0;
int n_kv_active_layers = 0;
const int64_t n_mtp_first_layer = n_layer - hparams.nextn_predict_layers;
for (int i = 0; i < (int) n_layer; i++) {
// For MTP-only context, skip KV allocation for non-MTP layers
Expand All @@ -901,6 +902,7 @@ static bool llama_kv_cache_init(
}
continue;
}
n_kv_active_layers++;
const bool qnext_recurrent = llama_is_recurrent_layer(hparams, i);
const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
Expand Down Expand Up @@ -1061,8 +1063,8 @@ static bool llama_kv_cache_init(
cache.v_l.push_back(v);
}
}
if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
if (is_mla_attn && cparams.mla_attn && n_mla < n_kv_active_layers && n_mla > 0) {
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d active KV layers having MLA enabled\n", __func__, n_mla, n_kv_active_layers);
LLAMA_LOG_ERROR("%s: bailing out\n", __func__);
GGML_ABORT("fatal error");
}
Expand Down Expand Up @@ -6258,7 +6260,8 @@ struct llama_context * llama_init_from_model(

if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 &&
model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 &&
model->arch != LLM_ARCH_GEMMA4_MTP && cparams.mtp != 0) {
model->arch != LLM_ARCH_GEMMA4_MTP && model->arch != LLM_ARCH_GLM_DSA &&
cparams.mtp != 0) {
cparams.mtp = 0;
}

Expand Down