From a55493bbda346d0d65ae720182e5074902537c49 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 11 May 2026 11:18:17 +0800 Subject: [PATCH 01/18] spec: support MTP --- common/common.h | 1 + common/speculative.cpp | 339 +++++++++++++++++++++++++++++++- convert_hf_to_gguf.py | 61 +++++- gguf-py/gguf/constants.py | 18 +- include/llama.h | 3 + src/llama-arch.cpp | 19 +- src/llama-arch.h | 2 + src/llama-context.cpp | 117 +++++++++-- src/llama-context.h | 9 + src/llama-cparams.h | 1 + src/llama-ext.h | 16 ++ src/llama-graph.h | 2 + src/llama-hparams.cpp | 6 + src/llama-hparams.h | 2 + src/llama-model-loader.cpp | 13 +- src/llama-model-loader.h | 2 +- src/llama-model.cpp | 19 +- src/models/models.h | 26 +++ src/models/qwen35.cpp | 32 ++- src/models/qwen35_mtp.cpp | 207 +++++++++++++++++++ src/models/qwen35moe.cpp | 32 ++- src/models/qwen35moe_mtp.cpp | 252 ++++++++++++++++++++++++ tests/test-llama-archs.cpp | 3 + tools/server/server-context.cpp | 73 ++++++- 24 files changed, 1210 insertions(+), 45 deletions(-) create mode 100644 src/models/qwen35_mtp.cpp create mode 100644 src/models/qwen35moe_mtp.cpp diff --git a/common/common.h b/common/common.h index a3cd1743957..b39095e7233 100644 --- a/common/common.h +++ b/common/common.h @@ -159,6 +159,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding + COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction head loaded from the target GGUF COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values diff --git a/common/speculative.cpp b/common/speculative.cpp index 0eebcb3dcfe..c60082ae68a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -3,6 +3,7 @@ #include "common.h" #include "ggml.h" #include "llama.h" +#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP) #include "log.h" #include "ngram-cache.h" #include "ngram-map.h" @@ -23,6 +24,7 @@ const std::map common_speculative_type_fro {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE}, {"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3}, + {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -364,6 +366,330 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { } }; +struct common_speculative_state_mtp : public common_speculative_impl { + common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) + + llama_batch batch; + + std::vector smpls; + + int32_t n_embd = 0; + + // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. + // The last h-row of one process() call needs the first token of the NEXT + // call to pair with, so it's stashed here until that next call fires. + std::vector> pending_h; // [n_seq][n_embd] + std::vector pending_pos; // [n_seq] + + std::vector last_n_drafted; + std::vector last_n_accepted; + + // Number of trunk output rows produced by the most recent process() call. + // Used by draft() for the first AR step (when last_n_accepted is -1) to + // pick the last prefill row out of ctx_tgt's pre-norm buffer. + std::vector last_trunk_n_outputs; + + common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) + , params(params.draft) + { + GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation"); + + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); + + n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); + + const int32_t n_ub = (int32_t) llama_n_ubatch(ctx_dft); + batch = llama_batch_init(/*n_tokens=*/ n_ub, /*embd=*/ n_embd, /*n_seq_max=*/ 1); + // llama_batch_init allocates only one of token/embd; MTP needs both. + // TODO: fix, how to call without malloc + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_ub); + + smpls.resize(n_seq); + for (auto & s : smpls) { + common_params_sampling sparams; + sparams.no_perf = false; + sparams.top_k = 1; + sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); + } + + llama_set_embeddings_pre_norm(ctx_tgt, true); + llama_set_embeddings_pre_norm(ctx_dft, true); + + pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); + pending_pos.assign(n_seq, -1); + + last_n_drafted.assign(n_seq, 0); + last_n_accepted.assign(n_seq, -1); + last_trunk_n_outputs.assign(n_seq, 0); + } + + ~common_speculative_state_mtp() override { + if (batch.token != nullptr) { + free(batch.token); + batch.token = nullptr; + } + llama_batch_free(batch); + } + + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size()); + + last_n_accepted[seq_id] = -1; + last_n_drafted [seq_id] = 0; + pending_pos [seq_id] = -1; + + const int32_t N = (int32_t) prompt.size(); + if (N <= 0) { + return; + } + auto * ctx_dft = this->params.ctx_dft; + const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + if (pos_max < N - 1) { + LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — " + "process() hook may not have run on every prefill ubatch " + "(need_embd / logits=1 on every prompt position?). " + "Drafts may degrade.\n", + __func__, (int) pos_max, N - 1); + } + } + + bool process(const llama_batch & batch_in) override { + if (batch_in.n_tokens <= 0) { + return true; + } + + // Single-seq for now (asserted in ctor). Future: bucket by seq_id. + const llama_seq_id seq_id = 0; + + // TODO: how to make it work with vision tokens? + if (batch_in.token == nullptr || batch_in.embd != nullptr) { + pending_pos[seq_id] = -1; + return true; + } + + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + + const int32_t n_rows = batch_in.n_tokens; + const llama_pos pos_start = batch_in.pos[0]; + + const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + if (pos_start <= pos_max_dft) { + return true; + } + + // Stale pending: discard if the new batch doesn't start one past it. + const bool pending_continues = + pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start; + if (pending_pos[seq_id] >= 0 && !pending_continues) { + pending_pos[seq_id] = -1; + } + + // Build a paired hook batch: + // row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues + // rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1] + // The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not* + // decoded this call — it waits for the next batch's first token to pair. + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + common_batch_clear(batch); + int out_idx = 0; + + auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) { + std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes); + batch.token [out_idx] = tok; + batch.pos [out_idx] = pos; + batch.n_seq_id[out_idx] = 1; + batch.seq_id [out_idx][0] = seq_id; + batch.logits [out_idx] = 0; + ++out_idx; + }; + + if (pending_continues) { + add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start); + } + + // TODO: is there is a fast way to build this batch? + for (int k = 0; k + 1 < n_rows; ++k) { + if (batch_in.logits[k] == 0) { + LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n", + __func__, k); + break; + } + const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k); + if (h_k == nullptr) { + LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k); + break; + } + add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]); + } + + if (out_idx > 0) { + batch.n_tokens = out_idx; + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n", + __func__, (int) rc, (int) pos_start, out_idx); + return false; + } + } + + // last_n_accepted < 0) can find the last pre-norm row of this batch. + // We assume every batch position has logits=1 (server sets need_embd + // for MTP slots) → n_outputs == n_tokens. + last_trunk_n_outputs[seq_id] = n_rows; + + // Stash the last h-row (h_{n_rows-1}) as the new pending for the next + // process() call's first token to pair with. + if (batch_in.logits[n_rows - 1] != 0) { + const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1); + if (h_last != nullptr) { + std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); + pending_pos[seq_id] = batch_in.pos[n_rows - 1]; + } else { + pending_pos[seq_id] = -1; + } + } else { + // No trunk output at the tail — can't carry over. + pending_pos[seq_id] = -1; + } + + return true; + } + + void draft(common_speculative_draft_params_vec & dparams) override { + // Single-seq for now (asserted in ctor). Future: iterate over dparams. + const llama_seq_id seq_id = 0; + if ((size_t) seq_id >= dparams.size()) { + return; + } + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + return; + } + + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + auto * smpl = smpls[seq_id].get(); + + GGML_ASSERT(dp.result != nullptr); + auto & draft_tokens = *dp.result; + draft_tokens.clear(); + + if (last_n_drafted[seq_id] > 0) { + const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1; + if (n_to_drop > 0) { + const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + if (pos_max >= 0) { + const llama_pos drop_from = pos_max - n_to_drop + 1; + llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + } + } + last_n_drafted[seq_id] = 0; + last_n_accepted[seq_id] = 0; + } + + // Effective draft length: min(global cap, per-sequence override). + int32_t n_max = std::max(1, params.n_max); + if (dp.n_max > 0) { + n_max = std::min(n_max, dp.n_max); + } + + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + common_sampler_reset(smpl); + + llama_token cond_tok = dp.id_last; + llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1; + + for (int32_t k = 0; k < n_max; ++k) { + const float * h_row = nullptr; + + if (k == 0) { + // Condition on the trunk's pre-norm row. + int32_t row_idx; + if (last_n_accepted[seq_id] < 0) { + // First draft after begin(): use the last prefill row. + row_idx = last_trunk_n_outputs[seq_id] - 1; + } else { + // After accept(n_accepted): row of the next conditioning + // position in the trunk's verify batch. + row_idx = last_n_accepted[seq_id]; + } + if (row_idx < 0) { + LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n", + __func__, row_idx); + break; + } + h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx); + } else { + // AR step: condition on the MTP head's own pre-norm row from + // the just-completed single-token decode. n_outputs=1 there, + // so the row is at batch position 0. + h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0); + } + + if (h_row == nullptr) { + LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k); + break; + } + + // 1-token batch carrying both (token, h_pre_norm). + common_batch_clear(batch); + std::memcpy(batch.embd, h_row, row_bytes); + batch.token [0] = cond_tok; + batch.pos [0] = pos; + batch.n_seq_id[0] = 1; + batch.seq_id [0][0] = seq_id; + batch.logits [0] = 1; // need logits for sampling + batch.n_tokens = 1; + + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n", + __func__, rc, k); + break; + } + + const llama_token best = common_sampler_sample(smpl, ctx_dft, 0); + common_sampler_accept(smpl, best, /*is_generated=*/ false); + draft_tokens.push_back(best); + cond_tok = best; + ++pos; + } + + last_n_drafted[seq_id] = (uint16_t) draft_tokens.size(); + } + + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size()); + + auto * ctx_dft = this->params.ctx_dft; + + const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id]; + + const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1); + + if (pos_max < 0) { + last_n_accepted[seq_id] = (int32_t) n_accepted; + return; + } + + if (n_to_drop > 0) { + const llama_pos drop_from = pos_max - n_to_drop + 1; + llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + } + + last_n_drafted [seq_id] = 0; + last_n_accepted[seq_id] = (int32_t) n_accepted; + } +}; + // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_impl_ngram_simple : public common_speculative_impl { common_params_speculative_ngram_map params; @@ -820,6 +1146,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple"; case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; @@ -875,8 +1202,8 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_draft_model_path = !params.draft.mparams.path.empty(); bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); - // bool has_mtp = false; // TODO: add MTP here bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_MTP)) && params.draft.ctx_dft != nullptr; bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); @@ -885,7 +1212,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD)); // when adding a new type - update here the logic above - static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8); + static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9); // this list here defines the priority of the speculators // the one with highest priority are listed first @@ -919,10 +1246,12 @@ common_speculative * common_speculative_init(common_params_speculative & params, if (has_draft_simple) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params)); } - // TODO: add MTP here if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } + if (has_mtp) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); + } } std::vector> impls = {}; @@ -940,6 +1269,10 @@ common_speculative * common_speculative_init(common_params_speculative & params, impls.push_back(std::make_unique(config.params, n_seq)); break; } + case COMMON_SPECULATIVE_TYPE_MTP: { + impls.push_back(std::make_unique(config.params, n_seq)); + break; + } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 5cff8684856..2cdbd9c1dba 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5549,13 +5549,70 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION) +class _Qwen35MtpMixin: + """Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries + the MTP block under `mtp_num_hidden_layers` and the tensors under + `mtp.*`; we extend block_count, emit the nextn metadata key, and remap + `mtp.*` to the standard layer-indexed nextn naming so the existing + tensor_map handles them.""" + + # Class-level annotations so the type checker understands the attributes + # available on the concrete subclasses in the MRO + hparams: dict[str, Any] + model_arch: gguf.MODEL_ARCH + gguf_writer: gguf.GGUFWriter + block_count: int + tensor_map: gguf.TensorNameMap + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() # ty: ignore[unresolved-attribute] + if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0: + self.gguf_writer.add_nextn_predict_layers(n) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`. + if name.startswith("model.language_model."): + name = "model." + name[len("model.language_model."):] + elif name.startswith("language_model."): + name = name[len("language_model."):] + + # Remap MTP block tensors to llama.cpp's layer-indexed nextn naming. + # HF: mtp.layers.0.* (transformer block at MTP slot 0) + # mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm + if name.startswith("mtp."): + n_layer = self.hparams["num_hidden_layers"] + if name.find("layers.") != -1: + assert bid is not None + name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}") + else: + remapper = { + "mtp.fc": "model.layers.{bid}.eh_proj", + "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm", + "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm", + "mtp.norm": "model.layers.{bid}.shared_head.norm", + } + stem = Path(name).stem + suffix = Path(name).suffix + tmpl = remapper[stem] + suffix + for b in range(n_layer, self.block_count): + yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute] + return + + yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute] + + @ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM") -class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase): +class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35 @ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM") -class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase): +class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35MOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4055ec2873a..c25f217f990 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2114,7 +2114,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.QWEN35MOE: [ MODEL_TENSOR.TOKEN_EMBD, @@ -2145,7 +2152,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/include/llama.h b/include/llama.h index 308e8ba9dbd..1b896944735 100644 --- a/include/llama.h +++ b/include/llama.h @@ -310,6 +310,9 @@ extern "C" { // override key-value pairs of the model meta data const struct llama_model_kv_override * kv_overrides; + // override architecture from GGUF (e.g. load the MTP head of a Qwen3.5 GGUF as "qwen35_mtp") + const char * override_arch; + // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 59dde99e362..794666d09a4 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -41,6 +41,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_QWEN35, "qwen35" }, { LLM_ARCH_QWEN35MOE, "qwen35moe" }, + { LLM_ARCH_QWEN35_MTP, "qwen35_mtp" }, + { LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -757,14 +759,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index e37d548c98e..71c2ca6e6b3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -45,6 +45,8 @@ enum llm_arch { LLM_ARCH_QWEN3VLMOE, LLM_ARCH_QWEN35, LLM_ARCH_QWEN35MOE, + LLM_ARCH_QWEN35_MTP, + LLM_ARCH_QWEN35MOE_MTP, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3d9714ab166..aea8a0a4e81 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -49,6 +49,7 @@ llama_context::llama_context( cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -860,6 +861,33 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const int64_t j = output_resolve_row(i); + const uint32_t n_embd = model.hparams.n_embd; + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1068,12 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.embeddings_pre_norm = value; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1241,7 +1275,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1348,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1416,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1578,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::mapget_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1859,20 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd; + + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float)); + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1893,10 +1957,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +1974,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +1992,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2009,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2038,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2105,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -3436,6 +3513,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) { + ctx->set_embeddings_pre_norm(value); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/src/llama-context.h b/src/llama-context.h index 92d1b0cf95a..e16ac4c618b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 9d359474132..1e4e9e29ed8 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -27,6 +27,7 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm bool causal_attn; bool offload_kqv; bool flash_attn; diff --git a/src/llama-ext.h b/src/llama-ext.h index 8ce29d217cb..11f1986676a 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// mirrors: +// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); + +// mirrors: +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5cb1756c6a9..d3cd69a674c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -644,6 +644,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +673,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 002d15d415f..2239309c8fb 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 0160a89caa2..e2d051edc6c 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d..c645d0785ab 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 7b3d6703c03..c476026d3e5 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ff30a2ae7a6..6abfbfb3e3b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -276,6 +276,10 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_qwen35(params); case LLM_ARCH_QWEN35MOE: return new llama_model_qwen35moe(params); + case LLM_ARCH_QWEN35_MTP: + return new llama_model_qwen35_mtp(params); + case LLM_ARCH_QWEN35MOE_MTP: + return new llama_model_qwen35moe_mtp(params); case LLM_ARCH_MISTRAL3: return new llama_model_mistral3(params); case LLM_ARCH_MIMO2: @@ -309,6 +313,15 @@ llama_model * llama_model_create(llama_model_loader & ml, const llama_model_para if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } + if (params.override_arch != nullptr && params.override_arch[0] != '\0') { + const llm_arch override = llm_arch_from_string(params.override_arch); + if (override == LLM_ARCH_UNKNOWN) { + throw std::runtime_error(std::string("unknown override architecture: '") + params.override_arch + "'"); + } + LLAMA_LOG_INFO("%s: overriding architecture %s -> %s\n", + __func__, llm_arch_name(arch), llm_arch_name(override)); + arch = override; + } return llama_model_create(arch, params); } @@ -1396,7 +1409,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } } - ml.done_getting_tensors(); + const bool partial_load = (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP); + ml.done_getting_tensors(partial_load); // populate tensors_by_name for (auto & [_, ctx_ptr] : ml.ctx_map) { @@ -2089,6 +2103,7 @@ llama_model_params llama_model_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.override_arch =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_direct_io =*/ false, @@ -2313,6 +2328,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3VLMOE: case LLM_ARCH_QWEN35: case LLM_ARCH_QWEN35MOE: + case LLM_ARCH_QWEN35_MTP: + case LLM_ARCH_QWEN35MOE_MTP: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/src/models/models.h b/src/models/models.h index 6d5f18a8e20..1f04d313d13 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1785,6 +1785,32 @@ struct llama_model_qwen35moe : public llama_model_base { }; +struct llama_model_qwen35_mtp : public llama_model_base { + llama_model_qwen35_mtp(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_qwen35moe_mtp : public llama_model_base { + llama_model_qwen35moe_mtp(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_mistral3 : public llama_model_base { llama_model_mistral3(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index f276be61ba8..79fdd8f679b 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -12,16 +12,23 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -83,6 +90,16 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // NextN/MTP tensors (preserved but unused) - only bound on MTP layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } } } @@ -111,7 +128,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +147,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +179,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/qwen35_mtp.cpp b/src/models/qwen35_mtp.cpp new file mode 100644 index 00000000000..83039e98db5 --- /dev/null +++ b/src/models/qwen35_mtp.cpp @@ -0,0 +1,207 @@ +#include "models.h" + +void llama_model_qwen35_mtp::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); + + // only the MTP layers get a KV cache, trunk layers are skipped. + hparams.kv_only_nextn = true; + hparams.n_layer_kv_from_start = -1; + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = false; + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_qwen35_mtp::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + if (output == nullptr) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + for (int i = 0; i < n_layer; ++i) { + if (static_cast(i) < n_main) { + continue; // trunk layer — owned by the sibling QWEN35 model + } + + auto & layer = layers[i]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_qwen35_mtp::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +// LLM_ARCH_QWEN35_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35_MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // The MTP block lives at the source file's original layer index. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35_MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35_MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index cf05dc9d61c..5912aa38153 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -15,16 +15,23 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -96,6 +103,16 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + + // NextN/MTP tensors (preserved but unused) - only bound on MTP layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } } } @@ -124,7 +141,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +160,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +192,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/qwen35moe_mtp.cpp b/src/models/qwen35moe_mtp.cpp new file mode 100644 index 00000000000..9f662213bee --- /dev/null +++ b/src/models/qwen35moe_mtp.cpp @@ -0,0 +1,252 @@ +#include "models.h" + +void llama_model_qwen35moe_mtp::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); + GGML_ASSERT(hparams.n_expert > 0 && "QWEN35MOE_MTP requires n_expert > 0"); + + // only the MTP layers get a KV cache, trunk layers are skipped. + hparams.kv_only_nextn = true; + hparams.n_layer_kv_from_start = -1; + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = false; + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_qwen35moe_mtp::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + if (output == nullptr) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + for (int i = 0; i < n_layer; ++i) { + if (static_cast(i) < n_main) { + continue; // trunk layer — owned by the sibling QWEN35MOE model + } + + auto & layer = layers[i]; + + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_qwen35moe_mtp::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +// LLM_ARCH_QWEN35MOE_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE_MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE_MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35MOE_MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 16af11a2862..fd0d3696d77 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -406,6 +406,9 @@ static bool arch_supported(const llm_arch arch) { if (arch == LLM_ARCH_DEEPSEEK2OCR) { return false; } + if (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP) { + return false; // MTP-only arch; requires a sibling trunk model and cannot run standalone. + } // FIXME some models are segfaulting with WebGPU: #ifdef GGML_USE_WEBGPU diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ce743e6656d..008a640bcbd 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -57,6 +57,11 @@ struct server_slot { llama_context * ctx_tgt = nullptr; llama_context * ctx_dft = nullptr; + // True when this slot's speculative impl is MTP (ctx_dft is the MTP head). + // MTP needs every prefill position to carry logits=1 so the streaming + // hook in common_speculative_state_mtp::process() can read t_h_pre_norm. + bool is_mtp_enabled = false; + // multimodal mtmd_context * mctx = nullptr; @@ -237,8 +242,20 @@ struct server_slot { (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); } + bool is_mtp() const { return is_mtp_enabled; } + + // The trunk needs to emit logits at every prefill position when either: + // - the task asked for embeddings, or + // - MTP is enabled for this slot (the streaming hook in process() reads + // h_pre_norm at every prompt position). + bool need_embd() const { + GGML_ASSERT(task); + return task->need_embd() || is_mtp(); + } + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens + // (MTP supports splitting — uses task->need_embd() not need_embd()) bool can_split() const { GGML_ASSERT(task); @@ -743,6 +760,53 @@ struct server_context_impl { ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); + } else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + // MTP head lives in the *target* GGUF — load it as a sibling model + // with override_arch and feed it through the existing ctx_dft slot. + char trunk_arch[64] = {0}; + llama_model_meta_val_str(model_tgt, "general.architecture", trunk_arch, sizeof(trunk_arch)); + + const char * mtp_arch = nullptr; + if (std::string(trunk_arch) == "qwen35") { + mtp_arch = "qwen35_mtp"; + } else if (std::string(trunk_arch) == "qwen35moe") { + mtp_arch = "qwen35moe_mtp"; + } else { + SRV_ERR("MTP not supported for trunk architecture '%s'\n", trunk_arch); + return false; + } + + if (params_base.n_parallel > 1) { + SRV_ERR("MTP currently supports only n_parallel=1; got %d\n", params_base.n_parallel); + return false; + } + + SRV_INF("loading MTP head from '%s' (override_arch=%s)\n", + params_base.model.path.c_str(), mtp_arch); + + auto mparams_mtp = common_model_params_to_llama(params_base); + mparams_mtp.override_arch = mtp_arch; + + model_dft.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp)); + if (model_dft == nullptr) { + SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str()); + return false; + } + + auto cparams_mtp = common_context_params_to_llama(params_base); + cparams_mtp.n_ctx = llama_n_ctx_seq(ctx_tgt); + cparams_mtp.n_seq_max = 1; + + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp)); + if (ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create MTP context\n"); + return false; + } + + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } @@ -855,6 +919,7 @@ struct server_context_impl { slot.ctx_tgt = ctx_tgt; slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); + slot.is_mtp_enabled = (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; @@ -2716,12 +2781,14 @@ struct server_context_impl { break; } - // embedding requires all tokens in the batch to be output + // embedding requires all tokens in the batch to be output; + // MTP also wants logits at every prompt position so the + // streaming hook can mirror t_h_pre_norm into ctx_dft. common_batch_add(batch, cur_tok, slot.prompt.tokens.pos_next(), { slot.id }, - slot.task->need_embd()); + slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -2838,7 +2905,7 @@ struct server_context_impl { slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->need_embd()); } if (batch.n_tokens == 0) { From edcaf120d9737c5a6f1cae99cdd4ae5862e1ed16 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 11 May 2026 12:22:37 +0800 Subject: [PATCH 02/18] fix batch size --- common/speculative.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index c60082ae68a..132313becd6 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -401,11 +401,11 @@ struct common_speculative_state_mtp : public common_speculative_impl { n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); - const int32_t n_ub = (int32_t) llama_n_ubatch(ctx_dft); - batch = llama_batch_init(/*n_tokens=*/ n_ub, /*embd=*/ n_embd, /*n_seq_max=*/ 1); + const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); + batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1); // llama_batch_init allocates only one of token/embd; MTP needs both. // TODO: fix, how to call without malloc - batch.token = (llama_token *) malloc(sizeof(llama_token) * n_ub); + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b); smpls.resize(n_seq); for (auto & s : smpls) { From 180f8ff346c0b583a12693e6d310018e12902632 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 11 May 2026 17:09:05 +0800 Subject: [PATCH 03/18] rename files --- src/models/{qwen35_mtp.cpp => qwen35-mtp.cpp} | 0 src/models/{qwen35moe_mtp.cpp => qwen35moe-mtp.cpp} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/models/{qwen35_mtp.cpp => qwen35-mtp.cpp} (100%) rename src/models/{qwen35moe_mtp.cpp => qwen35moe-mtp.cpp} (100%) diff --git a/src/models/qwen35_mtp.cpp b/src/models/qwen35-mtp.cpp similarity index 100% rename from src/models/qwen35_mtp.cpp rename to src/models/qwen35-mtp.cpp diff --git a/src/models/qwen35moe_mtp.cpp b/src/models/qwen35moe-mtp.cpp similarity index 100% rename from src/models/qwen35moe_mtp.cpp rename to src/models/qwen35moe-mtp.cpp From 5c58cc5bddb000b5267de1f8189f8ae0a03150fe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 May 2026 17:26:03 +0300 Subject: [PATCH 04/18] cont : simplify (#7) --- common/speculative.cpp | 344 ++++++++++++++------------------ tools/server/server-context.cpp | 14 +- 2 files changed, 157 insertions(+), 201 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 132313becd6..94088c65439 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -379,22 +379,14 @@ struct common_speculative_state_mtp : public common_speculative_impl { // The last h-row of one process() call needs the first token of the NEXT // call to pair with, so it's stashed here until that next call fires. std::vector> pending_h; // [n_seq][n_embd] - std::vector pending_pos; // [n_seq] - std::vector last_n_drafted; - std::vector last_n_accepted; - - // Number of trunk output rows produced by the most recent process() call. - // Used by draft() for the first AR step (when last_n_accepted is -1) to - // pick the last prefill row out of ctx_tgt's pre-norm buffer. - std::vector last_trunk_n_outputs; + std::vector i_batch_beg; + std::vector i_batch_end; common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) , params(params.draft) { - GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation"); - auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); @@ -411,7 +403,7 @@ struct common_speculative_state_mtp : public common_speculative_impl { for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; - sparams.top_k = 1; + sparams.top_k = 10; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } @@ -420,11 +412,9 @@ struct common_speculative_state_mtp : public common_speculative_impl { llama_set_embeddings_pre_norm(ctx_dft, true); pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); - pending_pos.assign(n_seq, -1); - last_n_drafted.assign(n_seq, 0); - last_n_accepted.assign(n_seq, -1); - last_trunk_n_outputs.assign(n_seq, 0); + i_batch_beg.assign(n_seq, -1); + i_batch_end.assign(n_seq, -1); } ~common_speculative_state_mtp() override { @@ -436,12 +426,6 @@ struct common_speculative_state_mtp : public common_speculative_impl { } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size()); - - last_n_accepted[seq_id] = -1; - last_n_drafted [seq_id] = 0; - pending_pos [seq_id] = -1; - const int32_t N = (int32_t) prompt.size(); if (N <= 0) { return; @@ -462,231 +446,207 @@ struct common_speculative_state_mtp : public common_speculative_impl { return true; } - // Single-seq for now (asserted in ctor). Future: bucket by seq_id. - const llama_seq_id seq_id = 0; - // TODO: how to make it work with vision tokens? if (batch_in.token == nullptr || batch_in.embd != nullptr) { - pending_pos[seq_id] = -1; return true; } - auto * ctx_tgt = this->params.ctx_tgt; - auto * ctx_dft = this->params.ctx_dft; + const int32_t n_tokens = batch_in.n_tokens; - const int32_t n_rows = batch_in.n_tokens; - const llama_pos pos_start = batch_in.pos[0]; + // remember the frist and last batch index for each sequence + std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1); + std::fill(i_batch_end.begin(), i_batch_end.end(), -1); - const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_start <= pos_max_dft) { - return true; - } + for (int k = 0; k < n_tokens; ++k) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + GGML_ASSERT(batch_in.n_seq_id[k] == 1); - // Stale pending: discard if the new batch doesn't start one past it. - const bool pending_continues = - pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start; - if (pending_pos[seq_id] >= 0 && !pending_continues) { - pending_pos[seq_id] = -1; + if (batch_in.seq_id[k][0] == seq_id) { + i_batch_end[seq_id] = k; + if (i_batch_beg[seq_id] < 0) { + i_batch_beg[seq_id] = k; + } + } + } } - // Build a paired hook batch: - // row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues - // rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1] - // The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not* - // decoded this call — it waits for the next batch's first token to pair. + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + const size_t row_bytes = (size_t) n_embd * sizeof(float); common_batch_clear(batch); - int out_idx = 0; - - auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) { - std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes); - batch.token [out_idx] = tok; - batch.pos [out_idx] = pos; - batch.n_seq_id[out_idx] = 1; - batch.seq_id [out_idx][0] = seq_id; - batch.logits [out_idx] = 0; - ++out_idx; - }; - if (pending_continues) { - add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start); + for (int k = 0; k < n_tokens; ++k) { + common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); } - // TODO: is there is a fast way to build this batch? - for (int k = 0; k + 1 < n_rows; ++k) { - if (batch_in.logits[k] == 0) { - LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n", - __func__, k); - break; - } - const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k); - if (h_k == nullptr) { - LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k); - break; - } - add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]); - } + // shift the tgt embeddings to the right by one position + // assumes that the tokens in the batch are sequential for each sequence + // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] + // ^--- this is a problem + // TODO:this is generally true, but would be nice to assert it + { + const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt); + std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); + + //{ + // // string with seq_ids in the batch + // std::stringstream ss; + // for (int i = 0; i < n_tokens; ++i) { + // ss << batch_in.seq_id[i][0] << ","; + // } + // LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str()); + //} + } + + // fill the pending embeddings from a previous run + auto set_h = [&](int idx, const float * h_row) { + std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); + }; - if (out_idx > 0) { - batch.n_tokens = out_idx; - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n", - __func__, (int) rc, (int) pos_start, out_idx); - return false; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; } + + set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); } - // last_n_accepted < 0) can find the last pre-norm row of this batch. - // We assume every batch position has logits=1 (server sets need_embd - // for MTP slots) → n_outputs == n_tokens. - last_trunk_n_outputs[seq_id] = n_rows; + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + return false; + } - // Stash the last h-row (h_{n_rows-1}) as the new pending for the next - // process() call's first token to pair with. - if (batch_in.logits[n_rows - 1] != 0) { - const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1); - if (h_last != nullptr) { - std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); - pending_pos[seq_id] = batch_in.pos[n_rows - 1]; - } else { - pending_pos[seq_id] = -1; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_end[seq_id] < 0) { + continue; } - } else { - // No trunk output at the tail — can't carry over. - pending_pos[seq_id] = -1; + + const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]); + std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); } return true; } void draft(common_speculative_draft_params_vec & dparams) override { - // Single-seq for now (asserted in ctor). Future: iterate over dparams. - const llama_seq_id seq_id = 0; - if ((size_t) seq_id >= dparams.size()) { - return; + auto & ctx_dft = params.ctx_dft; + + common_batch_clear(batch); + + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); + + const float * h_row = nullptr; + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + + if (!dp.drafting) { + continue; + } + + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); + + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); + + h_row = pending_h[seq_id].data(); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); } - auto & dp = dparams[seq_id]; - if (!dp.drafting) { + + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } - auto * ctx_tgt = this->params.ctx_tgt; - auto * ctx_dft = this->params.ctx_dft; - auto * smpl = smpls[seq_id].get(); - - GGML_ASSERT(dp.result != nullptr); - auto & draft_tokens = *dp.result; - draft_tokens.clear(); - - if (last_n_drafted[seq_id] > 0) { - const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1; - if (n_to_drop > 0) { - const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_max >= 0) { - const llama_pos drop_from = pos_max - n_to_drop + 1; - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + int i = 0; + + while (n_drafting > 0) { + int i_batch = 0; + + common_batch_clear(batch); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; } - } - last_n_drafted[seq_id] = 0; - last_n_accepted[seq_id] = 0; - } - // Effective draft length: min(global cap, per-sequence override). - int32_t n_max = std::max(1, params.n_max); - if (dp.n_max > 0) { - n_max = std::min(n_max, dp.n_max); - } + auto * smpl = smpls[seq_id].get(); - const size_t row_bytes = (size_t) n_embd * sizeof(float); + common_sampler_sample(smpl, ctx_dft, i_batch, true); + h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch); + ++i_batch; - common_sampler_reset(smpl); + const auto * cur_p = common_sampler_get_candidates(smpl, true); - llama_token cond_tok = dp.id_last; - llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1; + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } - for (int32_t k = 0; k < n_max; ++k) { - const float * h_row = nullptr; + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; - if (k == 0) { - // Condition on the trunk's pre-norm row. - int32_t row_idx; - if (last_n_accepted[seq_id] < 0) { - // First draft after begin(): use the last prefill row. - row_idx = last_trunk_n_outputs[seq_id] - 1; - } else { - // After accept(n_accepted): row of the next conditioning - // position in the trunk's verify batch. - row_idx = last_n_accepted[seq_id]; + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; } - if (row_idx < 0) { - LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n", - __func__, row_idx); - break; + + common_sampler_accept(smpl, id, true); + + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; + + result.push_back(id); + + if ((params.n_max <= (int) result.size()) || + (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + drafting[seq_id] = false; + n_drafting--; + continue; } - h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx); - } else { - // AR step: condition on the MTP head's own pre-norm row from - // the just-completed single-token decode. n_outputs=1 there, - // so the row is at batch position 0. - h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0); + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); } - if (h_row == nullptr) { - LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k); + if (batch.n_tokens == 0) { break; } - // 1-token batch carrying both (token, h_pre_norm). - common_batch_clear(batch); - std::memcpy(batch.embd, h_row, row_bytes); - batch.token [0] = cond_tok; - batch.pos [0] = pos; - batch.n_seq_id[0] = 1; - batch.seq_id [0][0] = seq_id; - batch.logits [0] = 1; // need logits for sampling - batch.n_tokens = 1; - - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n", - __func__, rc, k); + // evaluate the drafted tokens on the draft model + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); break; } - const llama_token best = common_sampler_sample(smpl, ctx_dft, 0); - common_sampler_accept(smpl, best, /*is_generated=*/ false); - draft_tokens.push_back(best); - cond_tok = best; - ++pos; + ++i; } - last_n_drafted[seq_id] = (uint16_t) draft_tokens.size(); - } - - void accept(llama_seq_id seq_id, uint16_t n_accepted) override { - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size()); - - auto * ctx_dft = this->params.ctx_dft; - - const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id]; - - const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1); - - if (pos_max < 0) { - last_n_accepted[seq_id] = (int32_t) n_accepted; - return; - } + for (auto & dp : dparams) { + if (!dp.drafting) { + continue; + } - if (n_to_drop > 0) { - const llama_pos drop_from = pos_max - n_to_drop + 1; - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } } + } - last_n_drafted [seq_id] = 0; - last_n_accepted[seq_id] = (int32_t) n_accepted; + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { } }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 008a640bcbd..9430c31b4c5 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -762,7 +762,8 @@ struct server_context_impl { params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); - } else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) { // MTP head lives in the *target* GGUF — load it as a sibling model // with override_arch and feed it through the existing ctx_dft slot. char trunk_arch[64] = {0}; @@ -778,11 +779,6 @@ struct server_context_impl { return false; } - if (params_base.n_parallel > 1) { - SRV_ERR("MTP currently supports only n_parallel=1; got %d\n", params_base.n_parallel); - return false; - } - SRV_INF("loading MTP head from '%s' (override_arch=%s)\n", params_base.model.path.c_str(), mtp_arch); @@ -796,8 +792,6 @@ struct server_context_impl { } auto cparams_mtp = common_context_params_to_llama(params_base); - cparams_mtp.n_ctx = llama_n_ctx_seq(ctx_tgt); - cparams_mtp.n_seq_max = 1; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp)); if (ctx_dft == nullptr) { @@ -919,7 +913,9 @@ struct server_context_impl { slot.ctx_tgt = ctx_tgt; slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); - slot.is_mtp_enabled = (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_dft != nullptr); + slot.is_mtp_enabled = (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) + && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; From 3c3aebaaa055773ff5231d2ecdf074619661d83c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 13 May 2026 11:12:20 +0800 Subject: [PATCH 05/18] MTP: clean-up (#9) * MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion --- common/arg.cpp | 27 ++- common/download.cpp | 55 ++++-- common/download.h | 7 +- common/speculative.cpp | 2 +- convert_hf_to_gguf.py | 74 +++++++- include/llama.h | 6 + src/llama-arch.cpp | 2 - src/llama-arch.h | 2 - src/llama-context.cpp | 27 ++- src/llama-cparams.h | 1 + src/llama-graph.h | 1 + src/llama-memory.h | 3 + src/llama-model.cpp | 36 ++-- src/models/models.h | 30 +--- src/models/qwen35-mtp.cpp | 207 --------------------- src/models/qwen35.cpp | 254 +++++++++++++++++++++----- src/models/qwen35moe-mtp.cpp | 252 -------------------------- src/models/qwen35moe.cpp | 306 +++++++++++++++++++++++++++----- tests/test-llama-archs.cpp | 6 +- tools/server/server-context.cpp | 39 ++-- 20 files changed, 689 insertions(+), 648 deletions(-) delete mode 100644 src/models/qwen35-mtp.cpp delete mode 100644 src/models/qwen35moe-mtp.cpp diff --git a/common/arg.cpp b/common/arg.cpp index f1f4c12a3ce..3fc7bd88d39 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -335,11 +335,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa struct handle_model_result { bool found_mmproj = false; common_params_model mmproj; + + bool found_mtp = false; + common_params_model mtp; }; static handle_model_result common_params_handle_model(struct common_params_model & model, const std::string & bearer_token, - bool offline) { + bool offline, + bool search_mtp = false) { handle_model_result result; if (!model.docker_repo.empty()) { @@ -354,7 +358,7 @@ static handle_model_result common_params_handle_model(struct common_params_model common_download_opts opts; opts.bearer_token = bearer_token; opts.offline = offline; - auto download_result = common_download_model(model, opts, true); + auto download_result = common_download_model(model, opts, true, search_mtp); if (download_result.model_path.empty()) { LOG_ERR("error: failed to download model from Hugging Face\n"); @@ -368,6 +372,11 @@ static handle_model_result common_params_handle_model(struct common_params_model result.found_mmproj = true; result.mmproj.path = download_result.mmproj_path; } + + if (!download_result.mtp_path.empty()) { + result.found_mtp = true; + result.mtp.path = download_result.mtp_path; + } } else if (!model.url.empty()) { if (model.path.empty()) { auto f = string_split(model.url, '#').front(); @@ -436,7 +445,11 @@ static bool parse_bool_value(const std::string & value) { // void common_params_handle_models(common_params & params, llama_example curr_ex) { - auto res = common_params_handle_model(params.model, params.hf_token, params.offline); + const bool spec_type_mtp = std::find(params.speculative.types.begin(), + params.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_MTP) != params.speculative.types.end(); + + auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_mtp); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -450,6 +463,14 @@ void common_params_handle_models(common_params & params, llama_example curr_ex) break; } } + // when --spec-type mtp is set and no draft model was provided explicitly, + // fall back to the MTP head discovered alongside the -hf model + if (spec_type_mtp && res.found_mtp && + params.speculative.draft.mparams.path.empty() && + params.speculative.draft.mparams.hf_repo.empty() && + params.speculative.draft.mparams.url.empty()) { + params.speculative.draft.mparams.path = res.mtp.path; + } common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline); common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); } diff --git a/common/download.cpp b/common/download.cpp index d6d47b2d2fc..c1cbe2033aa 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files, return result; } -static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, - const std::string & model) { +// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"), +// preferring deeper shared directory prefix with the model, then closest quantization +static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files, + const std::string & model, + const std::string & keyword) { hf_cache::hf_file best; size_t best_depth = 0; int best_diff = 0; @@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, for (const auto & f : files) { if (!string_ends_with(f.path, ".gguf") || - f.path.find("mmproj") == std::string::npos) { + f.path.find(keyword) == std::string::npos) { continue; } - auto mmproj_parts = string_split(f.path, '/'); - auto mmproj_dir = mmproj_parts.end() - 1; + auto sib_parts = string_split(f.path, '/'); + auto sib_dir = sib_parts.end() - 1; auto [_, dir] = std::mismatch(model_parts.begin(), model_dir, - mmproj_parts.begin(), mmproj_dir); - if (dir != mmproj_dir) { + sib_parts.begin(), sib_dir); + if (dir != sib_dir) { continue; } - size_t depth = dir - mmproj_parts.begin(); + size_t depth = dir - sib_parts.begin(); auto bits = extract_quant_bits(f.path); auto diff = std::abs(bits - model_bits); @@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, return best; } +static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, + const std::string & model) { + return find_best_sibling(files, model, "mmproj"); +} + +static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files, + const std::string & model) { + return find_best_sibling(files, model, "mtp-"); +} + static bool gguf_filename_is_model(const std::string & filepath) { if (!string_ends_with(filepath, ".gguf")) { return false; @@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) { } return filename.find("mmproj") == std::string::npos && - filename.find("imatrix") == std::string::npos; + filename.find("imatrix") == std::string::npos && + filename.find("mtp-") == std::string::npos; } static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files, @@ -673,11 +687,13 @@ struct hf_plan { hf_cache::hf_file primary; hf_cache::hf_files model_files; hf_cache::hf_file mmproj; + hf_cache::hf_file mtp; }; static hf_plan get_hf_plan(const common_params_model & model, const common_download_opts & opts, - bool download_mmproj) { + bool download_mmproj, + bool download_mtp) { hf_plan plan; hf_cache::hf_files all; @@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model, plan.mmproj = find_best_mmproj(all, primary.path); } + if (download_mtp) { + plan.mtp = find_best_mtp(all, primary.path); + } + return plan; } @@ -756,7 +776,8 @@ static std::vector get_url_tasks(const common_params_model & mode common_download_model_result common_download_model(const common_params_model & model, const common_download_opts & opts, - bool download_mmproj) { + bool download_mmproj, + bool download_mtp) { common_download_model_result result; std::vector tasks; hf_plan hf; @@ -764,13 +785,16 @@ common_download_model_result common_download_model(const common_params_model & bool is_hf = !model.hf_repo.empty(); if (is_hf) { - hf = get_hf_plan(model, opts, download_mmproj); + hf = get_hf_plan(model, opts, download_mmproj, download_mtp); for (const auto & f : hf.model_files) { tasks.push_back({f.url, f.local_path}); } if (!hf.mmproj.path.empty()) { tasks.push_back({hf.mmproj.url, hf.mmproj.local_path}); } + if (!hf.mtp.path.empty()) { + tasks.push_back({hf.mtp.url, hf.mtp.local_path}); + } } else if (!model.url.empty()) { tasks = get_url_tasks(model); } else { @@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model & if (!hf.mmproj.path.empty()) { result.mmproj_path = hf_cache::finalize_file(hf.mmproj); } + + if (!hf.mtp.path.empty()) { + result.mtp_path = hf_cache::finalize_file(hf.mtp); + } } else { result.model_path = model.path; } @@ -946,7 +974,8 @@ std::vector common_list_cached_models() { for (const auto & f : files) { auto split = get_gguf_split_info(f.path); if (split.index != 1 || split.tag.empty() || - split.prefix.find("mmproj") != std::string::npos) { + split.prefix.find("mmproj") != std::string::npos || + split.prefix.find("MTP") != std::string::npos) { continue; } if (seen.insert(f.repo_id + ":" + split.tag).second) { diff --git a/common/download.h b/common/download.h index edc3e9f1a71..4a169ef7796 100644 --- a/common/download.h +++ b/common/download.h @@ -59,6 +59,7 @@ struct common_download_opts { struct common_download_model_result { std::string model_path; std::string mmproj_path; + std::string mtp_path; }; // Download model from HuggingFace repo or URL @@ -83,12 +84,14 @@ struct common_download_model_result { // when opts.offline=true, no network requests are made // when download_mmproj=true, searches for mmproj in same directory as model or any parent directory // then with the closest quantization bits +// when download_mtp=true, applies the same sibling search for an MTP-head GGUF // -// returns result with model_path and mmproj_path (empty on failure) +// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure) common_download_model_result common_download_model( const common_params_model & model, const common_download_opts & opts = {}, - bool download_mmproj = false + bool download_mmproj = false, + bool download_mtp = false ); // returns list of cached models diff --git a/common/speculative.cpp b/common/speculative.cpp index 94088c65439..65d4b9ceaa4 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1198,7 +1198,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__); has_draft_simple = false; } - } else if (has_draft_model_path) { + } else if (has_draft_model_path && !has_mtp && !has_draft_eagle3) { LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__); has_draft_simple = true; } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cdbd9c1dba..b68e1e2a17e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -95,6 +95,7 @@ class ModelBase: gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None + metadata: gguf.Metadata dir_model_card: Path remote_hf_model_id: str | None @@ -5564,26 +5565,59 @@ class _Qwen35MtpMixin: block_count: int tensor_map: gguf.TensorNameMap + mtp_only: bool = False + no_mtp: bool = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0) + self.block_count = self.hparams["num_hidden_layers"] + if not self.no_mtp: + self.block_count += self.hparams.get("mtp_num_hidden_layers", 0) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + @classmethod + def filter_tensors(cls, item): + name, _ = item + if name.startswith("mtp."): + if cls.no_mtp: + return None + return item + if cls.mtp_only: + # In --mtp mode, drop trunk weights and keep only the shared embeddings/output + # tensors that the standalone MTP graph references at inference time. + canonical = name.replace("language_model.", "") + keep = canonical in ( + "model.embed_tokens.weight", "model.norm.weight", "lm_head.weight", + "embed_tokens.weight", "norm.weight", + ) + if not keep: + return None + return super().filter_tensors(item) # ty: ignore[unresolved-attribute] + def set_gguf_parameters(self): super().set_gguf_parameters() # ty: ignore[unresolved-attribute] + if self.no_mtp: + return if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0: self.gguf_writer.add_nextn_predict_layers(n) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`. - if name.startswith("model.language_model."): - name = "model." + name[len("model.language_model."):] - elif name.startswith("language_model."): - name = name[len("language_model."):] + def prepare_metadata(self, vocab_only: bool): + # TextModel.prepare_metadata resolves a directory fname_out into a concrete + # file path, so snapshot is_dir() first to decide whether to apply the mtp- prefix. + from_dir = self.fname_out.is_dir() + super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute] + + if not self.mtp_only or not from_dir: + return + output_type: str = self.ftype.name.partition("_")[2] # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + fname_default: str = gguf.naming_convention( + self.metadata.name, self.metadata.basename, self.metadata.finetune, # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + self.metadata.version, size_label=None, output_type=output_type, model_type=None) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf" + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Remap MTP block tensors to llama.cpp's layer-indexed nextn naming. - # HF: mtp.layers.0.* (transformer block at MTP slot 0) - # mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm if name.startswith("mtp."): n_layer = self.hparams["num_hidden_layers"] if name.find("layers.") != -1: @@ -14109,6 +14143,14 @@ def parse_args() -> argparse.Namespace: "--mmproj", action="store_true", help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", ) + parser.add_argument( + "--mtp", action="store_true", + help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.", + ) + parser.add_argument( + "--no-mtp", action="store_true", + help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.", + ) parser.add_argument( "--mistral-format", action="store_true", help="Whether the model is stored following the Mistral format.", @@ -14268,6 +14310,20 @@ def main() -> None: else: model_class = MistralModel + if args.mtp and args.no_mtp: + logger.error("--mtp and --no-mtp are mutually exclusive") + sys.exit(1) + + if (args.mtp or args.no_mtp) and not issubclass(model_class, _Qwen35MtpMixin): + logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today") + sys.exit(1) + + # set on the class so __init__ / filter_tensors see the correct mode + if args.no_mtp: + model_class.no_mtp = True # ty: ignore[unresolved-attribute] + if args.mtp: + model_class.mtp_only = True # ty: ignore[unresolved-attribute] + model_instance = model_class(dir_model, output_type, fname_out, is_big_endian=args.bigendian, use_temp_file=args.use_temp_file, eager=args.no_lazy, diff --git a/include/llama.h b/include/llama.h index 1b896944735..b814e2c58de 100644 --- a/include/llama.h +++ b/include/llama.h @@ -198,6 +198,11 @@ extern "C" { LLAMA_SPLIT_MODE_TENSOR = 3, }; + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -339,6 +344,7 @@ extern "C" { int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 794666d09a4..ab4334da79b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -41,8 +41,6 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_QWEN35, "qwen35" }, { LLM_ARCH_QWEN35MOE, "qwen35moe" }, - { LLM_ARCH_QWEN35_MTP, "qwen35_mtp" }, - { LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 71c2ca6e6b3..e37d548c98e 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -45,8 +45,6 @@ enum llm_arch { LLM_ARCH_QWEN3VLMOE, LLM_ARCH_QWEN35, LLM_ARCH_QWEN35MOE, - LLM_ARCH_QWEN35_MTP, - LLM_ARCH_QWEN35MOE_MTP, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index aea8a0a4e81..6ecbe1b6083 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -21,6 +22,14 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -66,6 +75,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_type = params.ctx_type; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -279,6 +290,7 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, + /*.ctx_type= */ cparams.ctx_type, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -1738,7 +1750,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + + const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -2198,7 +2211,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3177,7 +3190,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3280,6 +3293,7 @@ llama_context_params llama_context_default_params() { /*.n_seq_max =*/ 1, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -3383,6 +3397,13 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.nextn_predict_layers == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 1e4e9e29ed8..cbf74eba63e 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -41,6 +41,7 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/src/llama-graph.h b/src/llama-graph.h index d3cd69a674c..9e55d0a675e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -32,6 +32,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { diff --git a/src/llama-memory.h b/src/llama-memory.h index 4a157b91fdb..4ad1612e45b 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include #include @@ -20,6 +21,8 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; }; enum llama_memory_status { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6abfbfb3e3b..2a157eb1299 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -276,10 +276,6 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_qwen35(params); case LLM_ARCH_QWEN35MOE: return new llama_model_qwen35moe(params); - case LLM_ARCH_QWEN35_MTP: - return new llama_model_qwen35_mtp(params); - case LLM_ARCH_QWEN35MOE_MTP: - return new llama_model_qwen35moe_mtp(params); case LLM_ARCH_MISTRAL3: return new llama_model_mistral3(params); case LLM_ARCH_MIMO2: @@ -1409,8 +1405,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } } - const bool partial_load = (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP); - ml.done_getting_tensors(partial_load); + ml.done_getting_tensors(); // populate tensors_by_name for (auto & [_, ctx_ptr] : ml.ctx_map) { @@ -1948,6 +1943,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, @@ -1957,7 +1958,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, nullptr); - } else if (llm_arch_is_hybrid(arch)) { + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -1972,6 +1973,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t il) { return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter_attn = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && !hparams.is_recurrent(il); + }; + filter_recr = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && hparams.is_recurrent(il); + }; } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { @@ -2014,6 +2023,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_filter_cb filter = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2025,6 +2035,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } + if (mtp_on_hybrid_qwen35) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2040,7 +2055,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, - nullptr, + filter, reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2057,7 +2072,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, - nullptr, + filter, nullptr); } } @@ -2161,6 +2176,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } @@ -2328,8 +2344,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3VLMOE: case LLM_ARCH_QWEN35: case LLM_ARCH_QWEN35MOE: - case LLM_ARCH_QWEN35_MTP: - case LLM_ARCH_QWEN35MOE_MTP: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/src/models/models.h b/src/models/models.h index 1f04d313d13..fe95b9b89ad 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1739,6 +1739,10 @@ struct llama_model_qwen35 : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; @@ -1781,30 +1785,8 @@ struct llama_model_qwen35moe : public llama_model_base { const llama_model & model; }; - std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; -}; - - -struct llama_model_qwen35_mtp : public llama_model_base { - llama_model_qwen35_mtp(const struct llama_model_params & params) : llama_model_base(params) {} - void load_arch_hparams(llama_model_loader & ml) override; - void load_arch_tensors(llama_model_loader & ml) override; - - struct graph : public llm_graph_context { - graph(const llama_model & model, const llm_graph_params & params); - }; - - std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; -}; - - -struct llama_model_qwen35moe_mtp : public llama_model_base { - llama_model_qwen35moe_mtp(const struct llama_model_params & params) : llama_model_base(params) {} - void load_arch_hparams(llama_model_loader & ml) override; - void load_arch_tensors(llama_model_loader & ml) override; - - struct graph : public llm_graph_context { - graph(const llama_model & model, const llm_graph_params & params); + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); }; std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; diff --git a/src/models/qwen35-mtp.cpp b/src/models/qwen35-mtp.cpp deleted file mode 100644 index 83039e98db5..00000000000 --- a/src/models/qwen35-mtp.cpp +++ /dev/null @@ -1,207 +0,0 @@ -#include "models.h" - -void llama_model_qwen35_mtp::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); - - // only the MTP layers get a KV cache, trunk layers are skipped. - hparams.kv_only_nextn = true; - hparams.n_layer_kv_from_start = -1; - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = false; - } - - type = LLM_TYPE_UNKNOWN; -} - -void llama_model_qwen35_mtp::load_arch_tensors(llama_model_loader &) { - LLAMA_LOAD_LOCALS; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - if (output == nullptr) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - for (int i = 0; i < n_layer; ++i) { - if (static_cast(i) < n_main) { - continue; // trunk layer — owned by the sibling QWEN35 model - } - - auto & layer = layers[i]; - - // MTP block looks like a full-attention Qwen3.5 decoder block. - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - } -} - -std::unique_ptr llama_model_qwen35_mtp::build_arch_graph(const llm_graph_params & params) const { - return std::make_unique(*this, params); -} - -// LLM_ARCH_QWEN35_MTP draft head for Qwen3.5/3.6 dense series -llama_model_qwen35_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) - : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35_MTP currently only supports a single MTP block"); - - const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - - // The MTP block lives at the source file's original layer index. - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; - const auto & layer = model.layers[il]; - - GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); - GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); - GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); - - int sections[4]; - std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - - auto inp = std::make_unique(hparams.n_embd); - - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - ggml_set_input(inp->tokens); - - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); - ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); - cb(tok_embd, "mtp_tok_embd", il); - - res->add_input(std::move(inp)); - - ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); - cb(h_norm, "mtp_hnorm", il); - - ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); - cb(e_norm, "mtp_enorm", il); - - ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); - cb(concat, "mtp_concat", il); - - ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); - cb(cur, "mtp_eh_proj", il); - - ggml_tensor * inpSA = cur; - - cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "mtp_attn_norm", il); - - ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); - cb(Qcur_full, "mtp_Qcur_full", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, - n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur_full) * n_embd_head * 2, - ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, - 0); - Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "mtp_Qcur_normed", il); - - ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, - n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur_full) * n_embd_head * 2, - ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, - ggml_element_size(Qcur_full) * n_embd_head); - gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cb(gate, "mtp_gate", il); - - ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); - cb(Kcur, "mtp_Kcur_normed", il); - - ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cb(Vcur, "mtp_Vcur", il); - - Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - const float kq_scale = hparams.f_attention_scale == 0.0f - ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - - cur = build_attn(inp_attn, - nullptr, nullptr, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(cur, "mtp_attn_pregate", il); - - cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); - cur = build_lora_mm(layer.wo, cur, layer.wo_s); - cb(cur, "mtp_attn_out", il); - - cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "mtp_attn_residual", il); - - ggml_tensor * ffn_residual = cur; - cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "mtp_attn_post_norm", il); - - cur = build_ffn(cur, - layer.ffn_up, nullptr, layer.ffn_up_s, - layer.ffn_gate, nullptr, layer.ffn_gate_s, - layer.ffn_down, nullptr, layer.ffn_down_s, - nullptr, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "mtp_ffn_out", il); - - cur = ggml_add(ctx0, cur, ffn_residual); - cb(cur, "mtp_post_ffn", il); - - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm - ? layer.nextn.shared_head_norm - : model.output_norm; - GGML_ASSERT(head_norm_w && "QWEN35_MTP: missing both nextn.shared_head_norm and output_norm"); - cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); - cb(cur, "mtp_shared_head_norm", -1); - - ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; - GGML_ASSERT(head_w && "QWEN35_MTP: missing LM head (nextn.shared_head_head or model.output)"); - cur = build_lora_mm(head_w, cur); - cb(cur, "result_output", -1); - - res->t_logits = cur; - ggml_build_forward_expand(gf, cur); -} diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 79fdd8f679b..ca4297e94f3 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -15,7 +15,6 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. @@ -36,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -50,60 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // NextN/MTP tensors (preserved but unused) - only bound on MTP layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - } + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -493,3 +522,146 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // The MTP block lives at the source file's original layer index. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/qwen35moe-mtp.cpp b/src/models/qwen35moe-mtp.cpp deleted file mode 100644 index 9f662213bee..00000000000 --- a/src/models/qwen35moe-mtp.cpp +++ /dev/null @@ -1,252 +0,0 @@ -#include "models.h" - -void llama_model_qwen35moe_mtp::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer); - GGML_ASSERT(hparams.n_expert > 0 && "QWEN35MOE_MTP requires n_expert > 0"); - - // only the MTP layers get a KV cache, trunk layers are skipped. - hparams.kv_only_nextn = true; - hparams.n_layer_kv_from_start = -1; - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = false; - } - - type = LLM_TYPE_UNKNOWN; -} - -void llama_model_qwen35moe_mtp::load_arch_tensors(llama_model_loader &) { - LLAMA_LOAD_LOCALS; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - if (output == nullptr) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - for (int i = 0; i < n_layer; ++i) { - if (static_cast(i) < n_main) { - continue; // trunk layer — owned by the sibling QWEN35MOE model - } - - auto & layer = layers[i]; - - // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - - // Routed experts - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared experts - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - - // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - } -} - -std::unique_ptr llama_model_qwen35moe_mtp::build_arch_graph(const llm_graph_params & params) const { - return std::make_unique(*this, params); -} - -// LLM_ARCH_QWEN35MOE_MTP draft head for Qwen3.5/3.6 MoE -llama_model_qwen35moe_mtp::graph::graph(const llama_model & model, const llm_graph_params & params) - : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE_MTP currently only supports a single MTP block"); - - const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; - const auto & layer = model.layers[il]; - - GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); - GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); - GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); - GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); - - int sections[4]; - std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - - auto inp = std::make_unique(hparams.n_embd); - - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - ggml_set_input(inp->tokens); - - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); - ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); - cb(tok_embd, "mtp_tok_embd", il); - - res->add_input(std::move(inp)); - - ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); - cb(h_norm, "mtp_hnorm", il); - - ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); - cb(e_norm, "mtp_enorm", il); - - ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); - cb(concat, "mtp_concat", il); - - ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); - cb(cur, "mtp_eh_proj", il); - - ggml_tensor * inpSA = cur; - - cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "mtp_attn_norm", il); - - ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); - cb(Qcur_full, "mtp_Qcur_full", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, - n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur_full) * n_embd_head * 2, - ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, - 0); - Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "mtp_Qcur_normed", il); - - ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, - n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur_full) * n_embd_head * 2, - ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, - ggml_element_size(Qcur_full) * n_embd_head); - gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cb(gate, "mtp_gate", il); - - ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); - cb(Kcur, "mtp_Kcur_normed", il); - - ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cb(Vcur, "mtp_Vcur", il); - - Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - const float kq_scale = hparams.f_attention_scale == 0.0f - ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - - cur = build_attn(inp_attn, - nullptr, nullptr, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(cur, "mtp_attn_pregate", il); - - cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); - cur = build_lora_mm(layer.wo, cur, layer.wo_s); - cb(cur, "mtp_attn_out", il); - - cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "mtp_attn_residual", il); - - ggml_tensor * ffn_residual = cur; - cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "mtp_attn_post_norm", il); - - // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). - ggml_tensor * moe_out = - build_moe_ffn(cur, - layer.ffn_gate_inp, - layer.ffn_up_exps, - layer.ffn_gate_exps, - layer.ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - hparams.expert_weights_scale, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, - nullptr, layer.ffn_gate_up_exps, - layer.ffn_up_exps_s, - layer.ffn_gate_exps_s, - layer.ffn_down_exps_s); - cb(moe_out, "mtp_ffn_moe_out", il); - - if (layer.ffn_up_shexp != nullptr) { - ggml_tensor * ffn_shexp = - build_ffn(cur, - layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, - layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, - layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, - nullptr, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(ffn_shexp, "mtp_ffn_shexp", il); - - ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); - shared_gate = ggml_sigmoid(ctx0, shared_gate); - cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); - - ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); - cb(ffn_shexp, "mtp_ffn_shexp_gated", il); - - cur = ggml_add(ctx0, moe_out, ffn_shexp); - } else { - cur = moe_out; - } - cb(cur, "mtp_ffn_out", il); - - cur = ggml_add(ctx0, cur, ffn_residual); - cb(cur, "mtp_post_ffn", il); - - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm - ? layer.nextn.shared_head_norm - : model.output_norm; - GGML_ASSERT(head_norm_w && "QWEN35MOE_MTP: missing both nextn.shared_head_norm and output_norm"); - cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); - cb(cur, "mtp_shared_head_norm", -1); - - ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; - GGML_ASSERT(head_w && "QWEN35MOE_MTP: missing LM head (nextn.shared_head_head or model.output)"); - cur = build_lora_mm(head_w, cur); - cb(cur, "result_output", -1); - - res->t_logits = cur; - ggml_build_forward_expand(gf, cur); -} diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 5912aa38153..a4c7cb6ad14 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -18,7 +18,6 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. @@ -39,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -53,70 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - - // NextN/MTP tensors (preserved but unused) - only bound on MTP layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); - } + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -547,3 +586,178 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index fd0d3696d77..03d7c19c78b 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -406,11 +406,7 @@ static bool arch_supported(const llm_arch arch) { if (arch == LLM_ARCH_DEEPSEEK2OCR) { return false; } - if (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP) { - return false; // MTP-only arch; requires a sibling trunk model and cannot run standalone. - } - - // FIXME some models are segfaulting with WebGPU: +// FIXME some models are segfaulting with WebGPU: #ifdef GGML_USE_WEBGPU if (arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_KIMI_LINEAR) { return false; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9430c31b4c5..76b4294f897 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -756,6 +756,14 @@ struct server_context_impl { } auto cparams = common_context_params_to_llama(params_dft); + + const bool spec_mtp = std::find(params_base.speculative.types.begin(), + params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end(); + if (spec_mtp) { + cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + } + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); @@ -764,36 +772,13 @@ struct server_context_impl { params_base.speculative.draft.ctx_dft = ctx_dft.get(); } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) { - // MTP head lives in the *target* GGUF — load it as a sibling model - // with override_arch and feed it through the existing ctx_dft slot. - char trunk_arch[64] = {0}; - llama_model_meta_val_str(model_tgt, "general.architecture", trunk_arch, sizeof(trunk_arch)); - - const char * mtp_arch = nullptr; - if (std::string(trunk_arch) == "qwen35") { - mtp_arch = "qwen35_mtp"; - } else if (std::string(trunk_arch) == "qwen35moe") { - mtp_arch = "qwen35moe_mtp"; - } else { - SRV_ERR("MTP not supported for trunk architecture '%s'\n", trunk_arch); - return false; - } - - SRV_INF("loading MTP head from '%s' (override_arch=%s)\n", - params_base.model.path.c_str(), mtp_arch); - - auto mparams_mtp = common_model_params_to_llama(params_base); - mparams_mtp.override_arch = mtp_arch; - - model_dft.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp)); - if (model_dft == nullptr) { - SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str()); - return false; - } + SRV_INF("creating MTP draft context against the target model '%s'\n", + params_base.model.path.c_str()); auto cparams_mtp = common_context_params_to_llama(params_base); + cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; - ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp)); + ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp)); if (ctx_dft == nullptr) { SRV_ERR("%s", "failed to create MTP context\n"); return false; From f2200a3a7752f55b2b12e61410d6d1c90107c040 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 13 May 2026 14:43:12 +0800 Subject: [PATCH 06/18] mtp -> draft-mtp --- common/arg.cpp | 8 ++++---- common/common.h | 2 +- common/speculative.cpp | 20 ++++++++++---------- tools/server/server-context.cpp | 6 +++--- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 3fc7bd88d39..ebcd4aefeca 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -445,11 +445,11 @@ static bool parse_bool_value(const std::string & value) { // void common_params_handle_models(common_params & params, llama_example curr_ex) { - const bool spec_type_mtp = std::find(params.speculative.types.begin(), + const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), params.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_MTP) != params.speculative.types.end(); + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); - auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_mtp); + auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -465,7 +465,7 @@ void common_params_handle_models(common_params & params, llama_example curr_ex) } // when --spec-type mtp is set and no draft model was provided explicitly, // fall back to the MTP head discovered alongside the -hf model - if (spec_type_mtp && res.found_mtp && + if (spec_type_draft_mtp && res.found_mtp && params.speculative.draft.mparams.path.empty() && params.speculative.draft.mparams.hf_repo.empty() && params.speculative.draft.mparams.url.empty()) { diff --git a/common/common.h b/common/common.h index b39095e7233..987a31d0bb3 100644 --- a/common/common.h +++ b/common/common.h @@ -159,7 +159,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding - COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction head loaded from the target GGUF + COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // multi-token prediction head loaded from the target GGUF COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values diff --git a/common/speculative.cpp b/common/speculative.cpp index 65d4b9ceaa4..29e4b315633 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -24,7 +24,7 @@ const std::map common_speculative_type_fro {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE}, {"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3}, - {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, + {"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP}, {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -366,7 +366,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { } }; -struct common_speculative_state_mtp : public common_speculative_impl { +struct common_speculative_state_draft_mtp : public common_speculative_impl { common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) llama_batch batch; @@ -383,8 +383,8 @@ struct common_speculative_state_mtp : public common_speculative_impl { std::vector i_batch_beg; std::vector i_batch_end; - common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) - : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) + common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) , params(params.draft) { auto * ctx_tgt = this->params.ctx_tgt; @@ -417,7 +417,7 @@ struct common_speculative_state_mtp : public common_speculative_impl { i_batch_end.assign(n_seq, -1); } - ~common_speculative_state_mtp() override { + ~common_speculative_state_draft_mtp() override { if (batch.token != nullptr) { free(batch.token); batch.token = nullptr; @@ -1106,7 +1106,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple"; case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3"; - case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; + case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; @@ -1163,7 +1163,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 - bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_MTP)) && params.draft.ctx_dft != nullptr; + bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); @@ -1210,7 +1210,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } if (has_mtp) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params)); } } @@ -1229,8 +1229,8 @@ common_speculative * common_speculative_init(common_params_speculative & params, impls.push_back(std::make_unique(config.params, n_seq)); break; } - case COMMON_SPECULATIVE_TYPE_MTP: { - impls.push_back(std::make_unique(config.params, n_seq)); + case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: { + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 76b4294f897..37776e10f8b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -759,7 +759,7 @@ struct server_context_impl { const bool spec_mtp = std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end(); + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); if (spec_mtp) { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } @@ -771,7 +771,7 @@ struct server_context_impl { params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) { + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) { SRV_INF("creating MTP draft context against the target model '%s'\n", params_base.model.path.c_str()); @@ -899,7 +899,7 @@ struct server_context_impl { slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); slot.is_mtp_enabled = (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; From 19dd00b0e41249385048ad548eb888f910251954 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 13 May 2026 15:00:30 +0800 Subject: [PATCH 07/18] remove unused llama_arch --- include/llama.h | 3 --- src/llama-model.cpp | 10 ---------- 2 files changed, 13 deletions(-) diff --git a/include/llama.h b/include/llama.h index b814e2c58de..470add516c8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -315,9 +315,6 @@ extern "C" { // override key-value pairs of the model meta data const struct llama_model_kv_override * kv_overrides; - // override architecture from GGUF (e.g. load the MTP head of a Qwen3.5 GGUF as "qwen35_mtp") - const char * override_arch; - // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2a157eb1299..f9601eae094 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -309,15 +309,6 @@ llama_model * llama_model_create(llama_model_loader & ml, const llama_model_para if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } - if (params.override_arch != nullptr && params.override_arch[0] != '\0') { - const llm_arch override = llm_arch_from_string(params.override_arch); - if (override == LLM_ARCH_UNKNOWN) { - throw std::runtime_error(std::string("unknown override architecture: '") + params.override_arch + "'"); - } - LLAMA_LOG_INFO("%s: overriding architecture %s -> %s\n", - __func__, llm_arch_name(arch), llm_arch_name(override)); - arch = override; - } return llama_model_create(arch, params); } @@ -2118,7 +2109,6 @@ llama_model_params llama_model_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, - /*.override_arch =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_direct_io =*/ false, From e7b4848151377395b1693d326d1cda3fcd61c2d9 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 13 May 2026 15:09:00 +0800 Subject: [PATCH 08/18] add need_embd in speculative --- common/speculative.cpp | 45 +++++++++++++++++++++++++++++++++ common/speculative.h | 3 +++ tools/server/server-context.cpp | 16 +----------- 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 29e4b315633..3914a5e245d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -145,6 +145,9 @@ struct common_speculative_impl { virtual void draft(common_speculative_draft_params_vec & dparams) = 0; virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; + + // true if this implementation requires the target context to extract embeddings + virtual bool need_embd() const = 0; }; struct common_speculative_impl_draft_simple : public common_speculative_impl { @@ -340,6 +343,10 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { @@ -364,6 +371,10 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative_state_draft_mtp : public common_speculative_impl { @@ -648,6 +659,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { } + + bool need_embd() const override { + return true; + } }; // state of self-speculation (simple implementation, not ngram-map) @@ -689,6 +704,10 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_map_k : public common_speculative_impl { @@ -737,6 +756,10 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl { common_ngram_map_accept(config[seq_id], n_accepted); } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_mod : public common_speculative_impl { @@ -905,6 +928,10 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl { } } } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_cache : public common_speculative_impl { @@ -1038,6 +1065,10 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative { @@ -1333,6 +1364,20 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b return result; } +bool common_speculative_need_embd(common_speculative * spec) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->need_embd()) { + return true; + } + } + + return false; +} + void common_speculative_draft(common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 51f0b059fa4..614db9b1b50 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -53,6 +53,9 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co // process the batch and update the internal state of the speculative context bool common_speculative_process(common_speculative * spec, const llama_batch & batch); +// true if any implementation requires target embeddings to be extracted +bool common_speculative_need_embd(common_speculative * spec); + // generate drafts for the sequences specified with `common_speculative_get_draft_params` void common_speculative_draft(common_speculative * spec); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 37776e10f8b..16bcd0bd490 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -57,11 +57,6 @@ struct server_slot { llama_context * ctx_tgt = nullptr; llama_context * ctx_dft = nullptr; - // True when this slot's speculative impl is MTP (ctx_dft is the MTP head). - // MTP needs every prefill position to carry logits=1 so the streaming - // hook in common_speculative_state_mtp::process() can read t_h_pre_norm. - bool is_mtp_enabled = false; - // multimodal mtmd_context * mctx = nullptr; @@ -242,15 +237,9 @@ struct server_slot { (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); } - bool is_mtp() const { return is_mtp_enabled; } - - // The trunk needs to emit logits at every prefill position when either: - // - the task asked for embeddings, or - // - MTP is enabled for this slot (the streaming hook in process() reads - // h_pre_norm at every prompt position). bool need_embd() const { GGML_ASSERT(task); - return task->need_embd() || is_mtp(); + return task->need_embd() || (spec && common_speculative_need_embd(spec)); } // if the context does not have a memory module then all embeddings have to be computed within a single ubatch @@ -898,9 +887,6 @@ struct server_context_impl { slot.ctx_tgt = ctx_tgt; slot.ctx_dft = ctx_dft.get(); slot.spec = spec.get(); - slot.is_mtp_enabled = (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) - && (ctx_dft != nullptr); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; From 4e732e0a6c49d41e5aaf3e460f9685a63529adb5 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 26 Apr 2026 00:42:04 +0800 Subject: [PATCH 09/18] llama: allow partial seq_rm for GDN models for speculative decoding Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates. --- common/common.cpp | 16 +++ common/common.h | 7 +- ggml/include/ggml.h | 7 ++ ggml/src/ggml-backend-meta.cpp | 4 +- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- ggml/src/ggml-cpu/ops.cpp | 43 ++++++-- ggml/src/ggml-cuda/gated_delta_net.cu | 88 +++++++++++----- ggml/src/ggml.c | 12 ++- include/llama.h | 2 + src/llama-arch.cpp | 10 ++ src/llama-arch.h | 1 + src/llama-context.cpp | 12 +++ src/llama-cparams.h | 1 + src/llama-graph.cpp | 3 +- src/llama-memory-hybrid-iswa.cpp | 2 + src/llama-memory-hybrid-iswa.h | 1 + src/llama-memory-hybrid.cpp | 2 + src/llama-memory-hybrid.h | 1 + src/llama-memory-recurrent.cpp | 47 +++++++-- src/llama-memory-recurrent.h | 8 ++ src/llama-model.cpp | 3 + src/models/delta-net-base.cpp | 144 +++++++++++++++++++++++++- src/models/models.h | 28 ++++- src/models/qwen35.cpp | 46 +------- src/models/qwen35moe.cpp | 46 +------- src/models/qwen3next.cpp | 44 +------- tests/test-backend-ops.cpp | 21 +++- tools/server/server-context.cpp | 7 +- 28 files changed, 422 insertions(+), 188 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 352af0b1787..6af85e7ab7b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1420,6 +1420,11 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { goto done; } + if (llama_n_rs_seq(ctx) > 0) { + res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED; + goto done; + } + // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { LOG_WRN("%s: the context does not support partial sequence removal\n", __func__); @@ -1490,6 +1495,17 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; + { + // TODO: add for MTP + bool has_spec = params.speculative.has_dft(); + for (auto t : params.speculative.types) { + if (t != COMMON_SPECULATIVE_TYPE_NONE) { + has_spec = true; + break; + } + } + cparams.n_rs_seq = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u; + } cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; diff --git a/common/common.h b/common/common.h index 987a31d0bb3..e33a8f57fd4 100644 --- a/common/common.h +++ b/common/common.h @@ -872,9 +872,10 @@ std::string common_get_model_endpoint(); // enum common_context_seq_rm_type { - COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) - COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences - COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only + COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) + COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences + COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only + COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rs_seq }; // check if the llama_context can remove sequences diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d9985..be1b6be4cb2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2541,6 +2541,13 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs). K is the snapshot slot count: + // K == 1 → output carries the final state only. + // K > 1 → output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots (earlier slots are left untouched + // when n_tokens < K). + // Only slot 0 (state[:, 0, :]) is read as the initial state; the rest is shape signal. GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c0ffd9a048b..6fd401caa6f 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, + // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b7acafdaa8..cdad92eecd4 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2939,7 +2939,9 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - cur = S_v * sizeof(float) * n_tasks; + const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150ce..7485ba4fc86 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // scratch layout per thread: [delta(S_v)] - const int64_t scratch_per_thread = S_v; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int64_t K = src_state->ne[1]; + GGML_ASSERT(K >= 1); + // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; - float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; // output layout: [attn_scores | new_states] - // attn_scores: S_v * H * n_tokens * n_seqs floats - // new_states: S_v * S_v * H * n_seqs floats - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int64_t shift = n_tokens - K; + const float * state_in_base = (const float *)src_state->data; //const int64_t rq1 = nev1 / neq1; @@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; - float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; - // copy input state into output buffer and operate in-place - const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + // copy input state into the working buffer and operate in-place + // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); // attn output pointer for first token of this (head, seq) @@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } } } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec7317..ef4bb29c8ad 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q, int64_t sb3, const uint3 neqk1_magic, const uint3 rq3_magic, - float scale) { + float scale, + int K) { const uint32_t h_idx = blockIdx.x; const uint32_t sequence = blockIdx.y; // each warp owns one column, using warp-level primitives to reduce across rows @@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset + col * S_v; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q, } attn_data += S_v * H; + + if constexpr (keep_intermediates_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } } - // Write state back to global memory (transposed layout) + if constexpr (!keep_intermediates_t) { #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -155,7 +177,7 @@ static void launch_gated_delta_net( int64_t sv1, int64_t sv2, int64_t sv3, int64_t sb1, int64_t sb2, int64_t sb3, int64_t neqk1, int64_t rq3, - float scale, cudaStream_t stream) { + float scale, int K, cudaStream_t stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int num_warps = 4; @@ -169,29 +191,29 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( + gated_delta_net_cuda<16, KDA, keep_intermediates_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( + gated_delta_net_cuda<32, KDA, keep_intermediates_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( + gated_delta_net_cuda<64, KDA, keep_intermediates_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( + gated_delta_net_cuda<128, KDA, keep_intermediates_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } default: @@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_intermediates = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_intermediates) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_intermediates) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa106..476c3079795 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. + GGML_ASSERT(state->ne[0] == S_v * S_v * H); + GGML_ASSERT(state->ne[2] == n_seqs); + GGML_ASSERT(state->ne[3] == 1); + const int64_t K = state->ne[1]; + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; diff --git a/include/llama.h b/include/llama.h index 470add516c8..0f711ab29e3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -338,6 +338,7 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -536,6 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ab4334da79b..c95f341b07d 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -878,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } +bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + bool llm_arch_supports_sm_tensor(const llm_arch & arch) { switch (arch) { case LLM_ARCH_GROK: diff --git a/src/llama-arch.h b/src/llama-arch.h index e37d548c98e..e3765dba3da 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6ecbe1b6083..4a000f59f1c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -51,6 +51,13 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_recurrent_partial_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; @@ -3291,6 +3298,7 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, @@ -3445,6 +3453,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } diff --git a/src/llama-cparams.h b/src/llama-cparams.h index cbf74eba63e..5898a1c38d5 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -12,6 +12,7 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index fe155c92dea..858c297dd76 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2528,7 +2528,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index 10e6b459797..a59561ea54d 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr diff --git a/src/llama-memory-hybrid-iswa.h b/src/llama-memory-hybrid-iswa.h index 807c8aac96c..c9d3f9f57c5 100644 --- a/src/llama-memory-hybrid-iswa.h +++ b/src/llama-memory-hybrid-iswa.h @@ -34,6 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 4ce1af592c1..fd305cab79c 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 558cafdf984..484eafb7499 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index c07f1d969cb..1913e9414a0 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; @@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent( size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); - ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -141,7 +146,6 @@ void llama_memory_recurrent::clear(bool data) { } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -161,10 +165,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -368,6 +378,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -1163,5 +1180,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d73912..29c58afc9c2 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i { bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,13 @@ class llama_memory_recurrent : public llama_memory_i { uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + // per-seq rollback index + std::vector rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f9601eae094..960e0786142 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1948,6 +1948,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, + cparams.n_rs_seq, nullptr); } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the @@ -1989,6 +1990,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_s */ GGML_TYPE_F32, /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2007,6 +2009,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 6bc989c9509..081f490c5a9 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -1,6 +1,7 @@ #include "models.h" #include "llama-impl.h" +#include "llama-memory-recurrent.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -397,7 +398,9 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -443,3 +446,142 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +bool llm_build_delta_net_base::keep_intermediates() const { + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + return cparams.n_rs_seq > 0 + && n_seq_tokens > 1 + && (uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq; +} + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const bool keep = keep_intermediates(); + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + if (!keep) { + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + } else { + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + const size_t row_size = row_count * ggml_element_size(conv_states_all); + for (int64_t t = 1; t <= n_seq_tokens; ++t) { + const uint32_t slot = (uint32_t)(n_seq_tokens - t); + ggml_tensor * src = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + t * ggml_element_size(conv_input)); + ggml_tensor * dst = + ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + ((size_t) slot * mem_size + kv_head) * row_size); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + if (!keep_intermediates()) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = (int64_t) cparams.n_rs_seq; + + ggml_tensor * state_3d = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, D, K, n_seqs); + ggml_tensor * slot_0 = ggml_view_2d(ctx0, state_3d, D, n_seqs, state_3d->nb[2], 0); + ggml_tensor * state_in_2d = ggml_reshape_2d(ctx0, s, D, n_seqs); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_in_2d, slot_0)); + + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d); + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + for (int64_t k_i = 0; k_i < K; ++k_i) { + const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); + ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, + S_v, S_v, H_v, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * S_v), + ggml_row_size(gdn_out->type, S_v * S_v * H_v), + ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, + hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + ((size_t) cache_slot * mem_size + kv_head) * row_size); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + + return output; +} diff --git a/src/models/models.h b/src/models/models.h index fe95b9b89ad..8bef479754a 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator + // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) std::pair build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -65,6 +65,32 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // true when speculative rollback is enabled and the batch fits in the rs cache + bool keep_intermediates() const; + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index ca4297e94f3..4fc36dfccba 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -348,8 +348,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -379,41 +377,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -464,7 +435,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -475,18 +446,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index a4c7cb6ad14..43325279eae 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -371,8 +371,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -402,41 +400,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -487,7 +458,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -498,18 +469,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index cb1b4814caf..8ca1ba9ce1c 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 33311948669..13abc24e53c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3805,16 +3805,17 @@ struct test_gated_delta_net : public test_case { const int v_repeat; const bool permuted; const bool kda; + const int64_t K; // snapshot slot count: 1 = final-only, >1 = last K states std::string vars() override { - return VARS_TO_STR8(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda); + return VARS_TO_STR9(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda, K); } test_gated_delta_net(ggml_type type = GGML_TYPE_F32, int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, - int v_repeat = 1, bool permuted = false, bool kda = false) + int v_repeat = 1, bool permuted = false, bool kda = false, int64_t K = 1) : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), - v_repeat(v_repeat), permuted(permuted), kda(kda) {} + v_repeat(v_repeat), permuted(permuted), kda(kda), K(K) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q; @@ -3836,7 +3837,7 @@ struct test_gated_delta_net : public test_case { const int64_t g_ne0 = kda ? head_size : 1; ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs); ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs); - ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs); + ggml_tensor * state = ggml_new_tensor_3d(ctx, type, head_size * v_repeat * head_size * head_count, K, n_seqs); ggml_set_name(g, "g"); ggml_set_name(beta, "beta"); ggml_set_name(state, "state"); @@ -8995,6 +8996,18 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true)); + // K > 1: output keeps the last min(n_tokens, K) per-token snapshots in the trailing K-token region. + // exact-match cases (K == n_seq_tokens): + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 2, 1, 1, false, false, /*K=*/2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 128, 4, 1, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true, /*K=*/4)); + // overflow: n_tokens > K — only the last K snapshots kept. + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 8, 1, 1, false, false, /*K=*/3)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 16, 2, 1, false, false, /*K=*/4)); + #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true)); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 16bcd0bd490..4be9c706ed1 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -753,6 +753,7 @@ struct server_context_impl { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } + cparams.n_rs_seq = 0; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); @@ -2693,9 +2694,11 @@ struct server_context_impl { // checkpoints are created only if: // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) + // - the model supports partial sequence removal but only up to a fixed bound do_checkpoint = do_checkpoint && ( - (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || - (n_swa > 0)); + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED || + n_swa > 0); bool has_mtmd = false; From 9a3a48722a22abae17baacefe909ce95b389fdf8 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 14:32:42 +0800 Subject: [PATCH 10/18] fix pending state --- common/common.cpp | 18 +++--- common/speculative.cpp | 64 +++++++++++++------ .../speculative-simple/speculative-simple.cpp | 3 +- include/llama.h | 2 +- src/llama-context.cpp | 1 + src/llama-memory-recurrent.cpp | 4 +- src/models/delta-net-base.cpp | 5 +- 7 files changed, 64 insertions(+), 33 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6af85e7ab7b..6030ead189f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1421,6 +1421,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { } if (llama_n_rs_seq(ctx) > 0) { + LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__); res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED; goto done; } @@ -1496,15 +1497,14 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; { - // TODO: add for MTP - bool has_spec = params.speculative.has_dft(); - for (auto t : params.speculative.types) { - if (t != COMMON_SPECULATIVE_TYPE_NONE) { - has_spec = true; - break; - } - } - cparams.n_rs_seq = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u; + // Since MTP has a low number of draft tokens, enable recurrent checkpointing + // for hybrid attn models + // TODO: figure out how to make it place nicely with other speculative techniques + bool has_mtp = std::any_of(params.speculative.types.begin(), + params.speculative.types.end(), [&](auto t) { + return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; + }); + cparams.n_rs_seq = has_mtp ? (uint32_t) params.speculative.draft.n_max : 0u; } cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; diff --git a/common/speculative.cpp b/common/speculative.cpp index 3914a5e245d..b932c86ded7 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -394,6 +394,16 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { std::vector i_batch_beg; std::vector i_batch_end; + // Hidden rows from the most recent target verification batch, grouped by seq. + // Row 0 corresponds to the sampled token, row N to the Nth accepted draft token. + std::vector> verify_h; + std::vector verify_h_rows; + + // Per-seq draft length from the last draft() call, used in accept() to + // roll back ctx_dft's recurrent state past the AR draft's redundant + // pre-advancement before process() mirrored the verify batch. + std::vector last_n_drafted; + common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) , params(params.draft) @@ -414,7 +424,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; - sparams.top_k = 10; + sparams.top_k = 1; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } @@ -426,6 +436,11 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { i_batch_beg.assign(n_seq, -1); i_batch_end.assign(n_seq, -1); + + verify_h.assign(n_seq, {}); + verify_h_rows.assign(n_seq, 0); + + last_n_drafted.assign(n_seq, 0); } ~common_speculative_state_draft_mtp() override { @@ -535,8 +550,17 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { continue; } - const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]); - std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); + const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1; + verify_h_rows[seq_id] = n_rows; + verify_h[seq_id].resize((size_t) n_rows * n_embd); + + for (int32_t i = 0; i < n_rows; ++i) { + const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i); + std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes); + } + + std::memcpy(pending_h[seq_id].data(), + verify_h[seq_id].data() + (size_t) (n_rows - 1) * n_embd, row_bytes); } return true; @@ -606,14 +630,6 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { // add drafted token for each sequence const llama_token id = cur_p->data[0].id; - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { - drafting[seq_id] = false; - n_drafting--; - - continue; - } - common_sampler_accept(smpl, id, true); auto & dp = dparams.at(seq_id); @@ -621,8 +637,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { result.push_back(id); - if ((params.n_max <= (int) result.size()) || - (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + if (params.n_max <= (int) result.size()) { drafting[seq_id] = false; n_drafting--; continue; @@ -646,7 +661,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { ++i; } - for (auto & dp : dparams) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } @@ -654,10 +670,24 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } + + last_n_drafted[seq_id] = (uint16_t) dp.result->size(); } } - void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { + return; + } + + const int32_t n_rows = verify_h_rows[seq_id]; + if (n_rows <= 0) { + return; + } + + const int32_t i_h = std::min(n_accepted, n_rows - 1); + const size_t row_bytes = (size_t) n_embd * sizeof(float); + std::memcpy(pending_h[seq_id].data(), verify_h[seq_id].data() + (size_t) i_h * n_embd, row_bytes); } bool need_embd() const override { @@ -1460,10 +1490,6 @@ void common_speculative_draft(common_speculative * spec) { } void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { - if (n_accepted == 0) { - return; - } - common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 5325bcc9e3f..6848de988d0 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -82,7 +82,8 @@ int main(int argc, char ** argv) { } // check if the context supports partial sequence removal - const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) + || (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED); const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); if (use_ckpt_tgt) { diff --git a/include/llama.h b/include/llama.h index 0f711ab29e3..75095b22d08 100644 --- a/include/llama.h +++ b/include/llama.h @@ -537,7 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4a000f59f1c..a481757e388 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -225,6 +225,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 1913e9414a0..084c5d9ea4f 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -120,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 081f490c5a9..852d5bc88ef 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -547,7 +547,10 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( } const int64_t D = S_v * S_v * H_v; - const int64_t K = (int64_t) cparams.n_rs_seq; + // Memory has 1 + n_rs_seq slots (slot 0 = current, slots 1..n_rs_seq = rollback distances). + // The snapshot buffer must match — otherwise the deepest rollback (= n_rs_seq) reads + // uninitialized memory and corrupts the recurrent state. + const int64_t K = (int64_t) cparams.n_rs_seq + 1; ggml_tensor * state_3d = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, D, K, n_seqs); ggml_tensor * slot_0 = ggml_view_2d(ctx0, state_3d, D, n_seqs, state_3d->nb[2], 0); From 8c05923630110223669f069af2000e9cf10c02bc Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 08:41:23 +0200 Subject: [PATCH 11/18] vulkan: add GDN partial rollback --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++++- .../vulkan-shaders/gated_delta_net.comp | 29 ++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a0a556206d5..b0ea4fd6450 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants { uint32_t sb1, sb2, sb3; uint32_t neq1, rq3; float scale; + uint32_t K; }; struct vk_op_ssm_scan_push_constants { @@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const uint32_t K = (uint32_t)src_state->ne[1]; + const uint32_t s_off = S_v * H * n_tokens * n_seqs; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); @@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + scale, + K }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 5e9f8308c1d..33c3202dbb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters { uint sb1, sb2, sb3; uint neq1, rq3; float scale; + uint K; }; layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; @@ -101,13 +102,21 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; FLOAT_TYPE s_shard[ROWS_PER_LANE]; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = int(n_tokens) - int(K); + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -161,9 +170,21 @@ void main() { } attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(t) - shift; + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } } - [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } } } From a25be1b96147c603cff865b5e41ea46e65c33c5c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 15:55:53 +0800 Subject: [PATCH 12/18] meta: extend check to axis 1 --- ggml/src/ggml-backend-meta.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 6fd401caa6f..df0f405ed9f 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -755,7 +755,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; @@ -2142,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; return backend_ctx->backend_configs[index].backend; } - From 1bd883d55939c9f136217eb6c1abf970f1923060 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 10:24:09 +0300 Subject: [PATCH 13/18] metal: add GDN partial rollback Extend the gated delta net kernel to store intermediate states for partial rollback support on the Metal backend. - Add K (snapshot slot count) as a function constant - Read input state from slot 0 of the 3D state tensor - Write intermediate states to different slots during token loop - For K=1, maintain backward-compatible single-slot behavior Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc Assisted-by: llama.cpp:local pi --- ggml/src/ggml-metal/ggml-metal-device.cpp | 5 ++- ggml/src/ggml-metal/ggml-metal.metal | 46 ++++++++++++++++++----- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f0147af84c1..e288a27f992 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = op->src[5]->ne[1]; const int nsg = op->src[2]->ne[0]/32; @@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( GGML_ASSERT(ne20 % 32 == 0); snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); - snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3882b955847..deb616105ae 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2552,17 +2553,21 @@ kernel void kernel_gated_delta_net_impl( const uint tx = tpitg.x; const uint ty = tpitg.y; - const uint i23 = tgpig.z; // B - const uint i21 = tgpig.y; // H - const uint i20 = tgpig.x*NSG + ty; + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v const uint i01 = i21 % args.ne01; const uint i11 = i21 % args.ne11; const float scale = 1.0f / sqrt((float)S_v); + const uint K = FC_gated_delta_net_K; + + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2580,6 +2585,17 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = (int)args.ne22 - (int)K; + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; @@ -2627,13 +2643,25 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } + } - FOR_UNROLL (short j = 0; j < NSG; j++) { - const short is = tx*NSG + j; - dst_state[is] = ls[j]; + if (K == 1u) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } } #undef S_v From b9c2a4f2add6d56d374a3d176aa88f54973a225d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 10:59:33 +0200 Subject: [PATCH 14/18] delta_net_base: use ggml_pad instead of new_tensor --- src/models/delta-net-base.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 852d5bc88ef..e261ac54e20 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -547,15 +547,11 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( } const int64_t D = S_v * S_v * H_v; - // Memory has 1 + n_rs_seq slots (slot 0 = current, slots 1..n_rs_seq = rollback distances). - // The snapshot buffer must match — otherwise the deepest rollback (= n_rs_seq) reads - // uninitialized memory and corrupts the recurrent state. const int64_t K = (int64_t) cparams.n_rs_seq + 1; - ggml_tensor * state_3d = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, D, K, n_seqs); - ggml_tensor * slot_0 = ggml_view_2d(ctx0, state_3d, D, n_seqs, state_3d->nb[2], 0); - ggml_tensor * state_in_2d = ggml_reshape_2d(ctx0, s, D, n_seqs); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_in_2d, slot_0)); + // TODO: remove pad + simplify + ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0); ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d); cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); From 757be51e2178468f8e0000767165a9e30427832b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 17:20:50 +0800 Subject: [PATCH 15/18] review: add need_rs_seq --- common/common.cpp | 11 +---------- common/common.h | 9 +++++++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6030ead189f..7c5f16a4c8d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1496,16 +1496,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; - { - // Since MTP has a low number of draft tokens, enable recurrent checkpointing - // for hybrid attn models - // TODO: figure out how to make it place nicely with other speculative techniques - bool has_mtp = std::any_of(params.speculative.types.begin(), - params.speculative.types.end(), [&](auto t) { - return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; - }); - cparams.n_rs_seq = has_mtp ? (uint32_t) params.speculative.draft.n_max : 0u; - } + cparams.n_rs_seq = params.speculative.need_n_rs_seq(); cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; diff --git a/common/common.h b/common/common.h index e33a8f57fd4..5bf7a845faf 100644 --- a/common/common.h +++ b/common/common.h @@ -13,6 +13,7 @@ #include #include #include +#include #if defined(_WIN32) && !defined(_WIN32_WINNT) #define _WIN32_WINNT 0x0A00 @@ -356,6 +357,14 @@ struct common_params_speculative { bool has_dft() const { return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty(); } + + uint32_t need_n_rs_seq() const { + bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) { + return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; + }); + + return needs_rs_seq ? draft.n_max : 0u; + } }; struct common_params_vocoder { From 7cb69f5fdc9913026e4444c38af6c4c32c1b6525 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 17:22:52 +0800 Subject: [PATCH 16/18] review: rename part_bounded to n_rs --- common/common.cpp | 2 +- common/common.h | 4 ++-- examples/speculative-simple/speculative-simple.cpp | 1 - tools/server/server-context.cpp | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 7c5f16a4c8d..b215c7f2d66 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { if (llama_n_rs_seq(ctx) > 0) { LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__); - res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED; + res = COMMON_CONTEXT_SEQ_RM_TYPE_RS; goto done; } diff --git a/common/common.h b/common/common.h index 5bf7a845faf..7dc7f15cbdb 100644 --- a/common/common.h +++ b/common/common.h @@ -160,7 +160,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding - COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // multi-token prediction head loaded from the target GGUF + COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values @@ -884,7 +884,7 @@ enum common_context_seq_rm_type { COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only - COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rs_seq + COMMON_CONTEXT_SEQ_RM_TYPE_RS = 3, // can seq_rm partial sequences, bounded by n_rs_seq }; // check if the llama_context can remove sequences diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 6848de988d0..868fd145fdc 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -83,7 +83,6 @@ int main(int argc, char ** argv) { // check if the context supports partial sequence removal const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) - || (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED); const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); if (use_ckpt_tgt) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 4be9c706ed1..b07b96183ef 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2696,8 +2696,8 @@ struct server_context_impl { // - the model uses SWA (and we are not using `swa_full`) // - the model supports partial sequence removal but only up to a fixed bound do_checkpoint = do_checkpoint && ( - ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || - ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS || n_swa > 0); bool has_mtmd = false; From 66e47d1366d4d70b7af9e383554fa1ce49337401 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 17:24:34 +0800 Subject: [PATCH 17/18] review: deslop comments --- ggml/include/ggml.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index be1b6be4cb2..41566d41aef 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2542,12 +2542,10 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 // - // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs). K is the snapshot slot count: - // K == 1 → output carries the final state only. - // K > 1 → output carries K snapshot slots; the kernel writes the last min(n_tokens, K) - // per-token snapshots into the trailing slots (earlier slots are left untouched - // when n_tokens < K). - // Only slot 0 (state[:, 0, :]) is read as the initial state; the rest is shape signal. + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): + // K == 1: output carries the final state only. + // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, From 5060c924107aa2d80b136c7ae0b6502298951662 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 14 May 2026 17:40:55 +0800 Subject: [PATCH 18/18] review: rename, add asserts --- .../speculative-simple/speculative-simple.cpp | 2 +- ggml/src/ggml-cuda/gated_delta_net.cu | 22 +++++++++---------- src/llama-arch.cpp | 2 +- src/llama-arch.h | 2 +- src/llama-context.cpp | 2 +- src/llama-memory-recurrent.cpp | 10 ++++----- src/models/delta-net-base.cpp | 6 ++--- src/models/models.h | 2 +- 8 files changed, 23 insertions(+), 25 deletions(-) diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 868fd145fdc..5325bcc9e3f 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -82,7 +82,7 @@ int main(int argc, char ** argv) { } // check if the context supports partial sequence removal - const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) + const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); if (use_ckpt_tgt) { diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index ef4bb29c8ad..b4c9845e7a7 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -145,7 +145,7 @@ gated_delta_net_cuda(const float * q, attn_data += S_v * H; - if constexpr (keep_intermediates_t) { + if constexpr (keep_rs_t) { const int target_slot = t - shift; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; @@ -158,7 +158,7 @@ gated_delta_net_cuda(const float * q, } } - if constexpr (!keep_intermediates_t) { + if constexpr (!keep_rs_t) { #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; @@ -167,7 +167,7 @@ gated_delta_net_cuda(const float * q, } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -191,26 +191,26 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA, keep_intermediates_t><<>>( + gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA, keep_intermediates_t><<>>( + gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA, keep_intermediates_t><<>>( + gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA, keep_intermediates_t><<>>( + gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); @@ -285,10 +285,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. const int K = (int) src_state->ne[1]; - const bool keep_intermediates = K > 1; + const bool keep_rs = K > 1; if (kda) { - if (keep_intermediates) { + if (keep_rs) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1, rq3, scale, K, stream); @@ -298,7 +298,7 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * sb1, sb2, sb3, neqk1, rq3, scale, K, stream); } } else { - if (keep_intermediates) { + if (keep_rs) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1, rq3, scale, K, stream); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index c95f341b07d..4bee6fbe651 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -878,7 +878,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } -bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch) { +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { switch (arch) { case LLM_ARCH_QWEN35: case LLM_ARCH_QWEN35MOE: diff --git a/src/llama-arch.h b/src/llama-arch.h index e3765dba3da..89cf16cc37c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -637,4 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); -bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a481757e388..d62abc4009b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -52,7 +52,7 @@ llama_context::llama_context( } cparams.n_rs_seq = params.n_rs_seq; - if (cparams.n_rs_seq > 0 && !llm_arch_supports_recurrent_partial_rollback(model.arch)) { + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", __func__, cparams.n_rs_seq); cparams.n_rs_seq = 0; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 084c5d9ea4f..109a77be404 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -170,12 +170,10 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { const llama_pos rollback = cell.pos - (p0 - 1); - if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { - set_rs_idx(seq_id, (uint32_t) rollback); - cell.pos = p0 - 1; - return true; - } - return false; + GGML_ASSERT(rollback >= 1 && rollback <= (llama_pos) n_rs_seq); + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; } // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index e261ac54e20..2a4e00384e9 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -447,7 +447,7 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } -bool llm_build_delta_net_base::keep_intermediates() const { +bool llm_build_delta_net_base::keep_rs() const { const int64_t n_seq_tokens = ubatch.n_seq_tokens; return cparams.n_rs_seq > 0 && n_seq_tokens > 1 @@ -466,7 +466,7 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state( const uint32_t mem_size = mctx_cur->get_size(); const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const bool keep = keep_intermediates(); + const bool keep = keep_rs(); ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -531,7 +531,7 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( const int64_t n_seqs = s->ne[3]; const int64_t n_seq_tokens = q->ne[2]; - if (!keep_intermediates()) { + if (!keep_rs()) { auto attn_out = build_delta_net(q, k, v, g, b, s, il); ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; diff --git a/src/models/models.h b/src/models/models.h index 8bef479754a..4e40536a5ea 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -67,7 +67,7 @@ struct llm_build_delta_net_base : public llm_graph_context { int il); // true when speculative rollback is enabled and the batch fits in the rs cache - bool keep_intermediates() const; + bool keep_rs() const; // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs)