From 1428004808efc1acbf1f76904b275acdc980e3d2 Mon Sep 17 00:00:00 2001 From: viggy <70774793+vignesh191@users.noreply.github.com> Date: Sat, 16 May 2026 03:00:46 -0700 Subject: [PATCH 01/10] webui : [ChatFormActionAdd][a11y] fix accessibility issues in add menu trigger and items (#22736) * fix tab order on attach button, and dont focus on disabled mennu item * add a11y tests --- .../ChatFormActionAddDropdown.svelte | 105 +++++++++++------- .../ChatFormActionsAdd.svelte | 6 +- .../ChatScreenForm.a11y.stories.svelte | 50 +++++++++ 3 files changed, 117 insertions(+), 44 deletions(-) create mode 100644 tools/ui/tests/stories/ChatScreenForm.a11y.stories.svelte diff --git a/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAdd/ChatFormActionAddDropdown.svelte b/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAdd/ChatFormActionAddDropdown.svelte index e053e6f836b..a6bb0fc2a80 100644 --- a/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAdd/ChatFormActionAddDropdown.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAdd/ChatFormActionAddDropdown.svelte @@ -1,11 +1,14 @@ + + { + const textarea = await canvas.findByRole('textbox'); + await userEvent.clear(textarea); + await userEvent.type(textarea, 'What is the meaning of life?'); + + const trigger = await canvas.findByRole('button', { name: ATTACHMENT_TOOLTIP_TEXT }); + + trigger.focus(); + await expect(trigger).toHaveFocus(); + + await userEvent.tab(); + + await expect(trigger).not.toHaveFocus(); + }} +/> + + { + const trigger = await canvas.findByRole('button', { name: ATTACHMENT_TOOLTIP_TEXT }); + + trigger.focus(); + await userEvent.keyboard('{Enter}'); + await screen.findByRole('menu'); + + await waitFor(() => { + expect(document.activeElement).toHaveTextContent('Text Files'); + }); + }} +/> From b81c2cdd748dc2704d5989cf03936325554c12d3 Mon Sep 17 00:00:00 2001 From: kubawoo Date: Sat, 16 May 2026 13:25:41 +0200 Subject: [PATCH 02/10] ui: Fix handling of MCP resource template parameters (#23117) * Fix handling of MCP resource template parameters * Fix formatting for uri-template.test.ts --------- Co-authored-by: kuba --- tools/ui/src/lib/utils/uri-template.ts | 2 +- tools/ui/tests/unit/uri-template.test.ts | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tools/ui/src/lib/utils/uri-template.ts b/tools/ui/src/lib/utils/uri-template.ts index 7665c98c929..eb8dbfb3632 100644 --- a/tools/ui/src/lib/utils/uri-template.ts +++ b/tools/ui/src/lib/utils/uri-template.ts @@ -160,7 +160,7 @@ export function expandTemplate(template: string, values: Record) (name: string, i: number) => `${encodeURIComponent(name)}=${encodeURIComponent(expandedParts[i])}` ) - .join(URI_TEMPLATE_SEPARATORS.COMMA) + .join(URI_TEMPLATE_SEPARATORS.QUERY_CONTINUATION) ); case URI_TEMPLATE_OPERATORS.FORM_CONTINUATION: // Form-style query continuation diff --git a/tools/ui/tests/unit/uri-template.test.ts b/tools/ui/tests/unit/uri-template.test.ts index 6221279231a..23645af3f5f 100644 --- a/tools/ui/tests/unit/uri-template.test.ts +++ b/tools/ui/tests/unit/uri-template.test.ts @@ -107,6 +107,14 @@ describe('expandTemplate', () => { expect(result).toBe('http://example.com?q=search%20term'); }); + it('expands multiple query parameters', () => { + const result = expandTemplate('http://example.com{?q,sort}', { + q: 'search term', + sort: 'descending' + }); + expect(result).toBe('http://example.com?q=search%20term&sort=descending'); + }); + it('keeps static parts unchanged', () => { const result = expandTemplate('http://example.com/static', {}); expect(result).toBe('http://example.com/static'); From 255582687b8dd211fdbc582e43ab842491554e94 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 20:06:23 +0800 Subject: [PATCH 03/10] llama + spec: MTP Support (#22673) * spec: support MTP * fix batch size * rename files * cont : simplify (#7) * MTP: clean-up (#9) * MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion * mtp -> draft-mtp * remove unused llama_arch * add need_embd in speculative * llama: allow partial seq_rm for GDN models for speculative decoding Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates. * fix pending state * vulkan: add GDN partial rollback * meta: extend check to axis 1 * metal: add GDN partial rollback Extend the gated delta net kernel to store intermediate states for partial rollback support on the Metal backend. - Add K (snapshot slot count) as a function constant - Read input state from slot 0 of the 3D state tensor - Write intermediate states to different slots during token loop - For K=1, maintain backward-compatible single-slot behavior Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc Assisted-by: llama.cpp:local pi * delta_net_base: use ggml_pad instead of new_tensor * review: add need_rs_seq * review: rename part_bounded to n_rs * review: deslop comments * review: rename, add asserts * server : adjust checkpoint logic (#11) * server : adjust checkpoint logic * cont : rm asserts * server-context: fix early exit * spec : fix compatibility with n-gram and add TODOs (#13) * metal : cleanup * llama : fix faulty bitwise check in recurrent memory * server : disable RS-based MTP in combination with other spec types * spec : add TODOs * cont : fix comment * cont : update comment * common : fix logic for ngram + mtp compat * llama-memory: enable checkpointing with partial rollback * cont: add test-case for loading into a dirty ctx * llama-memory-recurrent: clear rs_idx in clear * download: fix mtp path * llama-arch: fix enorm op * docs: update docs * conversion: fix type annotations --------- Co-authored-by: Georgi Gerganov --- common/arg.cpp | 34 +- common/common.cpp | 56 +++ common/common.h | 26 +- common/download.cpp | 55 ++- common/download.h | 7 +- common/speculative.cpp | 385 +++++++++++++++++- common/speculative.h | 3 + conversion/base.py | 6 + conversion/qwen.py | 89 +++- convert_hf_to_gguf.py | 22 + ggml/include/ggml.h | 5 + ggml/src/ggml-backend-meta.cpp | 5 +- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- ggml/src/ggml-cpu/ops.cpp | 43 +- ggml/src/ggml-cuda/gated_delta_net.cu | 88 ++-- ggml/src/ggml-metal/ggml-metal-device.cpp | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 46 ++- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +- .../vulkan-shaders/gated_delta_net.comp | 29 +- ggml/src/ggml.c | 12 +- gguf-py/gguf/constants.py | 18 +- include/llama.h | 8 + src/llama-arch.cpp | 27 +- src/llama-arch.h | 1 + src/llama-context.cpp | 157 ++++++- src/llama-context.h | 9 + src/llama-cparams.h | 3 + src/llama-ext.h | 16 + src/llama-graph.cpp | 3 +- src/llama-graph.h | 3 + src/llama-hparams.cpp | 6 + src/llama-hparams.h | 2 + src/llama-memory-hybrid-iswa.cpp | 2 + src/llama-memory-hybrid-iswa.h | 1 + src/llama-memory-hybrid.cpp | 2 + src/llama-memory-hybrid.h | 1 + src/llama-memory-recurrent.cpp | 119 +++++- src/llama-memory-recurrent.h | 8 + src/llama-memory.h | 3 + src/llama-model-loader.cpp | 13 +- src/llama-model-loader.h | 2 +- src/llama-model.cpp | 30 +- src/models/delta-net-base.cpp | 143 ++++++- src/models/models.h | 36 +- src/models/qwen35.cpp | 310 ++++++++++---- src/models/qwen35moe.cpp | 362 ++++++++++++---- src/models/qwen3next.cpp | 44 +- tests/CMakeLists.txt | 3 + tests/test-backend-ops.cpp | 21 +- tests/test-recurrent-state-rollback.cpp | 185 +++++++++ tools/cli/README.md | 7 +- tools/completion/README.md | 5 +- tools/server/README.md | 19 +- tools/server/server-context.cpp | 141 ++++--- 54 files changed, 2226 insertions(+), 412 deletions(-) create mode 100644 tests/test-recurrent-state-rollback.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 2129a9c7266..84b3c8f962d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -337,11 +337,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa struct handle_model_result { bool found_mmproj = false; common_params_model mmproj; + + bool found_mtp = false; + common_params_model mtp; }; static handle_model_result common_params_handle_model(struct common_params_model & model, const std::string & bearer_token, - bool offline) { + bool offline, + bool search_mtp = false) { handle_model_result result; if (!model.docker_repo.empty()) { @@ -356,7 +360,7 @@ static handle_model_result common_params_handle_model(struct common_params_model common_download_opts opts; opts.bearer_token = bearer_token; opts.offline = offline; - auto download_result = common_download_model(model, opts, true); + auto download_result = common_download_model(model, opts, true, search_mtp); if (download_result.model_path.empty()) { throw std::runtime_error("failed to download model from Hugging Face"); @@ -369,6 +373,11 @@ static handle_model_result common_params_handle_model(struct common_params_model result.found_mmproj = true; result.mmproj.path = download_result.mmproj_path; } + + if (!download_result.mtp_path.empty()) { + result.found_mtp = true; + result.mtp.path = download_result.mtp_path; + } } else if (!model.url.empty()) { if (model.path.empty()) { auto f = string_split(model.url, '#').front(); @@ -436,7 +445,11 @@ static bool parse_bool_value(const std::string & value) { // void common_params_handle_models(common_params & params, llama_example curr_ex) { - auto res = common_params_handle_model(params.model, params.hf_token, params.offline); + const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), + params.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); + + auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -450,6 +463,14 @@ void common_params_handle_models(common_params & params, llama_example curr_ex) break; } } + // when --spec-type mtp is set and no draft model was provided explicitly, + // fall back to the MTP head discovered alongside the -hf model + if (spec_type_draft_mtp && res.found_mtp && + params.speculative.draft.mparams.path.empty() && + params.speculative.draft.mparams.hf_repo.empty() && + params.speculative.draft.mparams.url.empty()) { + params.speculative.draft.mparams.path = res.mtp.path; + } common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline); common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); } @@ -3608,8 +3629,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("comma-separated list of types of speculative decoding to use (default: %s)\n", common_speculative_type_name_str(params.speculative.types).c_str()), [](common_params & params, const std::string & value) { - const auto enabled_types = string_split(value, ','); - params.speculative.types = common_speculative_types_from_names(enabled_types); + const auto types_str = string_split(value, ','); + auto types = common_speculative_types_from_names(types_str); + params.speculative.types.insert(params.speculative.types.end(), types.begin(), types.end()); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE")); add_opt(common_arg( @@ -4098,7 +4120,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--spec-default"}, string_format("enable default speculative decoding config"), [](common_params & params) { - params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD }; + params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MOD); params.speculative.ngram_mod.n_match = 24; params.speculative.ngram_mod.n_min = 48; params.speculative.ngram_mod.n_max = 64; diff --git a/common/common.cpp b/common/common.cpp index b701edddb3f..8b6d182f549 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -7,6 +7,7 @@ #include "log.h" #include "llama.h" #include "sampling.h" +#include "speculative.h" #include "unicode.h" #include @@ -1247,6 +1248,29 @@ common_init_result::common_init_result(common_params & params) : cparams.n_samplers = pimpl->samplers_seq_config.size(); } + // [TAG_RS_STATE_ROLLBACK_SUPPORT] + // TODO: ngram speculative methods require checkpointing in addition to partial RS rollback + // currently this is not supported. so we disable the partial rollback + if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) { + auto & types = params.speculative.types; + + for (int i = 0; i < (int) types.size(); i++) { + if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) { + continue; + } + if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) { + continue; + } + + cparams.n_rs_seq = 0; + + LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__, + common_speculative_type_to_str(types[i]).c_str()); + + break; + } + } + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); @@ -1435,6 +1459,12 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { goto done; } + if (llama_n_rs_seq(ctx) > 0) { + LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__); + res = COMMON_CONTEXT_SEQ_RM_TYPE_RS; + goto done; + } + // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { LOG_TRC("%s: the context does not support partial sequence removal\n", __func__); @@ -1449,6 +1479,23 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { return res; } +void common_context_seq_rm(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + auto * mem = llama_get_memory(ctx); + if (!llama_memory_seq_rm(mem, seq_id, p0, p1)) { + GGML_ABORT("%s", string_format("failed to remove sequence %d with p0=%d, p1=%d\n", seq_id, p0, p1).c_str()); + } +} + +void common_context_seq_cp(llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + auto * mem = llama_get_memory(ctx); + llama_memory_seq_cp(mem, seq_id_src, seq_id_dst, p0, p1); +} + +void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + auto * mem = llama_get_memory(ctx); + llama_memory_seq_add(mem, seq_id, p0, p1, delta); +} + void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { std::vector loras; std::vector scales; @@ -1505,6 +1552,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; + cparams.n_rs_seq = params.speculative.need_n_rs_seq(); cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; @@ -2074,3 +2122,11 @@ void common_prompt_checkpoint::load_dft( GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n); } } + +void common_prompt_checkpoint::clear_tgt() { + data_tgt.clear(); +} + +void common_prompt_checkpoint::clear_dft() { + data_dft.clear(); +} diff --git a/common/common.h b/common/common.h index c6223c4b515..4cca9d71568 100644 --- a/common/common.h +++ b/common/common.h @@ -13,6 +13,7 @@ #include #include #include +#include #if defined(_WIN32) && !defined(_WIN32_WINNT) #define _WIN32_WINNT 0x0A00 @@ -159,6 +160,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding + COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values @@ -301,7 +303,7 @@ struct common_params_speculative_draft { int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding float p_split = 0.1f; // speculative decoding split probability - float p_min = 0.75f; // minimum speculative decoding probability (greedy) + float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f common_params_model mparams; @@ -355,6 +357,14 @@ struct common_params_speculative { bool has_dft() const { return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty(); } + + uint32_t need_n_rs_seq() const { + bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) { + return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; + }); + + return needs_rs_seq ? draft.n_max : 0u; + } }; struct common_params_vocoder { @@ -884,15 +894,20 @@ std::string common_get_model_endpoint(); // enum common_context_seq_rm_type { - COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) - COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences - COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only + COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) + COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences + COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only + COMMON_CONTEXT_SEQ_RM_TYPE_RS = 3, // can seq_rm partial sequences, bounded by n_rs_seq }; // check if the llama_context can remove sequences // note: clears the memory of the context common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx); +// aborts execution on failure +void common_context_seq_rm (llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); +void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); +void common_context_seq_cp (llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); // // Batch utils @@ -1074,4 +1089,7 @@ struct common_prompt_checkpoint { llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) const; + + void clear_tgt(); + void clear_dft(); }; diff --git a/common/download.cpp b/common/download.cpp index 0bf12ad4a3b..103bc408faf 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files, return result; } -static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, - const std::string & model) { +// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"), +// preferring deeper shared directory prefix with the model, then closest quantization +static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files, + const std::string & model, + const std::string & keyword) { hf_cache::hf_file best; size_t best_depth = 0; int best_diff = 0; @@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, for (const auto & f : files) { if (!string_ends_with(f.path, ".gguf") || - f.path.find("mmproj") == std::string::npos) { + f.path.find(keyword) == std::string::npos) { continue; } - auto mmproj_parts = string_split(f.path, '/'); - auto mmproj_dir = mmproj_parts.end() - 1; + auto sib_parts = string_split(f.path, '/'); + auto sib_dir = sib_parts.end() - 1; auto [_, dir] = std::mismatch(model_parts.begin(), model_dir, - mmproj_parts.begin(), mmproj_dir); - if (dir != mmproj_dir) { + sib_parts.begin(), sib_dir); + if (dir != sib_dir) { continue; } - size_t depth = dir - mmproj_parts.begin(); + size_t depth = dir - sib_parts.begin(); auto bits = extract_quant_bits(f.path); auto diff = std::abs(bits - model_bits); @@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, return best; } +static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files, + const std::string & model) { + return find_best_sibling(files, model, "mmproj"); +} + +static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files, + const std::string & model) { + return find_best_sibling(files, model, "mtp-"); +} + static bool gguf_filename_is_model(const std::string & filepath) { if (!string_ends_with(filepath, ".gguf")) { return false; @@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) { } return filename.find("mmproj") == std::string::npos && - filename.find("imatrix") == std::string::npos; + filename.find("imatrix") == std::string::npos && + filename.find("mtp-") == std::string::npos; } static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files, @@ -673,11 +687,13 @@ struct hf_plan { hf_cache::hf_file primary; hf_cache::hf_files model_files; hf_cache::hf_file mmproj; + hf_cache::hf_file mtp; }; static hf_plan get_hf_plan(const common_params_model & model, const common_download_opts & opts, - bool download_mmproj) { + bool download_mmproj, + bool download_mtp) { hf_plan plan; hf_cache::hf_files all; @@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model, plan.mmproj = find_best_mmproj(all, primary.path); } + if (download_mtp) { + plan.mtp = find_best_mtp(all, primary.path); + } + return plan; } @@ -756,7 +776,8 @@ static std::vector get_url_tasks(const common_params_model & mode common_download_model_result common_download_model(const common_params_model & model, const common_download_opts & opts, - bool download_mmproj) { + bool download_mmproj, + bool download_mtp) { common_download_model_result result; std::vector tasks; hf_plan hf; @@ -764,13 +785,16 @@ common_download_model_result common_download_model(const common_params_model & bool is_hf = !model.hf_repo.empty(); if (is_hf) { - hf = get_hf_plan(model, opts, download_mmproj); + hf = get_hf_plan(model, opts, download_mmproj, download_mtp); for (const auto & f : hf.model_files) { tasks.push_back({f.url, f.local_path}); } if (!hf.mmproj.path.empty()) { tasks.push_back({hf.mmproj.url, hf.mmproj.local_path}); } + if (!hf.mtp.path.empty()) { + tasks.push_back({hf.mtp.url, hf.mtp.local_path}); + } } else if (!model.url.empty()) { tasks = get_url_tasks(model); } else { @@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model & if (!hf.mmproj.path.empty()) { result.mmproj_path = hf_cache::finalize_file(hf.mmproj); } + + if (!hf.mtp.path.empty()) { + result.mtp_path = hf_cache::finalize_file(hf.mtp); + } } else { result.model_path = model.path; } @@ -946,7 +974,8 @@ std::vector common_list_cached_models() { for (const auto & f : files) { auto split = get_gguf_split_info(f.path); if (split.index != 1 || split.tag.empty() || - split.prefix.find("mmproj") != std::string::npos) { + split.prefix.find("mmproj") != std::string::npos || + split.prefix.find("mtp-") != std::string::npos) { continue; } if (seen.insert(f.repo_id + ":" + split.tag).second) { diff --git a/common/download.h b/common/download.h index edc3e9f1a71..4a169ef7796 100644 --- a/common/download.h +++ b/common/download.h @@ -59,6 +59,7 @@ struct common_download_opts { struct common_download_model_result { std::string model_path; std::string mmproj_path; + std::string mtp_path; }; // Download model from HuggingFace repo or URL @@ -83,12 +84,14 @@ struct common_download_model_result { // when opts.offline=true, no network requests are made // when download_mmproj=true, searches for mmproj in same directory as model or any parent directory // then with the closest quantization bits +// when download_mtp=true, applies the same sibling search for an MTP-head GGUF // -// returns result with model_path and mmproj_path (empty on failure) +// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure) common_download_model_result common_download_model( const common_params_model & model, const common_download_opts & opts = {}, - bool download_mmproj = false + bool download_mmproj = false, + bool download_mtp = false ); // returns list of cached models diff --git a/common/speculative.cpp b/common/speculative.cpp index 476e1398ed8..3488b9393c5 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -3,6 +3,7 @@ #include "common.h" #include "ggml.h" #include "llama.h" +#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP) #include "log.h" #include "ngram-cache.h" #include "ngram-map.h" @@ -23,6 +24,7 @@ const std::map common_speculative_type_fro {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE}, {"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3}, + {"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP}, {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -143,6 +145,9 @@ struct common_speculative_impl { virtual void draft(common_speculative_draft_params_vec & dparams) = 0; virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; + + // true if this implementation requires the target context to extract embeddings + virtual bool need_embd() const = 0; }; struct common_speculative_impl_draft_simple : public common_speculative_impl { @@ -338,6 +343,10 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { @@ -362,6 +371,328 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } +}; + +struct common_speculative_state_draft_mtp : public common_speculative_impl { + common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) + + llama_batch batch; + + std::vector smpls; + + int32_t n_embd = 0; + + // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. + // The last h-row of one process() call needs the first token of the NEXT + // call to pair with, so it's stashed here until that next call fires. + std::vector> pending_h; // [n_seq][n_embd] + + std::vector i_batch_beg; + std::vector i_batch_end; + + // Hidden rows from the most recent target verification batch, grouped by seq. + // Row 0 corresponds to the sampled token, row N to the Nth accepted draft token. + std::vector> verify_h; + std::vector verify_h_rows; + + // Per-seq draft length from the last draft() call, used in accept() to + // roll back ctx_dft's recurrent state past the AR draft's redundant + // pre-advancement before process() mirrored the verify batch. + std::vector last_n_drafted; + + common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) + , params(params.draft) + { + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); + + n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); + + const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); + batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1); + // llama_batch_init allocates only one of token/embd; MTP needs both. + // TODO: fix, how to call without malloc + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b); + + smpls.resize(n_seq); + for (auto & s : smpls) { + common_params_sampling sparams; + sparams.no_perf = false; + sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param + sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); + } + + llama_set_embeddings_pre_norm(ctx_tgt, true); + llama_set_embeddings_pre_norm(ctx_dft, true); + + pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); + + i_batch_beg.assign(n_seq, -1); + i_batch_end.assign(n_seq, -1); + + verify_h.assign(n_seq, {}); + verify_h_rows.assign(n_seq, 0); + + last_n_drafted.assign(n_seq, 0); + } + + ~common_speculative_state_draft_mtp() override { + if (batch.token != nullptr) { + free(batch.token); + batch.token = nullptr; + } + llama_batch_free(batch); + } + + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + const int32_t N = (int32_t) prompt.size(); + if (N <= 0) { + return; + } + auto * ctx_dft = this->params.ctx_dft; + const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); + if (pos_max < N - 1) { + LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — " + "process() hook may not have run on every prefill ubatch " + "(need_embd / logits=1 on every prompt position?). " + "Drafts may degrade.\n", + __func__, (int) pos_max, N - 1); + } + } + + bool process(const llama_batch & batch_in) override { + if (batch_in.n_tokens <= 0) { + return true; + } + + // TODO: how to make it work with vision tokens? + if (batch_in.token == nullptr || batch_in.embd != nullptr) { + return true; + } + + const int32_t n_tokens = batch_in.n_tokens; + + // remember the frist and last batch index for each sequence + std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1); + std::fill(i_batch_end.begin(), i_batch_end.end(), -1); + + for (int k = 0; k < n_tokens; ++k) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + GGML_ASSERT(batch_in.n_seq_id[k] == 1); + + if (batch_in.seq_id[k][0] == seq_id) { + i_batch_end[seq_id] = k; + if (i_batch_beg[seq_id] < 0) { + i_batch_beg[seq_id] = k; + } + } + } + } + + auto * ctx_tgt = this->params.ctx_tgt; + auto * ctx_dft = this->params.ctx_dft; + + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + common_batch_clear(batch); + + for (int k = 0; k < n_tokens; ++k) { + common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); + } + + // shift the tgt embeddings to the right by one position + // assumes that the tokens in the batch are sequential for each sequence + // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] + // ^--- this is a problem + // TODO:this is generally true, but would be nice to assert it + { + const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt); + std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); + + //{ + // // string with seq_ids in the batch + // std::stringstream ss; + // for (int i = 0; i < n_tokens; ++i) { + // ss << batch_in.seq_id[i][0] << ","; + // } + // LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str()); + //} + } + + // fill the pending embeddings from a previous run + auto set_h = [&](int idx, const float * h_row) { + std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); + }; + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; + } + + set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); + } + + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + return false; + } + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_end[seq_id] < 0) { + continue; + } + + const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1; + verify_h_rows[seq_id] = n_rows; + verify_h[seq_id].resize((size_t) n_rows * n_embd); + + for (int32_t i = 0; i < n_rows; ++i) { + const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i); + std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes); + } + + std::memcpy(pending_h[seq_id].data(), + verify_h[seq_id].data() + (size_t) (n_rows - 1) * n_embd, row_bytes); + } + + return true; + } + + void draft(common_speculative_draft_params_vec & dparams) override { + auto & ctx_dft = params.ctx_dft; + + common_batch_clear(batch); + + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); + + const float * h_row = nullptr; + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + + if (!dp.drafting) { + continue; + } + + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); + + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); + + h_row = pending_h[seq_id].data(); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + } + + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); + return; + } + + int i = 0; + + while (n_drafting > 0) { + int i_batch = 0; + + common_batch_clear(batch); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; + } + + auto * smpl = smpls[seq_id].get(); + + common_sampler_sample(smpl, ctx_dft, i_batch, true); + h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch); + ++i_batch; + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; + + result.push_back(id); + + if (params.n_max <= (int) result.size()) { + drafting[seq_id] = false; + n_drafting--; + continue; + } + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + } + + if (batch.n_tokens == 0) { + break; + } + + // evaluate the drafted tokens on the draft model + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; + } + + ++i; + } + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } + + last_n_drafted[seq_id] = (uint16_t) dp.result->size(); + } + } + + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { + return; + } + + const int32_t n_rows = verify_h_rows[seq_id]; + if (n_rows <= 0) { + return; + } + + const int32_t i_h = std::min(n_accepted, n_rows - 1); + const size_t row_bytes = (size_t) n_embd * sizeof(float); + std::memcpy(pending_h[seq_id].data(), verify_h[seq_id].data() + (size_t) i_h * n_embd, row_bytes); + } + + bool need_embd() const override { + return true; + } }; // state of self-speculation (simple implementation, not ngram-map) @@ -403,6 +734,10 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_map_k : public common_speculative_impl { @@ -451,6 +786,10 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl { common_ngram_map_accept(config[seq_id], n_accepted); } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_mod : public common_speculative_impl { @@ -619,6 +958,10 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl { } } } + + bool need_embd() const override { + return false; + } }; struct common_speculative_impl_ngram_cache : public common_speculative_impl { @@ -752,6 +1095,10 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl { void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } + + bool need_embd() const override { + return false; + } }; struct common_speculative { @@ -820,6 +1167,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple"; case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3"; + case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; @@ -875,8 +1223,8 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_draft_model_path = !params.draft.mparams.path.empty(); bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); - // bool has_mtp = false; // TODO: add MTP here bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); @@ -885,7 +1233,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD)); // when adding a new type - update here the logic above - static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8); + static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9); // this list here defines the priority of the speculators // the one with highest priority are listed first @@ -911,7 +1259,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__); has_draft_simple = false; } - } else if (has_draft_model_path) { + } else if (has_draft_model_path && !has_mtp && !has_draft_eagle3) { LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__); has_draft_simple = true; } @@ -919,10 +1267,12 @@ common_speculative * common_speculative_init(common_params_speculative & params, if (has_draft_simple) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params)); } - // TODO: add MTP here if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } + if (has_mtp) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params)); + } } std::vector> impls = {}; @@ -940,6 +1290,10 @@ common_speculative * common_speculative_init(common_params_speculative & params, impls.push_back(std::make_unique(config.params, n_seq)); break; } + case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: { + impls.push_back(std::make_unique(config.params, n_seq)); + break; + } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple); @@ -1040,6 +1394,20 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b return result; } +bool common_speculative_need_embd(common_speculative * spec) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->need_embd()) { + return true; + } + } + + return false; +} + void common_speculative_draft(common_speculative * spec) { if (spec == nullptr) { return; @@ -1122,14 +1490,15 @@ void common_speculative_draft(common_speculative * spec) { } void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { - if (n_accepted == 0) { - return; - } - common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); + // TODO: currently only the implementation that generated the draft is used to accept it + // however, some implementations (such as MTP) need to also "see" the accepted tokens + // extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to + // inform the implementation if the accepted tokens are from another implementation and + // pass the accepted tokens to all remaining implementations using `is_other == true` { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); if (n_accepted > 0) { diff --git a/common/speculative.h b/common/speculative.h index 51f0b059fa4..614db9b1b50 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -53,6 +53,9 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co // process the batch and update the internal state of the speculative context bool common_speculative_process(common_speculative * spec, const llama_batch & batch); +// true if any implementation requires target embeddings to be extracted +bool common_speculative_need_embd(common_speculative * spec); + // generate drafts for the sequences specified with `common_speculative_get_draft_params` void common_speculative_draft(common_speculative * spec); diff --git a/conversion/base.py b/conversion/base.py index d89d32fe150..30c2124c2b9 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -91,6 +91,7 @@ class ModelBase: gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None + metadata: gguf.Metadata dir_model_card: Path remote_hf_model_id: str | None @@ -106,6 +107,11 @@ class ModelBase: disable_mistral_community_chat_template: bool = False sentence_transformers_dense_modules: bool = False + # MTP (multi-token prediction) export modes; set by main() before instantiation. + # Architectures opt in by overriding the handling (see _Qwen35MtpMixin). + mtp_only: bool = False + no_mtp: bool = False + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, diff --git a/conversion/qwen.py b/conversion/qwen.py index 919ecddcb91..4b86404262a 100644 --- a/conversion/qwen.py +++ b/conversion/qwen.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Callable, Iterable, TYPE_CHECKING +from pathlib import Path +from typing import Any, Callable, Iterable, TYPE_CHECKING import torch @@ -534,11 +535,93 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION) +class _Qwen35MtpMixin: + """Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries + the MTP block under `mtp_num_hidden_layers` and the tensors under + `mtp.*`; we extend block_count, emit the nextn metadata key, and remap + `mtp.*` to the standard layer-indexed nextn naming so the existing + tensor_map handles them.""" + + hparams: dict[str, Any] + model_arch: gguf.MODEL_ARCH + gguf_writer: gguf.GGUFWriter + block_count: int + tensor_map: gguf.TensorNameMap + no_mtp: bool + mtp_only: bool + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + if not self.no_mtp: + self.block_count += self.hparams.get("mtp_num_hidden_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + @classmethod + def filter_tensors(cls, item): + name, _ = item + if name.startswith("mtp."): + if cls.no_mtp: + return None + return item + if cls.mtp_only: + canonical = name.replace("language_model.", "") + keep = canonical in ( + "model.embed_tokens.weight", "model.norm.weight", "lm_head.weight", + "embed_tokens.weight", "norm.weight", + ) + if not keep: + return None + return super().filter_tensors(item) # ty: ignore[unresolved-attribute] + + def set_gguf_parameters(self): + super().set_gguf_parameters() # ty: ignore[unresolved-attribute] + if self.no_mtp: + return + if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0: + self.gguf_writer.add_nextn_predict_layers(n) + + def prepare_metadata(self, vocab_only: bool): + from_dir = self.fname_out.is_dir() + super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute] + + if not self.mtp_only or not from_dir: + return + + output_type: str = self.ftype.name.partition("_")[2] # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + fname_default: str = gguf.naming_convention( + self.metadata.name, self.metadata.basename, self.metadata.finetune, # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + self.metadata.version, size_label=None, output_type=output_type, model_type=None) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf" + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("mtp."): + n_layer = self.hparams["num_hidden_layers"] + if name.find("layers.") != -1: + assert bid is not None + name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}") + else: + remapper = { + "mtp.fc": "model.layers.{bid}.eh_proj", + "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm", + "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm", + "mtp.norm": "model.layers.{bid}.shared_head.norm", + } + stem = Path(name).stem + suffix = Path(name).suffix + tmpl = remapper[stem] + suffix + for b in range(n_layer, self.block_count): + yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute] + return + + yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute] + + @ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM") -class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase): +class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35 @ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM") -class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase): +class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35MOE diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7173f616009..ff840050861 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -117,6 +117,14 @@ def parse_args() -> argparse.Namespace: "--mmproj", action="store_true", help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", ) + parser.add_argument( + "--mtp", action="store_true", + help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.", + ) + parser.add_argument( + "--no-mtp", action="store_true", + help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.", + ) parser.add_argument( "--mistral-format", action="store_true", help="Whether the model is stored following the Mistral format.", @@ -233,6 +241,20 @@ def main() -> None: from conversion.mistral import MistralModel model_class = MistralModel + if args.mtp and args.no_mtp: + logger.error("--mtp and --no-mtp are mutually exclusive") + sys.exit(1) + + if args.mtp or args.no_mtp: + from conversion.qwen import _Qwen35MtpMixin + if not issubclass(model_class, _Qwen35MtpMixin): + logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today") + sys.exit(1) + if args.no_mtp: + model_class.no_mtp = True + if args.mtp: + model_class.mtp_only = True + model_instance = model_class(dir_model, output_type, fname_out, is_big_endian=args.bigendian, use_temp_file=args.use_temp_file, eager=args.no_lazy, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d9985..41566d41aef 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2541,6 +2541,11 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): + // K == 1: output carries the final state only. + // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c0ffd9a048b..df0f405ed9f 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, + // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; @@ -2140,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; return backend_ctx->backend_configs[index].backend; } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7b05edf6b75..cd5c61a8187 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - cur = S_v * sizeof(float) * n_tasks; + const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150ce..7485ba4fc86 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // scratch layout per thread: [delta(S_v)] - const int64_t scratch_per_thread = S_v; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int64_t K = src_state->ne[1]; + GGML_ASSERT(K >= 1); + // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; - float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; // output layout: [attn_scores | new_states] - // attn_scores: S_v * H * n_tokens * n_seqs floats - // new_states: S_v * S_v * H * n_seqs floats - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int64_t shift = n_tokens - K; + const float * state_in_base = (const float *)src_state->data; //const int64_t rq1 = nev1 / neq1; @@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; - float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; - // copy input state into output buffer and operate in-place - const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + // copy input state into the working buffer and operate in-place + // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); // attn output pointer for first token of this (head, seq) @@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } } } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec7317..b4c9845e7a7 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q, int64_t sb3, const uint3 neqk1_magic, const uint3 rq3_magic, - float scale) { + float scale, + int K) { const uint32_t h_idx = blockIdx.x; const uint32_t sequence = blockIdx.y; // each warp owns one column, using warp-level primitives to reduce across rows @@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset + col * S_v; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q, } attn_data += S_v * H; + + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } } - // Write state back to global memory (transposed layout) + if constexpr (!keep_rs_t) { #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -155,7 +177,7 @@ static void launch_gated_delta_net( int64_t sv1, int64_t sv2, int64_t sv3, int64_t sb1, int64_t sb2, int64_t sb3, int64_t neqk1, int64_t rq3, - float scale, cudaStream_t stream) { + float scale, int K, cudaStream_t stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int num_warps = 4; @@ -169,29 +191,29 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( + gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( + gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( + gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( + gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } default: @@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f0147af84c1..e288a27f992 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = op->src[5]->ne[1]; const int nsg = op->src[2]->ne[0]/32; @@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( GGML_ASSERT(ne20 % 32 == 0); snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); - snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3882b955847..82e29d5ad7c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2548,21 +2549,24 @@ kernel void kernel_gated_delta_net_impl( uint3 ntg[[threads_per_threadgroup]]) { #define S_v FC_gated_delta_net_ne20 #define G FC_gated_delta_net_ne30 +#define K FC_gated_delta_net_K const uint tx = tpitg.x; const uint ty = tpitg.y; - const uint i23 = tgpig.z; // B - const uint i21 = tgpig.y; // H - const uint i20 = tgpig.x*NSG + ty; + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v const uint i01 = i21 % args.ne01; const uint i11 = i21 % args.ne11; const float scale = 1.0f / sqrt((float)S_v); + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2580,6 +2584,17 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = (int)args.ne22 - (int)K; + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; @@ -2627,17 +2642,30 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } + } - FOR_UNROLL (short j = 0; j < NSG; j++) { - const short is = tx*NSG + j; - dst_state[is] = ls[j]; + if (K == 1u) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } } #undef S_v #undef G +#undef K } typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8c4cf9ef1db..d29a4bab2e2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants { uint32_t sb1, sb2, sb3; uint32_t neq1, rq3; float scale; + uint32_t K; }; struct vk_op_ssm_scan_push_constants { @@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const uint32_t K = (uint32_t)src_state->ne[1]; + const uint32_t s_off = S_v * H * n_tokens * n_seqs; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); @@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + scale, + K }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 5e9f8308c1d..33c3202dbb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters { uint sb1, sb2, sb3; uint neq1, rq3; float scale; + uint K; }; layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; @@ -101,13 +102,21 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; FLOAT_TYPE s_shard[ROWS_PER_LANE]; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = int(n_tokens) - int(K); + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -161,9 +170,21 @@ void main() { } attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(t) - shift; + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } } - [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa106..476c3079795 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. + GGML_ASSERT(state->ne[0] == S_v * S_v * H); + GGML_ASSERT(state->ne[2] == n_seqs); + GGML_ASSERT(state->ne[3] == 1); + const int64_t K = state->ne[1]; + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4055ec2873a..c25f217f990 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2114,7 +2114,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.QWEN35MOE: [ MODEL_TENSOR.TOKEN_EMBD, @@ -2145,7 +2152,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/include/llama.h b/include/llama.h index 308e8ba9dbd..75095b22d08 100644 --- a/include/llama.h +++ b/include/llama.h @@ -198,6 +198,11 @@ extern "C" { LLAMA_SPLIT_MODE_TENSOR = 3, }; + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -333,9 +338,11 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings @@ -530,6 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 59dde99e362..c9eead18aa3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -757,14 +757,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -877,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + bool llm_arch_supports_sm_tensor(const llm_arch & arch) { switch (arch) { case LLM_ARCH_GROK: diff --git a/src/llama-arch.h b/src/llama-arch.h index e37d548c98e..89cf16cc37c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3d9714ab166..d62abc4009b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -21,6 +22,14 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -42,6 +51,13 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; @@ -49,6 +65,7 @@ llama_context::llama_context( cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -65,6 +82,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_type = params.ctx_type; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -206,6 +225,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -278,6 +298,7 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, + /*.ctx_type= */ cparams.ctx_type, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -860,6 +881,33 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const int64_t j = output_resolve_row(i); + const uint32_t n_embd = model.hparams.n_embd; + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1088,12 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.embeddings_pre_norm = value; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1241,7 +1295,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1368,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1436,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1598,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map remove all positions of that ubatch from the memory module @@ -1727,8 +1797,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1880,20 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd; + + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float)); + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1893,10 +1978,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +1995,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +2013,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2030,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2059,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2126,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -2121,7 +2219,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3100,7 +3198,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3201,8 +3299,10 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -3306,6 +3406,13 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.nextn_predict_layers == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + + try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3347,6 +3454,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } @@ -3436,6 +3547,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) { + ctx->set_embeddings_pre_norm(value); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/src/llama-context.h b/src/llama-context.h index 92d1b0cf95a..e16ac4c618b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 9d359474132..5898a1c38d5 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -12,6 +12,7 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -27,6 +28,7 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm bool causal_attn; bool offload_kqv; bool flash_attn; @@ -40,6 +42,7 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/src/llama-ext.h b/src/llama-ext.h index 8ce29d217cb..11f1986676a 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// mirrors: +// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); + +// mirrors: +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index fe155c92dea..858c297dd76 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2528,7 +2528,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. diff --git a/src/llama-graph.h b/src/llama-graph.h index 5cb1756c6a9..9e55d0a675e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -32,6 +32,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { @@ -644,6 +645,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +674,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 002d15d415f..2239309c8fb 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 0160a89caa2..e2d051edc6c 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index 10e6b459797..a59561ea54d 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr diff --git a/src/llama-memory-hybrid-iswa.h b/src/llama-memory-hybrid-iswa.h index 807c8aac96c..c9d3f9f57c5 100644 --- a/src/llama-memory-hybrid-iswa.h +++ b/src/llama-memory-hybrid-iswa.h @@ -34,6 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 4ce1af592c1..fd305cab79c 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 558cafdf984..484eafb7499 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index c07f1d969cb..aeb866657f2 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; @@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent( size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); - ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -115,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } @@ -138,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) { ggml_backend_buffer_clear(buf.get(), 0); } } + + std::fill(rs_idx.begin(), rs_idx.end(), 0); } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -152,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } + const bool rm_all = p0 == 0 && p1 == std::numeric_limits::max(); + if (rm_all) { + if (seq_id >= 0) { + set_rs_idx(seq_id, 0); + } else { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } + } + // models like Mamba or RWKV can't have a state partially erased at the end // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { @@ -161,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -368,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -703,6 +731,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq GGML_UNUSED(flags); std::vector> cell_ranges; // ranges, from inclusive, to exclusive + std::vector> cell_ranges_data; // logical source row ranges uint32_t cell_count = 0; // Count the number of cells with the specified seq_id @@ -712,6 +741,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq const auto & cell = cells[i]; if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { ++cell_count; + uint32_t rs_idx_cur = 0; + + if (n_rs_seq != 0) { + if (seq_id != -1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size()); + rs_idx_cur = rs_idx[seq_id]; + } else { + bool has_rs_idx = false; + for (const llama_seq_id cell_seq_id : cell.seq_id) { + GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size()); + + const uint32_t seq_rs_idx = rs_idx[cell_seq_id]; + if (!has_rs_idx) { + rs_idx_cur = seq_rs_idx; + has_rs_idx = true; + } else if (rs_idx_cur != seq_rs_idx) { + GGML_ABORT("cannot write shared recurrent state with different rollback indices"); + } + } + } + } + + const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i); + if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) { + cell_ranges_data.emplace_back(cell_id, cell_id + 1); + } else { + cell_ranges_data.back().second++; + } + if (cell_range_begin == size) { cell_range_begin = i; } @@ -726,7 +784,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } - if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) { GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); } @@ -737,10 +795,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq } GGML_ASSERT(cell_count == cell_count_check); + cell_count_check = 0; + for (const auto & range : cell_ranges_data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + io.write(&cell_count, sizeof(cell_count)); state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + state_write_data(io, cell_ranges_data); } void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -762,6 +826,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i } throw std::runtime_error("failed to restore kv cache"); } + + if (n_rs_seq != 0) { + if (seq_id == -1) { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } else { + set_rs_idx(seq_id, 0); + } + } } void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { @@ -804,7 +876,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Write each range of cells of r_size_row length + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -825,7 +898,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Write each range of S tensor rows + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -852,9 +926,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Write GQA embedding size io.write(&n_embd_s, sizeof(n_embd_s)); - // For each row, we get the element values of each cell + // For each row, we get the element values of each logical cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; @@ -1163,5 +1236,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d73912..29c58afc9c2 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i { bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,13 @@ class llama_memory_recurrent : public llama_memory_i { uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + // per-seq rollback index + std::vector rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/src/llama-memory.h b/src/llama-memory.h index 4a157b91fdb..4ad1612e45b 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include #include @@ -20,6 +21,8 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; }; enum llama_memory_status { diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d..c645d0785ab 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 7b3d6703c03..c476026d3e5 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 46ae010f800..8bf20a716eb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1947,6 +1947,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, @@ -1955,8 +1961,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, + cparams.n_rs_seq, nullptr); - } else if (llm_arch_is_hybrid(arch)) { + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -1971,6 +1978,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t il) { return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter_attn = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && !hparams.is_recurrent(il); + }; + filter_recr = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && hparams.is_recurrent(il); + }; } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { @@ -1988,6 +2003,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_s */ GGML_TYPE_F32, /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2006,6 +2022,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2013,6 +2030,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_filter_cb filter = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2024,6 +2042,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } + if (mtp_on_hybrid_qwen35) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2039,7 +2062,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, - nullptr, + filter, reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2056,7 +2079,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, - nullptr, + filter, nullptr); } } @@ -2159,6 +2182,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 6bc989c9509..2a4e00384e9 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -1,6 +1,7 @@ #include "models.h" #include "llama-impl.h" +#include "llama-memory-recurrent.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -397,7 +398,9 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -443,3 +446,141 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +bool llm_build_delta_net_base::keep_rs() const { + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + return cparams.n_rs_seq > 0 + && n_seq_tokens > 1 + && (uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq; +} + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const bool keep = keep_rs(); + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + if (!keep) { + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + } else { + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + const size_t row_size = row_count * ggml_element_size(conv_states_all); + for (int64_t t = 1; t <= n_seq_tokens; ++t) { + const uint32_t slot = (uint32_t)(n_seq_tokens - t); + ggml_tensor * src = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + t * ggml_element_size(conv_input)); + ggml_tensor * dst = + ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + ((size_t) slot * mem_size + kv_head) * row_size); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + if (!keep_rs()) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = (int64_t) cparams.n_rs_seq + 1; + + // TODO: remove pad + simplify + ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0); + + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d); + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + for (int64_t k_i = 0; k_i < K; ++k_i) { + const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); + ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, + S_v, S_v, H_v, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * S_v), + ggml_row_size(gdn_out->type, S_v * S_v * H_v), + ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, + hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + ((size_t) cache_slot * mem_size + kv_head) * row_size); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + + return output; +} diff --git a/src/models/models.h b/src/models/models.h index 6d5f18a8e20..4e40536a5ea 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator + // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) std::pair build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -65,6 +65,32 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // true when speculative rollback is enabled and the batch fits in the rs cache + bool keep_rs() const; + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { @@ -1739,6 +1765,10 @@ struct llama_model_qwen35 : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; @@ -1781,6 +1811,10 @@ struct llama_model_qwen35moe : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index b188810f931..2b4d5b14cd4 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -12,16 +12,22 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -29,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -43,50 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -111,7 +157,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +208,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -297,8 +348,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -328,41 +377,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -413,7 +435,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -424,18 +446,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -471,3 +482,146 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // The MTP block lives at the source file's original layer index. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 8ec9b8c6f7d..22e3e110765 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -15,16 +15,22 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -32,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -46,60 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -124,7 +180,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +231,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -310,8 +371,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -341,41 +400,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -426,7 +458,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -437,18 +469,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -525,3 +546,178 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index bdc3026c1de..1d873427db5 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3ee535224d9..0fdbd39c94a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -252,6 +252,9 @@ llama_build_and_test(test-backend-sampler.cpp LABEL "model") llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -m "${MODEL_DEST}") set_tests_properties(test-state-restore-fragmented PROPERTIES FIXTURES_REQUIRED test-download-model) +llama_build_and_test(test-recurrent-state-rollback.cpp LABEL "model" ARGS -m "${MODEL_DEST}") +set_tests_properties(test-recurrent-state-rollback PROPERTIES FIXTURES_REQUIRED test-download-model) + if (NOT GGML_BACKEND_DL) # these tests use the backends directly and cannot be built with dynamic loading llama_build_and_test(test-barrier.cpp) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8a561c038a1..76f7cb5a867 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3832,16 +3832,17 @@ struct test_gated_delta_net : public test_case { const int v_repeat; const bool permuted; const bool kda; + const int64_t K; // snapshot slot count: 1 = final-only, >1 = last K states std::string vars() override { - return VARS_TO_STR8(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda); + return VARS_TO_STR9(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda, K); } test_gated_delta_net(ggml_type type = GGML_TYPE_F32, int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, - int v_repeat = 1, bool permuted = false, bool kda = false) + int v_repeat = 1, bool permuted = false, bool kda = false, int64_t K = 1) : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), - v_repeat(v_repeat), permuted(permuted), kda(kda) {} + v_repeat(v_repeat), permuted(permuted), kda(kda), K(K) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q; @@ -3863,7 +3864,7 @@ struct test_gated_delta_net : public test_case { const int64_t g_ne0 = kda ? head_size : 1; ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs); ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs); - ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs); + ggml_tensor * state = ggml_new_tensor_3d(ctx, type, head_size * v_repeat * head_size * head_count, K, n_seqs); ggml_set_name(g, "g"); ggml_set_name(beta, "beta"); ggml_set_name(state, "state"); @@ -9034,6 +9035,18 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true)); + // K > 1: output keeps the last min(n_tokens, K) per-token snapshots in the trailing K-token region. + // exact-match cases (K == n_seq_tokens): + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 2, 1, 1, false, false, /*K=*/2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 128, 4, 1, 1, false, false, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true, /*K=*/4)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true, /*K=*/4)); + // overflow: n_tokens > K — only the last K snapshots kept. + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 8, 1, 1, false, false, /*K=*/3)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 16, 2, 1, false, false, /*K=*/4)); + #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true)); diff --git a/tests/test-recurrent-state-rollback.cpp b/tests/test-recurrent-state-rollback.cpp new file mode 100644 index 00000000000..be19316db8a --- /dev/null +++ b/tests/test-recurrent-state-rollback.cpp @@ -0,0 +1,185 @@ +#include "arg.h" +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +static llama_context * make_ctx(const common_params & params, llama_model * model) { + auto cparams = common_context_params_to_llama(params); + cparams.n_seq_max = 1; + cparams.n_rs_seq = 8; + cparams.n_batch = std::max(cparams.n_batch, (uint32_t) (cparams.n_rs_seq + 1)); + cparams.n_ubatch = std::max(cparams.n_ubatch, (uint32_t) (cparams.n_rs_seq + 1)); + return llama_init_from_model(model, cparams); +} + +static bool decode_tokens(llama_context * ctx, const std::vector & tokens, uint32_t count) { + llama_batch batch = llama_batch_init(count, 0, 1); + for (uint32_t pos = 0; pos < count; ++pos) { + common_batch_add(batch, tokens[pos], pos, { 0 }, false); + } + const bool ok = llama_decode(ctx, batch) == 0; + llama_batch_free(batch); + return ok; +} + +static bool decode_one(llama_context * ctx, llama_token tok, llama_pos pos) { + llama_batch batch = llama_batch_init(1, 0, 1); + common_batch_add(batch, tok, pos, { 0 }, true); + const bool ok = llama_decode(ctx, batch) == 0; + llama_batch_free(batch); + return ok; +} + +int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + + common_params params; + params.sampling.seed = 1234; + params.n_predict = 1; + + common_init(); + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + ggml_backend_load_all(); + + common_init_result_ptr llama_init = common_init_from_params(params); + llama_model * model = llama_init->model(); + if (model == nullptr) { + fprintf(stderr, "%s : failed to init model\n", __func__); + return 1; + } + + if (!llama_model_is_recurrent(model) && !llama_model_is_hybrid(model)) { + fprintf(stderr, "%s : skipping for non-recurrent model\n", __func__); + return 0; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_vocab_n_tokens(vocab); + + llama_context * ctx_src = make_ctx(params, model); + llama_context * ctx_dst = make_ctx(params, model); + if (ctx_src == nullptr || ctx_dst == nullptr) { + fprintf(stderr, "%s : failed to init contexts\n", __func__); + return 1; + } + + if (llama_n_rs_seq(ctx_src) == 0) { + fprintf(stderr, "%s : skipping because n_rs_seq is disabled\n", __func__); + llama_free(ctx_src); + llama_free(ctx_dst); + return 0; + } + + std::vector tokens = common_tokenize(ctx_src, "The quick brown fox jumps", true); + const uint32_t n_rs_seq = llama_n_rs_seq(ctx_src); + if (tokens.size() > n_rs_seq + 1) { + tokens.resize(n_rs_seq + 1); + } + if (tokens.size() < 2) { + fprintf(stderr, "%s : not enough prompt tokens\n", __func__); + return 1; + } + const uint32_t n_tokens = tokens.size(); + const llama_token last_tok = tokens.back(); + const llama_pos last_pos = (llama_pos) n_tokens - 2; + + // Decode the full prompt on the source, then roll back the last position. + // Rollback leaves the recurrent memory in a snapshot state (rs_idx != 0). + if (!decode_tokens(ctx_src, tokens, n_tokens)) { + fprintf(stderr, "%s : failed to decode prompt\n", __func__); + return 1; + } + if (!llama_memory_seq_rm(llama_get_memory(ctx_src), 0, last_pos, -1)) { + fprintf(stderr, "%s : rollback failed\n", __func__); + return 1; + } + + // Save the rolled-back state and restore it into a fresh context. + common_prompt_checkpoint ckpt; + ckpt.update_tgt(ctx_src, 0, 0); + ckpt.load_tgt(ctx_dst, 0, 0); + + // Replay the rolled-back token on both contexts and compare logits. + if (!decode_one(ctx_src, last_tok, last_pos) || + !decode_one(ctx_dst, last_tok, last_pos)) { + fprintf(stderr, "%s : replay failed\n", __func__); + return 1; + } + + const float * logits_src = llama_get_logits_ith(ctx_src, 0); + const float * logits_dst = llama_get_logits_ith(ctx_dst, 0); + if (logits_src == nullptr || logits_dst == nullptr) { + fprintf(stderr, "%s : missing logits\n", __func__); + return 1; + } + + constexpr float eps = 1e-5f; + for (int i = 0; i < n_vocab; ++i) { + if (std::fabs(logits_src[i] - logits_dst[i]) > eps) { + fprintf(stderr, "%s : logits mismatch at token %d (%g != %g)\n", + __func__, i, (double) logits_src[i], (double) logits_dst[i]); + return 1; + } + } + + // Repeat the load into a context that already has its own rollback state: + // groups 1..n_rs_seq hold a *different* prompt's history, and rs_idx[0] is + // non-zero at load time. The restore must wipe that state and still match. + llama_context * ctx_dirty = make_ctx(params, model); + if (ctx_dirty == nullptr) { + fprintf(stderr, "%s : failed to init dirty ctx\n", __func__); + return 1; + } + + std::vector noise = tokens; + for (auto & t : noise) { + t = (t + 1) % n_vocab; + if (t < 0) { + t = 0; + } + } + if (!decode_tokens(ctx_dirty, noise, n_tokens)) { + fprintf(stderr, "%s : dirty prompt decode failed\n", __func__); + return 1; + } + if (!llama_memory_seq_rm(llama_get_memory(ctx_dirty), 0, last_pos, -1)) { + fprintf(stderr, "%s : dirty rollback failed\n", __func__); + return 1; + } + + ckpt.load_tgt(ctx_dirty, 0, 0); + + if (!decode_one(ctx_dirty, last_tok, last_pos)) { + fprintf(stderr, "%s : dirty replay failed\n", __func__); + return 1; + } + + const float * logits_dirty = llama_get_logits_ith(ctx_dirty, 0); + if (logits_dirty == nullptr) { + fprintf(stderr, "%s : missing dirty logits\n", __func__); + return 1; + } + + for (int i = 0; i < n_vocab; ++i) { + if (std::fabs(logits_src[i] - logits_dirty[i]) > eps) { + fprintf(stderr, "%s : dirty-ctx logits mismatch at token %d (%g != %g)\n", + __func__, i, (double) logits_src[i], (double) logits_dirty[i]); + return 1; + } + } + + fprintf(stderr, "%s : recurrent rollback checkpoint restored successfully\n", __func__); + llama_free(ctx_src); + llama_free(ctx_dst); + llama_free(ctx_dirty); + return 0; +} diff --git a/tools/cli/README.md b/tools/cli/README.md index 9f0574d25d3..c40b5a21cc0 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -55,7 +55,6 @@ | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | -| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)
(env: LLAMA_ARG_RPC) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | | `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | @@ -94,8 +93,8 @@ | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | -| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | -| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | +| `--log-prefix, --no-log-prefix` | Enable prefix in log messages
(env: LLAMA_ARG_LOG_PREFIX) | +| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages
(env: LLAMA_ARG_LOG_TIMESTAMPS) | | `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) | | `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) | @@ -199,7 +198,7 @@ | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-type none,draft-simple,draft-eagle3,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | +| `--spec-type none,draft-simple,draft-eagle3,draft-mtp,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | | `--spec-ngram-mod-n-match N` | ngram-mod lookup length (default: 24) | diff --git a/tools/completion/README.md b/tools/completion/README.md index 048cf7416fc..e5dd7f6f4e7 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -138,7 +138,6 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | -| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)
(env: LLAMA_ARG_RPC) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | | `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | @@ -177,8 +176,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | -| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | -| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | +| `--log-prefix, --no-log-prefix` | Enable prefix in log messages
(env: LLAMA_ARG_LOG_PREFIX) | +| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages
(env: LLAMA_ARG_LOG_TIMESTAMPS) | | `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) | | `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) | diff --git a/tools/server/README.md b/tools/server/README.md index 2ed7fe16ee2..11098af2883 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -72,7 +72,6 @@ For the full list of features, please refer to [server's changelog](https://gith | `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | -| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)
(env: LLAMA_ARG_RPC) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | | `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | @@ -111,8 +110,8 @@ For the full list of features, please refer to [server's changelog](https://gith | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | -| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | -| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | +| `--log-prefix, --no-log-prefix` | Enable prefix in log messages
(env: LLAMA_ARG_LOG_PREFIX) | +| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages
(env: LLAMA_ARG_LOG_TIMESTAMPS) | | `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) | | `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) | @@ -189,11 +188,15 @@ For the full list of features, please refer to [server's changelog](https://gith | `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)
(env: LLAMA_ARG_REUSE_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )
(env: LLAMA_ARG_API_PREFIX) | -| `--ui-config JSON` / `--webui-config JSON` (deprecated) | JSON that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG / LLAMA_ARG_WEBUI_CONFIG) | -| `--ui-config-file PATH` / `--webui-config-file PATH` (deprecated) | JSON file that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG_FILE / LLAMA_ARG_WEBUI_CONFIG_FILE) | -| `--ui-mcp-proxy, --no-ui-mcp-proxy` / `--webui-mcp-proxy, --no-webui-mcp-proxy` (deprecated) | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_UI_MCP_PROXY / LLAMA_ARG_WEBUI_MCP_PROXY) | +| `--webui-config JSON` | [DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG) | +| `--ui-config JSON` | JSON that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG) | +| `--webui-config-file PATH` | [DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG_FILE) | +| `--ui-config-file PATH` | JSON file that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG_FILE) | +| `--webui-mcp-proxy, --no-webui-mcp-proxy` | [DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy
(env: LLAMA_ARG_WEBUI_MCP_PROXY) | +| `--ui-mcp-proxy, --no-ui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_UI_MCP_PROXY) | | `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)
specify "all" to enable all tools
available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime
(env: LLAMA_ARG_TOOLS) | -| `--ui, --no-ui` / `--webui, --no-webui` (deprecated) | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_UI / LLAMA_ARG_WEBUI) | +| `--webui, --no-webui` | [DEPRECATED: use --ui/--no-ui] whether to enable the Web UI
(env: LLAMA_ARG_WEBUI) | +| `--ui, --no-ui` | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_UI) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | | `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)
(env: LLAMA_API_KEY) | @@ -248,7 +251,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-type none,draft-simple,draft-eagle3,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | +| `--spec-type none,draft-simple,draft-eagle3,draft-mtp,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | | `--spec-ngram-mod-n-match N` | ngram-mod lookup length (default: 24) | diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1dc19536866..4d162f81d9b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -145,9 +145,9 @@ struct server_slot { SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1); + common_context_seq_rm(ctx_tgt, id, -1, -1); if (ctx_dft) { - llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1); + common_context_seq_rm(ctx_dft, id, -1, -1); } prompt.tokens.clear(); @@ -238,8 +238,14 @@ struct server_slot { (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); } + bool need_embd() const { + GGML_ASSERT(task); + return task->need_embd() || (spec && common_speculative_need_embd(spec)); + } + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens + // (MTP supports splitting — uses task->need_embd() not need_embd()) bool can_split() const { GGML_ASSERT(task); @@ -511,12 +517,12 @@ struct server_slot { void copy_state_to(server_slot & other) const { GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1); + common_context_seq_rm(ctx_tgt, other.id, -1, -1); + common_context_seq_cp(ctx_tgt, id, other.id, -1, -1); if (ctx_dft) { - llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1); + common_context_seq_rm(ctx_dft, other.id, -1, -1); + common_context_seq_cp(ctx_dft, id, other.id, -1, -1); } other.n_decoded = n_decoded; @@ -775,10 +781,40 @@ struct server_context_impl { } auto cparams = common_context_params_to_llama(params_dft); + + const bool spec_mtp = std::find(params_base.speculative.types.begin(), + params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); + if (spec_mtp) { + cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + } + + // note: for small models maybe we can set this to the maximum possible draft from all speculative types + // the extra memory for small models is likely negligible? + cparams.n_rs_seq = 0; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); + } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) { + SRV_INF("creating MTP draft context against the target model '%s'\n", + params_base.model.path.c_str()); + + auto cparams_mtp = common_context_params_to_llama(params_base); + cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + cparams_mtp.n_rs_seq = 0; + + ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp)); + if (ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create MTP context\n"); + return false; + } + + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } @@ -2194,12 +2230,12 @@ struct server_context_impl { SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + common_context_seq_rm (ctx_tgt, slot.id, n_keep , n_keep + n_discard); + common_context_seq_add(ctx_tgt, slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); if (ctx_dft) { - llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); + common_context_seq_rm (ctx_dft.get(), slot.id, n_keep , n_keep + n_discard); + common_context_seq_add(ctx_dft.get(), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); } // add generated tokens to cache @@ -2306,14 +2342,23 @@ struct server_context_impl { slot.n_draft_total += draft.size(); // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + if (ctx_dft) { - ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (use_ckpt_dft) { + ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } - llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1); + common_context_seq_rm(ctx_dft.get(), slot.id, ckpt.pos_max + 1, -1); } if (!draft.empty()) { - const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_tgt = + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && draft.size() > llama_n_rs_seq(ctx_tgt)); + + const bool use_ckpt_dft = + (ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && draft.size() > llama_n_rs_seq(ctx_dft.get())); if (use_ckpt_tgt) { //const int64_t t_start = ggml_time_us(); @@ -2328,6 +2373,10 @@ struct server_context_impl { (float) ckpt.size() / 1024 / 1024, (float) ckpt.data_dft.size() / 1024 / 1024); } + + if (use_ckpt_dft) { + ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } } } @@ -2499,12 +2548,12 @@ struct server_context_impl { const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift); + common_context_seq_rm (ctx_tgt, slot.id, head_p, head_c); + common_context_seq_add(ctx_tgt, slot.id, head_c, head_c + n_match, kv_shift); if (ctx_dft) { - llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift); + common_context_seq_rm (ctx_dft.get(), slot.id, head_p, head_c); + common_context_seq_add(ctx_dft.get(), slot.id, head_c, head_c + n_match, kv_shift); } for (size_t i = 0; i < n_match; i++) { @@ -2667,18 +2716,10 @@ struct server_context_impl { SLT_TRC(slot, "cached n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - - slot.prompt_clear(true); - - // there is no common part left - slot.n_prompt_tokens_cache = 0; - } else { - if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { - GGML_ABORT("failed to truncate draft context\n"); - } - } + common_context_seq_rm(ctx_tgt, slot.id, p0, -1); + if (ctx_dft) { + common_context_seq_rm(ctx_dft.get(), slot.id, p0, -1); + } // If using an alora, there may be uncached tokens that come // before the invocation sequence. When this happens, the @@ -2703,9 +2744,11 @@ struct server_context_impl { // checkpoints are created only if: // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) + // - the model supports partial sequence removal but only up to a fixed bound do_checkpoint = do_checkpoint && ( - (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || - (n_swa > 0)); + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS || + n_swa > 0); bool has_mtmd = false; @@ -2758,12 +2801,14 @@ struct server_context_impl { break; } - // embedding requires all tokens in the batch to be output + // embedding requires all tokens in the batch to be output; + // MTP also wants logits at every prompt position so the + // streaming hook can mirror t_h_pre_norm into ctx_dft. common_batch_add(batch, cur_tok, slot.prompt.tokens.pos_next(), { slot.id }, - slot.task->need_embd()); + slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -2877,7 +2922,7 @@ struct server_context_impl { slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->need_embd()); } if (batch.n_tokens == 0) { @@ -3140,13 +3185,8 @@ struct server_context_impl { // verify and try to accept the draft { - const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - - // only save the sampler sampler state if we use checkpoints - common_sampler_ptr smpl_save; - if (use_ckpt_tgt) { - smpl_save.reset(common_sampler_clone(slot.smpl.get())); - } + // save the sampler sampler state in case we need to restore it + common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); @@ -3154,8 +3194,14 @@ struct server_context_impl { GGML_ASSERT(accepted.size() >= 1); + const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); + + const bool use_ckpt_tgt = + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); + // check for partial draft acceptance - if (accepted.size() < slot.spec_draft.size() + 1) { + if (n_rollback > 0) { if (use_ckpt_tgt) { if (trace > 0) { SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); @@ -3171,13 +3217,13 @@ struct server_context_impl { { ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1); + common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); } if (slot.ctx_dft) { ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1); + common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); } slot.prompt.tokens.keep_first(ckpt.n_tokens); @@ -3200,7 +3246,6 @@ struct server_context_impl { const auto ids = std::move(slot.spec_draft); - slot.n_decoded += ids.size(); slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; // update how many tokens out of those tested were accepted @@ -3213,9 +3258,9 @@ struct server_context_impl { slot.sampled = ids.back(); // last accepted token SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1); + common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); if (slot.ctx_dft) { - llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1); + common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); } for (size_t i = 0; i < ids.size(); ++i) { @@ -3227,6 +3272,8 @@ struct server_context_impl { // TODO: set result.probs + slot.n_decoded += 1; + if (!process_token(result, slot)) { slot.print_timings(); send_final_response(slot); From 18675b6bbc8ea9b402453c256d70c09e4f95c32a Mon Sep 17 00:00:00 2001 From: "Alessandro de Oliveira Faria (A.K.A.CABELO)" Date: Sat, 16 May 2026 09:25:21 -0300 Subject: [PATCH 04/10] vendor : update cpp-httplib to 0.45.0 (#23103) --- scripts/sync_vendor.py | 2 +- vendor/cpp-httplib/httplib.cpp | 65 +++++++++++++++++----------------- vendor/cpp-httplib/httplib.h | 4 +-- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 0ddef236a4c..658f7326b96 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.44.0" +HTTPLIB_VERSION = "refs/tags/v0.45.0" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index ed9b09ba01c..b28549607a2 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -4723,17 +4723,24 @@ write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, }); } +bool has_framed_body(const Request &req) { + return is_chunked_transfer_encoding(req.headers) || + req.get_header_value_u64("Content-Length") > 0; +} + +bool is_connection_persistent(const Request &req) { + auto conn = req.get_header_value("Connection"); + if (conn == "close") { return false; } + if (req.version == "HTTP/1.0" && conn != "Keep-Alive") { return false; } + return true; +} + bool expect_content(const Request &req) { if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || req.method == "DELETE") { return true; } - if (req.has_header("Content-Length") && - req.get_header_value_u64("Content-Length") > 0) { - return true; - } - if (is_chunked_transfer_encoding(req.headers)) { return true; } - return false; + return has_framed_body(req); } #ifdef _WIN32 @@ -7449,29 +7456,18 @@ bool Server::read_content_core( size_t /*len*/) { return receiver(buf, n); }; } - // RFC 7230 Section 3.3.3: If this is a request message and none of the above - // are true (no Transfer-Encoding and no Content-Length), then the message - // body length is zero (no message body is present). - // - // For non-SSL builds, detect clients that send a body without a - // Content-Length header (raw HTTP over TCP). Check both the stream's - // internal read buffer (data already read from the socket during header - // parsing) and the socket itself for pending data. If data is found and - // exceeds the configured payload limit, reject with 413. - // For SSL builds we cannot reliably peek the decrypted application bytes, - // so keep the original behaviour. + // RFC 9112 §6: no Transfer-Encoding and no Content-Length means no body. + // For non-SSL builds we still scan non-persistent connections for stray + // body bytes so the payload limit is enforced (413). On keep-alive, + // pending bytes may be the next request (issue #2450), so skip. #if !defined(CPPHTTPLIB_SSL_ENABLED) if (!req.has_header("Content-Length") && !detail::is_chunked_transfer_encoding(req.headers)) { - // Only check if payload_max_length is set to a finite value - if (payload_max_length_ > 0 && + if (!detail::is_connection_persistent(req) && payload_max_length_ > 0 && payload_max_length_ < (std::numeric_limits::max)()) { - // Check if there is data already buffered in the stream (read during - // header parsing) or pending on the socket. Use a non-blocking socket - // check to avoid deadlock when the client sends no body. - bool has_data = strm.is_readable(); + auto has_data = strm.is_readable(); if (!has_data) { - socket_t s = strm.socket(); + auto s = strm.socket(); if (s != INVALID_SOCKET) { has_data = detail::select_read(s, 0, 0) > 0; } @@ -8033,6 +8029,11 @@ get_client_ip(const std::string &x_forwarded_for, ip_list.emplace_back(std::string(b + r.first, b + r.second)); }); + // A malformed X-Forwarded-For (empty, comma-only, whitespace-only) yields + // no segments. Signal "no client IP derived" with an empty string so the + // caller can fall back to the connection-level remote address. + if (ip_list.empty()) { return std::string(); } + for (size_t i = 0; i < ip_list.size(); ++i) { auto ip = ip_list[i]; @@ -8123,7 +8124,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr, if (!trusted_proxies_.empty() && req.has_header("X-Forwarded-For")) { auto x_forwarded_for = req.get_header_value("X-Forwarded-For"); - req.remote_addr = get_client_ip(x_forwarded_for, trusted_proxies_); + auto derived = get_client_ip(x_forwarded_for, trusted_proxies_); + req.remote_addr = derived.empty() ? remote_addr : derived; } else { req.remote_addr = remote_addr; } @@ -8325,15 +8327,14 @@ Server::process_request(Stream &strm, const std::string &remote_addr, ret = write_response(strm, close_connection, req, res); } - // Drain any unconsumed request body to prevent request smuggling on - // keep-alive connections. - if (!req.body_consumed_ && detail::expect_content(req)) { - int drain_status = 200; // required by read_content signature + // Drain any unconsumed framed body to prevent request smuggling on + // keep-alive. Without framing there is no body to drain — reading would + // consume the next request (issue #2450). + if (!req.body_consumed_ && detail::has_framed_body(req)) { + int dummy_status; if (!detail::read_content( - strm, req, payload_max_length_, drain_status, nullptr, + strm, req, payload_max_length_, dummy_status, nullptr, [](const char *, size_t, size_t, size_t) { return true; }, false)) { - // Body exceeds payload limit or read error — close the connection - // to prevent leftover bytes from being misinterpreted. connection_closed = true; } } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index b954b94af89..af750cdd905 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.44.0" -#define CPPHTTPLIB_VERSION_NUM "0x002c00" +#define CPPHTTPLIB_VERSION "0.45.0" +#define CPPHTTPLIB_VERSION_NUM "0x002d00" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00 From 25b1bc9c2f9aa0a390b968ee1ffd9ff01340a3fe Mon Sep 17 00:00:00 2001 From: Holger Voormann Date: Sat, 16 May 2026 14:42:38 +0200 Subject: [PATCH 05/10] ui: Correct links in `tools/ui/README.md` [no ci] (#23139) In `tools/ui/README.md`, update the relative links, now that the `README.md` file has been moved from `tools/server/webui/` to `tools/ui/`. See https://github.com/ggml-org/llama.cpp/commit/59778f0196a82db32580bb649d5d839355d6d7bf. --- tools/ui/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/ui/README.md b/tools/ui/README.md index abbbabe923d..4334f968162 100644 --- a/tools/ui/README.md +++ b/tools/ui/README.md @@ -681,6 +681,6 @@ tools/ui/ ## Related Documentation -- [llama.cpp Server README](../README.md) - Full server documentation -- [Multimodal Documentation](../../../docs/multimodal.md) - Image and audio support -- [Function Calling](../../../docs/function-calling.md) - Tool use capabilities +- [llama.cpp Server README](../server/README.md) - Full server documentation +- [Multimodal Documentation](../../docs/multimodal.md) - Image and audio support +- [Function Calling](../../docs/function-calling.md) - Tool use capabilities From 2eb3e6b2428c035dd24f996c7dfc48654dc19a6b Mon Sep 17 00:00:00 2001 From: Steve Lhomme Date: Sun, 10 May 2026 16:35:38 +0200 Subject: [PATCH 06/10] ggml: install ggml.pc in /pkgconfig (ggml/1480) That's always how it's done: https://github.com/search?q=path%3ACMakeLists.txt%20%22%24%7BCMAKE_INSTALL_LIBDIR%7D%2Fpkgconfig%22&type=code --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index de765ef3e24..bdeca34bf9f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -353,7 +353,7 @@ if (GGML_STANDALONE) @ONLY) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc - DESTINATION share/pkgconfig) + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) endif() # From 560445bf34c87356ad0f8d80fb03ec5488850b65 Mon Sep 17 00:00:00 2001 From: CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Date: Sun, 10 May 2026 16:45:00 +0200 Subject: [PATCH 07/10] metal : tighten input-position loop in kernel_conv_transpose_1d (ggml/1477) For a given output position j on the time axis, only input positions i such that i*s0 <= j < i*s0 + K contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] intersected with [0, IL-1]. That's at most ceil(K/s0) values (typically 2 for stride==K/2 transposed convs). The current kernel iterates the full IL range and filters with an `if`, amplifying per-thread work by IL/ceil(K/s0) (~160x for IL=320, K=10, s0=5 -- a representative codec-decoder shape). On Apple M1 the wasted work trips the macOS GPU watchdog (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) on long graphs. Compute i_min, i_max analytically before the inner loop and iterate only [i_min, i_max]. Output is bit-identical (same multiplies and adds in the same order); loop bound shrinks by IL/ceil(K/s0). Tested on M1 with a downstream consumer running a TTS codec at full T_codec; end-to-end codec decode ~3-4x faster, zero watchdog hits across long synthesis runs vs ~30% pre-patch. --- ggml/src/ggml-metal/ggml-metal.metal | 31 +++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 82e29d5ad7c..f6ffb2b3a1c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4881,15 +4881,32 @@ kernel void kernel_conv_transpose_1d( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { - float v = 0.0f; + // For output position j on the time axis, only input positions + // i such that i*s0 <= j < i*s0 + K + // contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] + // intersected with [0, IL-1]. That's at most ceil(K/s0) values + // (typically 2 for stride==K/2 transposed convs). + const int32_t j = tgpig[0]; + const int32_t s0 = args.s0; + const int32_t K = args.K; + const int32_t IL = args.IL; + + int32_t i_min; + { + int32_t a = j - K + 1; + i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0 + } + int32_t i_max = j / s0; + if (i_max > IL - 1) i_max = IL - 1; - for (int64_t c = 0; c < args.IC; c++) { - const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; - const int32_t input_offset = c * args.IL; + float v = 0.0f; + if (i_min <= i_max) { + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; + const int32_t input_offset = c * IL; - for (int64_t i = 0; i < args.IL; i++) { - if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { - v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; + for (int32_t i = i_min; i <= i_max; i++) { + v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i]; } } } From e6c37a1adc82509f3e8fbbb2d255ccaa433ab383 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 May 2026 15:59:09 +0300 Subject: [PATCH 08/10] ggml : bump version to 0.12.0 (ggml/1494) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index bdeca34bf9f..4aac5094d1c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 11) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_MINOR 12) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 3a92bc99db3a9737b2688ee038f871760fe3ad34 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 May 2026 15:59:45 +0300 Subject: [PATCH 09/10] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 15685a0718f..0fa47782fd9 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -628249b398293fc8d2fa81a449ae2920a02c6523 +0ce7ad348a3151e1da9f65d962044546bcaad421 From 0253fb21f595246f54c192fe8332f34173be251b Mon Sep 17 00:00:00 2001 From: Aleksander Grygier Date: Sat, 16 May 2026 15:20:27 +0200 Subject: [PATCH 10/10] ui: Add request timeout for MCP tool calls (#23138) * feat: Add request timeout for MCP tool calls in llama-ui * feat: MCP Settings tab with max timeout setting --- .../SettingsChat/SettingsChatFields.svelte | 2 ++ tools/ui/src/lib/constants/routes.ts | 1 + tools/ui/src/lib/constants/settings-keys.ts | 1 + .../ui/src/lib/constants/settings-registry.ts | 21 ++++++++++++++++++- tools/ui/src/lib/services/mcp.service.ts | 10 ++++++--- tools/ui/src/lib/stores/mcp.svelte.ts | 7 +++++-- tools/ui/src/lib/types/mcp.d.ts | 2 ++ tools/ui/src/lib/types/settings.d.ts | 1 + tools/ui/src/lib/utils/mcp.ts | 6 ++++-- 9 files changed, 43 insertions(+), 8 deletions(-) diff --git a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChatFields.svelte b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChatFields.svelte index 3ecf00adce8..069855eebef 100644 --- a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChatFields.svelte +++ b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChatFields.svelte @@ -79,6 +79,8 @@
{ // Update local config immediately for real-time badge feedback diff --git a/tools/ui/src/lib/constants/routes.ts b/tools/ui/src/lib/constants/routes.ts index 14416478f08..3b3fceea448 100644 --- a/tools/ui/src/lib/constants/routes.ts +++ b/tools/ui/src/lib/constants/routes.ts @@ -8,6 +8,7 @@ export const SETTINGS_SECTION_SLUGS = { PENALTIES: 'penalties', AGENTIC: 'agentic', DEVELOPER: 'developer', + MCP: 'mcp', TOOLS: 'tools', IMPORT_EXPORT: 'import-export' } as const; diff --git a/tools/ui/src/lib/constants/settings-keys.ts b/tools/ui/src/lib/constants/settings-keys.ts index b673bff278d..92a57f88acf 100644 --- a/tools/ui/src/lib/constants/settings-keys.ts +++ b/tools/ui/src/lib/constants/settings-keys.ts @@ -53,6 +53,7 @@ export const SETTINGS_KEYS = { DRY_PENALTY_LAST_N: 'dry_penalty_last_n', // MCP MCP_SERVERS: 'mcpServers', + MCP_REQUEST_TIMEOUT_SECONDS: 'mcpRequestTimeoutSeconds', AGENTIC_MAX_TURNS: 'agenticMaxTurns', ALWAYS_SHOW_AGENTIC_TURNS: 'alwaysShowAgenticTurns', AGENTIC_MAX_TOOL_PREVIEW_LINES: 'agenticMaxToolPreviewLines', diff --git a/tools/ui/src/lib/constants/settings-registry.ts b/tools/ui/src/lib/constants/settings-registry.ts index c4fc3fb301e..bdbb17d962c 100644 --- a/tools/ui/src/lib/constants/settings-registry.ts +++ b/tools/ui/src/lib/constants/settings-registry.ts @@ -23,7 +23,8 @@ import type { SettingsSectionEntry, SettingsSection } from '$lib/types'; -import { CLI_FLAGS } from '$lib/constants'; +import { CLI_FLAGS, DEFAULT_MCP_CONFIG } from '$lib/constants'; +import McpLogo from '$lib/components/app/mcp/McpLogo.svelte'; import { SETTINGS_KEYS } from './settings-keys'; import { ROUTES, SETTINGS_SECTION_SLUGS } from './routes'; import { TITLE_GENERATION } from './title-generation'; @@ -35,6 +36,7 @@ export const SETTINGS_SECTION_TITLES = { PENALTIES: 'Penalties', AGENTIC: 'Agentic', TOOLS: 'Tools', + MCP: 'MCP', IMPORT_EXPORT: 'Import/Export', DEVELOPER: 'Developer' } as const; @@ -657,6 +659,22 @@ const SETTINGS_REGISTRY: Record = { section: SETTINGS_SECTION_SLUGS.DEVELOPER } ] + }, + [SETTINGS_SECTION_SLUGS.MCP]: { + title: SETTINGS_SECTION_TITLES.MCP, + slug: SETTINGS_SECTION_SLUGS.MCP, + icon: McpLogo, + settings: [ + { + key: SETTINGS_KEYS.MCP_REQUEST_TIMEOUT_SECONDS, + label: 'Request timeout (seconds)', + help: 'Default timeout for individual MCP tool calls. Can be overridden per server.', + defaultValue: DEFAULT_MCP_CONFIG.requestTimeoutSeconds, + type: SettingsFieldType.INPUT, + section: SETTINGS_SECTION_SLUGS.MCP, + isPositiveInteger: true + } + ] } } as const; @@ -727,6 +745,7 @@ export const SETTINGS_CHAT_SECTIONS: SettingsSection[] = [ label: s.label, type: s.type, isExperimental: s.isExperimental, + isPositiveInteger: s.isPositiveInteger, help: s.help, options: s.options })) diff --git a/tools/ui/src/lib/services/mcp.service.ts b/tools/ui/src/lib/services/mcp.service.ts index 458013b5acb..44cbd4a8aaf 100644 --- a/tools/ui/src/lib/services/mcp.service.ts +++ b/tools/ui/src/lib/services/mcp.service.ts @@ -665,7 +665,9 @@ export class MCPService { tools: [], serverName, transportType, - connectionTimeMs: 0 + connectionTimeMs: 0, + requestTimeoutMs: + serverConfig.requestTimeoutMs ?? DEFAULT_MCP_CONFIG.requestTimeoutSeconds * 1000 }); const connectionTimeMs = Math.round(performance.now() - startTime); @@ -694,7 +696,9 @@ export class MCPService { clientCapabilities: effectiveCapabilities, protocolVersion: DEFAULT_MCP_CONFIG.protocolVersion, instructions, - connectionTimeMs + connectionTimeMs, + requestTimeoutMs: + serverConfig.requestTimeoutMs ?? DEFAULT_MCP_CONFIG.requestTimeoutSeconds * 1000 }; } @@ -813,7 +817,7 @@ export class MCPService { const result = await connection.client.callTool( { name: params.name, arguments: params.arguments }, undefined, - { signal } + { signal, timeout: connection.requestTimeoutMs } ); return { diff --git a/tools/ui/src/lib/stores/mcp.svelte.ts b/tools/ui/src/lib/stores/mcp.svelte.ts index 2a1eb3ff51d..8fb306da881 100644 --- a/tools/ui/src/lib/stores/mcp.svelte.ts +++ b/tools/ui/src/lib/stores/mcp.svelte.ts @@ -168,7 +168,9 @@ class MCPStore { enabled: Boolean((entry as { enabled?: unknown })?.enabled), url, name: (entry as { name?: string })?.name, - requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds, + requestTimeoutSeconds: + (entry as { requestTimeoutSeconds?: number })?.requestTimeoutSeconds ?? + DEFAULT_MCP_CONFIG.requestTimeoutSeconds, headers: headers || undefined, useProxy: Boolean((entry as { useProxy?: unknown })?.useProxy) } satisfies MCPServerSettingsEntry; @@ -554,7 +556,8 @@ class MCPStore { url: serverData.url.trim(), name: serverData.name, headers: serverData.headers?.trim() || undefined, - requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds, + requestTimeoutSeconds: + Number(config().mcpRequestTimeoutSeconds) || DEFAULT_MCP_CONFIG.requestTimeoutSeconds, useProxy: serverData.useProxy }; settingsStore.updateConfig(SETTINGS_KEYS.MCP_SERVERS, JSON.stringify([...servers, newServer])); diff --git a/tools/ui/src/lib/types/mcp.d.ts b/tools/ui/src/lib/types/mcp.d.ts index 3837bcdf1b8..7aa050cdfa7 100644 --- a/tools/ui/src/lib/types/mcp.d.ts +++ b/tools/ui/src/lib/types/mcp.d.ts @@ -135,6 +135,8 @@ export interface MCPConnection { protocolVersion?: string; instructions?: string; connectionTimeMs: number; + /** Configured timeout for individual requests (tool calls, etc.) in milliseconds */ + requestTimeoutMs: number; } /** diff --git a/tools/ui/src/lib/types/settings.d.ts b/tools/ui/src/lib/types/settings.d.ts index 1ab7a7e5d54..65096db3449 100644 --- a/tools/ui/src/lib/types/settings.d.ts +++ b/tools/ui/src/lib/types/settings.d.ts @@ -42,6 +42,7 @@ export interface SettingsFieldConfig { label: string; type: SettingsFieldType; isExperimental?: boolean; + isPositiveInteger?: boolean; help?: string; options?: Array<{ value: string; label: string; icon?: typeof Icon }>; } diff --git a/tools/ui/src/lib/utils/mcp.ts b/tools/ui/src/lib/utils/mcp.ts index ee27798455d..05fe90048f0 100644 --- a/tools/ui/src/lib/utils/mcp.ts +++ b/tools/ui/src/lib/utils/mcp.ts @@ -49,7 +49,7 @@ export function detectMcpTransportFromUrl(url: string): MCPTransportType { /** * Parses MCP server settings from a JSON string or array. - * requestTimeoutSeconds is not user-configurable in the UI, so we always use the default value. + * Preserves per-server requestTimeoutSeconds if stored, otherwise falls back to the global default. * @param rawServers - The raw servers to parse * @returns An empty array if the input is invalid. */ @@ -88,7 +88,9 @@ export function parseMcpServerSettings(rawServers: unknown): MCPServerSettingsEn enabled: Boolean((entry as { enabled?: unknown })?.enabled), url, name: (entry as { name?: string })?.name, - requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds, + requestTimeoutSeconds: + (entry as { requestTimeoutSeconds?: number })?.requestTimeoutSeconds ?? + DEFAULT_MCP_CONFIG.requestTimeoutSeconds, headers: headers || undefined, useProxy: Boolean((entry as { useProxy?: unknown })?.useProxy) } satisfies MCPServerSettingsEntry;