From a6d6183dbc5e30ed7cc92c7d29fc5daccab86a9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ekstr=C3=B6m?= Date: Sun, 17 May 2026 14:12:11 +0300 Subject: [PATCH 1/8] ggml-vulkan/CMakeLists: add a check for SPIRV-Headers (#22009) * ci/run: set explicit SPIR-V Headers search path for macOS vulkan CI For whatever reason, the files are under additional sub-path `vulkan/` under the cmake directory, which does not match either current LunarG macOS Vulkan SDK structure (`lib/cmake/SPIRV-Headers`), nor what gets installed when you run the cmake build+install for SPIRV-Headers itself on at least Linux (`share/cmake/SPIRV-Headers`). This allows for SPIRV-Headers to be found, as currently the CI runner's setup does not seem to include the relevant path in list of search locations. * ggml-vulkan/CMakeLists: add a check for SPIRV-Headers This is installed by the project if it is built and installed. Receiving an error during the configuration step is generally preferred to receiving an error in the middle of a build. --- ci/run.sh | 6 ++++++ ggml/src/ggml-vulkan/CMakeLists.txt | 2 ++ 2 files changed, 8 insertions(+) diff --git a/ci/run.sh b/ci/run.sh index 529da07779fd..a8cbd3371d37 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -117,6 +117,12 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then # if on Mac, disable METAL if [[ "$OSTYPE" == "darwin"* ]]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=OFF -DGGML_BLAS=OFF" + + MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION="/usr/local/lib/cmake/vulkan" + MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION="${MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION}/SPIRV-Headers/SPIRV-HeadersConfig.cmake" + if [[ -f "${MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION}" || -h "${MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION}" ]]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DSPIRV-Headers_DIR=${MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION}/SPIRV-Headers" + fi fi # Build shared libs on Windows diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 715a263a6d09..6dbcea065b35 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,6 +8,8 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +find_package(SPIRV-Headers REQUIRED) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files add_definitions(/MP) From 39cf5d61915769124b7efbbfa69c46f19a6363ee Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Sun, 17 May 2026 07:36:05 -0400 Subject: [PATCH 2/8] common : delegate assistant continuation to underlying template handlers (#23089) * common : delegate assistant continuation to template handler * server : implement echo parameter to exclude assistant prefill in the response * server : fix tests for prefill * server : use existing llama template * cont : clean up --- common/chat-auto-parser-generator.cpp | 34 +- common/chat-auto-parser.h | 15 +- common/chat-peg-parser.cpp | 2 +- common/chat.cpp | 255 +++++- common/chat.h | 23 +- tests/test-chat.cpp | 833 +++++++++++++++++- tools/server/server-common.cpp | 107 +-- tools/server/server-task.cpp | 12 + tools/server/server-task.h | 6 +- .../server/tests/unit/test_chat_completion.py | 12 +- 10 files changed, 1110 insertions(+), 189 deletions(-) diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 6021fc4ede51..db3a6cc6fe38 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -43,11 +43,33 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & const autoparser & autoparser) { // Create the result structure common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = autoparser.preserved_tokens; + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = autoparser.preserved_tokens; + + std::string parser_generation_prompt = data.generation_prompt; + + if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !inputs.continue_msg.empty()) { + // Build up generation prompt manually + const auto & msg = inputs.continue_msg; + + if (!autoparser.reasoning.start.empty()) { + data.generation_prompt = data.generation_prompt.substr(0, data.generation_prompt.find(autoparser.reasoning.start)); + data.generation_prompt += autoparser.reasoning.start + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += autoparser.reasoning.end; + } + } + + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += msg.render_content(); + } + + data.prompt += data.generation_prompt; + } - auto parser = autoparser.build_parser(inputs); + auto parser = autoparser.build_parser(inputs, parser_generation_prompt); data.parser = parser.save(); // Build grammar if tools are present @@ -87,7 +109,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & return data; } -common_peg_arena autoparser::build_parser(const generation_params & inputs) const { +common_peg_arena autoparser::build_parser(const generation_params & inputs, const std::string & generation_prompt) const { if (!analysis_complete) { throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)"); } @@ -121,7 +143,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons } else { parser = content.build_parser(ctx); } - return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser; + return pure_content ? p.prefix(generation_prompt, reasoning.start) + parser : p.prefix(generation_prompt, reasoning.start) << parser; }); } diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 6c547409760d..c680e6868676 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -60,16 +60,21 @@ struct generation_params { common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; bool stream = true; std::string grammar; - bool add_generation_prompt = false; - bool enable_thinking = true; - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); - std::string generation_prompt; + bool add_generation_prompt = false; + common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + common_chat_msg continue_msg; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); json extra_context; bool add_bos = false; bool add_eos = false; bool is_inference = true; bool add_inference = false; bool mark_input = true; // whether to mark input strings in the jinja context + + bool has_continuation() const { + return continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !continue_msg.empty(); + } }; // ============================================================================ @@ -386,7 +391,7 @@ struct autoparser { void analyze_template(const common_chat_template & tmpl); // Build the PEG parser for this template - common_peg_arena build_parser(const generation_params & inputs) const; + common_peg_arena build_parser(const generation_params & inputs, const std::string & generation_prompt) const; private: // Collect tokens from entire analysis to preserve diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 79274febe7e3..12e747d1ca14 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -785,7 +785,7 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s if (delimiter.empty()) { return literal(s); } - return literal(s.substr(0, s.rfind(delimiter))); + return literal(s.substr(0, s.find(delimiter))); } common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) { diff --git a/common/chat.cpp b/common/chat.cpp index 70b9f5dc2c58..56873e3a1e93 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -70,6 +70,26 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) { return !msg.content.empty() || !msg.tool_calls.empty(); } +std::string common_chat_msg::render_content(const std::string & delimiter) const { + if (!content.empty() && !content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); + } + if (!content.empty()) { + return content; + } + + std::string text; + for (const auto & part : content_parts) { + if (part.type == "text") { + if (!text.empty()) { + text += delimiter; + } + text += part.text; + } + } + return text; +} + json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { if (!content.empty() && !content_parts.empty()) { throw std::runtime_error("Cannot specify both content and content_parts"); @@ -451,6 +471,22 @@ std::vector common_chat_tools_parse_oaicompat(const json & too return result; } +common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value) { + if (value.is_boolean() && value.get()) { + return COMMON_CHAT_CONTINUATION_AUTO; + } + if (value.is_string()) { + auto value_str = value.get(); + if (value_str == "reasoning_content") { + return COMMON_CHAT_CONTINUATION_REASONING; + } + if (value_str == "content") { + return COMMON_CHAT_CONTINUATION_CONTENT; + } + } + return COMMON_CHAT_CONTINUATION_NONE; +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -811,6 +847,36 @@ std::string common_chat_template_direct_apply( return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt); } +static std::string common_chat_template_generation_prompt_impl( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs, + const std::optional & messages_override = std::nullopt, + const std::optional & tools_override = std::nullopt, + const std::optional & additional_context = std::nullopt) { + + auto adjusted_messages = messages_override ? *messages_override : inputs.messages; + + autoparser::generation_params params = inputs; + params.add_generation_prompt = false; + params.continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context); + params.add_generation_prompt = true; + std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context); + + size_t prefix_len = 0; + size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size()); + while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) { + prefix_len++; + } + return gen_prompt.substr(prefix_len); +} + +std::string common_chat_template_generation_prompt( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + return common_chat_template_generation_prompt_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt); +} + static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { common_chat_params data; @@ -863,6 +929,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ data.thinking_start_tag = "[THINK]"; data.thinking_end_tag = "[/THINK]"; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override = */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { "[THINK]", @@ -871,8 +938,19 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ "[ARGS]", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = "[THINK]" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "[/THINK]" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, "[THINK]"); + auto generation_prompt = p.eps(); auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); @@ -963,6 +1041,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp } data.prompt = prompt; + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; @@ -972,6 +1051,18 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp "<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>", }; + // Adjust prompt for continuation + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = "<|start|>assistant<|channel|>analysis<|message|>" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "<|end|><|start|>assistant<|channel|>final<|message|>" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); @@ -1080,12 +1171,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); if (inputs.add_generation_prompt && string_ends_with(data.prompt, "\n")) { // This may happen if the model generates content + tool_call, the // template does not add the model's next turn and confuses the model // from emitting its proper reasoning token sequence. - data.prompt += "<|turn>model\n"; + data.generation_prompt = "<|turn>model\n"; + data.prompt += data.generation_prompt; } data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; @@ -1101,13 +1194,25 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ "<|turn>", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = string_ends_with(data.prompt, "\n") ? "<|turn>model\n" : ""; + data.generation_prompt += "<|channel>thought\n" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>")); + auto start = p.rule("start", p.optional(p.literal("<|turn>model\n"))); if (extract_reasoning) { p.rule("thought", p.literal("<|channel>thought") + p.space() + p.reasoning(p.until("")) + p.literal("")); @@ -1224,15 +1329,22 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = { + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = { ">>>all", }; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + data.generation_prompt = "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\n" + msg.render_content(); + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { // Functionary v3.2 format: // - Normal content: >>>all\n{content} @@ -1244,7 +1356,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // When no tools, content goes until end auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>")); auto content_until_end = p.literal("all\n") + p.content(p.rest()); - auto generation_prompt = p.literal(inputs.generation_prompt); + auto generation_prompt = p.literal("<|start_header_id|>assistant<|end_header_id|>\n\n>>>"); // If no tools or tool_choice is NONE, just parse content if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { @@ -1318,9 +1430,10 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; data.preserved_tokens = { "<|tool_calls_section_begin|>", "<|tool_calls_section_end|>", @@ -1343,10 +1456,22 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_assistant|>assistant<|im_middle|>"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { // Kimi K2 Thinking format: // - Reasoning: {reasoning} @@ -1366,7 +1491,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning( p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) + p.optional(p.literal(THINK_END))) : p.eps(); - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); // Content only parser (no tools) @@ -1442,6 +1567,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1461,12 +1587,24 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat const std::string TOOL_CALL_END = "<|tool_call_end|>"; const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_start|>assistant\n"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -1521,6 +1659,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1536,12 +1675,24 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_start|>assistant\n"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -1592,6 +1743,7 @@ static common_chat_params common_chat_params_init_gigachat_v3( common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = false; data.preserved_tokens = { @@ -1599,6 +1751,12 @@ static common_chat_params common_chat_params_init_gigachat_v3( "<|role_sep|>\n", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + data.generation_prompt = "assistant<|role_sep|>\n" + msg.render_content(); + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n"; @@ -1634,7 +1792,7 @@ static common_chat_params common_chat_params_init_gigachat_v3( ret = p.content(p.rest()); } - return p.literal(inputs.generation_prompt) + ret; + return p.literal("assistant<|role_sep|>\n") + ret; }); data.parser = parser.save(); @@ -1662,12 +1820,13 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; data.thinking_start_tag = ""; data.thinking_end_tag = ""; - data.preserved_tokens = { + data.preserved_tokens = { "|DSML|", "", "", @@ -1687,9 +1846,21 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha const std::string INVOKE_END = ""; const std::string PARAM_START = "<" + DSML + "parameter"; const std::string PARAM_END = ""; + const std::string GEN_PROMPT = "<|Assistant|>"; + + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -2116,21 +2287,6 @@ std::optional common_chat_try_specialized_template( return std::nullopt; } -static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { - autoparser::generation_params params = inputs; - params.add_generation_prompt = false; - std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); - params.add_generation_prompt = true; - std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); - - size_t prefix_len = 0; - size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size()); - while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) { - prefix_len++; - } - return gen_prompt.substr(prefix_len); -} - static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) { autoparser::generation_params params; @@ -2149,6 +2305,27 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ params.add_bos = tmpls->add_bos; params.add_eos = tmpls->add_eos; + params.continue_final_message = inputs.continue_final_message; + if (params.continue_final_message != COMMON_CHAT_CONTINUATION_NONE) { + params.add_generation_prompt = false; + + if (!inputs.messages.empty()) { + // Render messages[:-1] and store continuation message separately + params.continue_msg = inputs.messages.back(); + params.messages.erase(params.messages.size() - 1); + } + + if (params.continue_final_message == COMMON_CHAT_CONTINUATION_AUTO && !inputs.messages.empty()) { + // Resolve based on message content + params.continue_final_message = COMMON_CHAT_CONTINUATION_CONTENT; + if (!params.continue_msg.reasoning_content.empty() && + params.continue_msg.content.empty() && + params.continue_msg.content_parts.empty()) { + params.continue_final_message = COMMON_CHAT_CONTINUATION_REASONING; + } + } + } + if (src.find("<|channel|>") == std::string::npos) { // map developer to system for all models except for GPT-OSS workaround::map_developer_role_to_system(params.messages); @@ -2169,8 +2346,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ workaround::func_args_not_string(params.messages); } - params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params); - params.extra_context = common_chat_extra_context(); for (auto el : inputs.chat_template_kwargs) { params.extra_context[el.first] = json::parse(el.second); @@ -2200,17 +2375,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ auto params_copy = params; params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE; data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, params); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.generation_prompt = params.generation_prompt; - auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) { - return p.prefix(params.generation_prompt) << p.content(p.rest()); + auto parser = build_chat_peg_parser([&data](common_chat_peg_builder &p) { + return p.literal(data.generation_prompt) << p.content(p.rest()); }); data.parser = parser.save(); return data; } if (auto result = common_chat_try_specialized_template(tmpl, src, params)) { - result->generation_prompt = params.generation_prompt; return *result; } @@ -2224,7 +2398,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start); auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end); } - auto_params.generation_prompt = params.generation_prompt; common_peg_arena arena; arena.load(auto_params.parser); LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str()); diff --git a/common/chat.h b/common/chat.h index 054f5ffe777f..8ace3e6ba69b 100644 --- a/common/chat.h +++ b/common/chat.h @@ -89,6 +89,8 @@ struct common_chat_msg { nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const; + std::string render_content(const std::string & delimiter = "\n\n") const; + bool empty() const { return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); @@ -164,12 +166,22 @@ enum common_chat_format { COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; + +// Continuation method provided via `continue_final_message` +enum common_chat_continuation { + COMMON_CHAT_CONTINUATION_NONE, + COMMON_CHAT_CONTINUATION_AUTO, + COMMON_CHAT_CONTINUATION_REASONING, + COMMON_CHAT_CONTINUATION_CONTENT, +}; + struct common_chat_templates_inputs { std::vector messages; std::string grammar; std::string json_schema; - bool add_generation_prompt = true; - bool use_jinja = true; + bool add_generation_prompt = true; + common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + bool use_jinja = true; // Parameters below only supported when use_jinja is true std::vector tools; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; @@ -207,6 +219,7 @@ struct common_chat_parser_params { bool reasoning_in_content = false; std::string generation_prompt; bool parse_tool_calls = true; + bool echo = false; // Include assistant prefilled msg in output bool debug = false; // Enable debug output for PEG parser common_peg_arena parser = {}; common_chat_parser_params() = default; @@ -267,6 +280,8 @@ std::vector common_chat_msgs_parse_oaicompat(const nlohmann::or std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); +common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value); + // DEPRECATED: only used in tests nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); @@ -279,6 +294,10 @@ std::string common_chat_template_direct_apply( const common_chat_template & tmpl, const autoparser::generation_params & inputs); +std::string common_chat_template_generation_prompt( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs); + std::optional common_chat_try_specialized_template( const common_chat_template & tmpl, const std::string & src, diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 05c60d297438..a428ef35c18a 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1,4 +1,4 @@ -// Tests chat handling, including grammar generation and parsing for tool calling, for various templates. +// Tests chat handling, including grammar genration and parsing for tool calling, for various templates. // // Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates, // e.g. given Minja (http://github.com/google/minja) checked out in parent dir: @@ -100,6 +100,25 @@ template static void assert_equals(const T & expected, const T & actua } } +static void assert_contains(const std::string & haystack, const std::string & needle) { + if (haystack.find(needle) == std::string::npos) { + LOG_ERR("Expected to contain: %s\n", needle.c_str()); + LOG_ERR("Actual: %s\n", haystack.c_str()); + common_log_flush(common_log_main()); + throw std::runtime_error("Test failed"); + } +} + +static void assert_ends_with(const std::string & str, const std::string & suffix) { + if (str.size() < suffix.size() || + str.compare(str.size() - suffix.size(), suffix.size(), suffix) != 0) { + LOG_ERR("Expected to end with: %s\n", suffix.c_str()); + LOG_ERR("Actual: %s\n", str.c_str()); + common_log_flush(common_log_main()); + throw std::runtime_error("Test failed"); + } +} + static std::string read_file(const std::string & path) { std::ifstream fs(path, std::ios_base::binary); if (!fs.is_open()) { @@ -945,6 +964,8 @@ const common_chat_msg message_assist_call_python_lines_unclosed = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')"); const common_chat_msg message_assist_json_content = simple_assist_msg("{\n \"response\": \"Hello, world!\\nWhat's up?\"\n}"); +const common_chat_msg message_assist_prefill_content = simple_assist_msg("Hello, ", "I'm thinking"); +const common_chat_msg message_assist_prefill_reasoning = simple_assist_msg("", "I'm"); // Use for PEG parser implementations struct peg_test_case { @@ -1350,7 +1371,10 @@ class peg_test_builder { peg_test_case tc_; public: - peg_test_builder(peg_tester & tester, const std::string & input) : tester_(tester) { tc_.input = input; } + peg_test_builder(peg_tester & tester, const std::string & input) : tester_(tester) { + tc_.input = input; + tc_.params.add_generation_prompt = true; + } // Parameter setters peg_test_builder & reasoning_format(common_reasoning_format fmt) { @@ -1373,6 +1397,16 @@ class peg_test_builder { return *this; } + peg_test_builder & add_generation_prompt(bool val) { + tc_.params.add_generation_prompt = val; + return *this; + } + + peg_test_builder & continue_final_message(common_chat_continuation cont) { + tc_.params.continue_final_message = cont; + return *this; + } + peg_test_builder & json_schema(const std::string & schema) { tc_.params.json_schema = schema; return *this; @@ -2045,6 +2079,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_content("Final answer without tools.") .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + { common_chat_msg user_start; user_start.role = "user"; @@ -2204,6 +2259,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking[/THINK]Hello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2350,6 +2426,26 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2409,6 +2505,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).expect_reconstruction().run(); tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2580,6 +2684,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist) .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + { // additional tests for https://github.com/ggml-org/llama.cpp/pull/21760 auto tmpls = read_templates("models/templates/google-gemma-4-31B-it.jinja"); @@ -2633,6 +2758,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); tst.test("Hello, world!").reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(simple_assist_msg("Hello, world!")).run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { // NousResearch-Hermes-2-Pro and Hermes-3 (tool calling models) @@ -2656,6 +2802,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Note: Hermes template doesn't support thinking/reasoning natively // Note: We only support one tool calling format per template, no alternate formats + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { // Test simple content-only template @@ -2679,6 +2833,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) // .expect(message_assist_thoughts) // .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2695,6 +2870,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2784,6 +2967,26 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -3019,6 +3222,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "set_unit", R"({"unit": "celsius"})", {} }, }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug); @@ -3030,6 +3241,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}")) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -3284,6 +3516,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "magic_int", R"({"ref": 42, "name": "foo bar"})", {} }, }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // GLM-4.6 tests - format: function_name\n...\n...\n @@ -3380,6 +3633,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_reconstruction() .run(); } + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Verify the throw path produces a readable error message, not std::out_of_range. @@ -3630,6 +3904,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { } }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -3657,6 +3952,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(kimi_id_special_func_tool_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // LFM2-8B-A1B tests - uses <|tool_list_start|>/<|tool_list_end|> and <|tool_call_start|>[name(args)]<|tool_call_end|> @@ -3742,6 +4058,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "special_function", R"({"arg1": 1})", {} }, }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // LFM2.5 tests - uses plain "List of tools: [...]" and bare [name(args)] without wrapper tokens @@ -3803,6 +4140,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ empty_args_tool }) .expect(simple_assist_msg("", "", "empty_args", "{}")) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Reka-Edge tests - uses native JSON format with per-call wrapper @@ -3888,6 +4246,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "special_function", R"({"arg1": 1})", {} }, }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } @@ -3900,6 +4279,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // MiniMax-M2 tests - XML invoke format with parameter tags @@ -3924,6 +4311,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // NVIDIA-Nemotron-Nano-v2 tests - ... format @@ -3934,6 +4342,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // CohereForAI-c4ai-command-r7b (uses START_RESPONSE/END_RESPONSE, START_THINKING/END_THINKING, START_ACTION/END_ACTION) @@ -3963,6 +4379,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // mistralai-Mistral-Nemo-Instruct-2407.jinja @@ -3974,6 +4398,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call_id) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.1.jinja", detailed_debug); @@ -3983,6 +4415,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Functionary v3.2 - recipient-based format: >>>recipient\n{content} { @@ -3993,6 +4433,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // FireFunction @@ -4004,6 +4452,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // DeepSeek R1 Distill Llama 8B - reasoning tests only (forced open thinking) @@ -4019,6 +4475,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .expect(message_assist_thoughts) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // llama-cpp DeepSeek R1 template (always forced-open thinking) { @@ -4036,6 +4513,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .parallel_tool_calls(true) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // DeepSeek R1 Distill Qwen 32B - reasoning tests only (forced open thinking) // Note: Template uses forced-open mode (prompt ends with ), so input shouldn't include opening tag @@ -4056,6 +4554,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // MiMo-VL / Hermes 3 / Qwen 2.5 (Common JSON format) @@ -4069,6 +4588,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Reka Edge @@ -4114,6 +4641,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .is_partial(true) .expect(message_assist_call_cutoff_args) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Apriel 1.5 @@ -4124,6 +4672,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Apriel 1.6 Thinker (reasoning-only support) @@ -4147,6 +4703,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(simple_assist_msg("", "Here are my reasoning steps:\nI'm\nthinking", "special_function", "{\"arg1\":1}")) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n[BEGIN FINAL RESPONSE]\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Mistral Small 3.2 - FUNC_BRACKET_TAG format: [TOOL_CALLS]func_name[CALL_ID]id[ARGS]{...} @@ -4172,7 +4749,13 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_reconstruction() .run(); - + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Devstral { @@ -4194,18 +4777,42 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Llama 3.1 auto tst = peg_tester("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", detailed_debug); tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { // Llama 3.2 auto tst = peg_tester("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", detailed_debug); tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { // Llama 3.3 auto tst = peg_tester("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", detailed_debug); tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).expect_reconstruction().run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // GPT-OSS format tests @@ -4366,6 +4973,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .expect(message_assist_thoughts) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -4556,6 +5184,26 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // GigaChat V3 @@ -4576,6 +5224,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call_content) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // GigaChat V3.1 @@ -4596,55 +5252,154 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call_content) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } } -static void test_reka_edge_common_path() { - auto tmpls = read_templates("models/templates/Reka-Edge.jinja"); +static void test_template_generation_prompt() { + common_chat_msg system_msg; + system_msg.role = "system"; + system_msg.content ="You are a helpful assistant."; - { - common_chat_templates_inputs inputs; - common_chat_msg system_msg; - system_msg.role = "system"; - system_msg.content = "Use tools when needed."; + common_chat_msg tool_call_msg = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); - common_chat_msg tool_call_msg = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); + common_chat_msg tool_msg; + tool_msg.role = "tool"; + tool_msg.tool_name = "special_function"; + tool_msg.tool_call_id = "call0"; + tool_msg.content = "Sunny"; - common_chat_msg tool_msg; - tool_msg.role = "tool"; - tool_msg.tool_name = "special_function"; - tool_msg.tool_call_id = "call0"; - tool_msg.content = "Sunny"; + struct test_case_options { + std::vector messages; + bool add_generation_prompt = true; + common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + }; - inputs.messages = { system_msg, message_user, tool_call_msg, tool_msg, message_user }; - inputs.tools = { special_function_tool }; - inputs.enable_thinking = true; - inputs.add_generation_prompt = true; + auto basic = [&]() { + test_case_options opts; + opts.messages = { system_msg, message_user }; + return opts; + }; + + auto continuation_content = [&]() { + test_case_options opts; + opts.messages = { system_msg, message_user, message_assist_prefill_content }; + opts.add_generation_prompt = false; + opts.continue_final_message = COMMON_CHAT_CONTINUATION_CONTENT; + return opts; + }; + + auto continuation_reasoning = [&]() { + test_case_options opts; + opts.messages = { system_msg, message_user, message_assist_prefill_reasoning }; + opts.add_generation_prompt = false; + opts.continue_final_message = COMMON_CHAT_CONTINUATION_REASONING; + return opts; + }; + + auto check = [&](const common_chat_templates_ptr & tmpls, + const test_case_options & opts, + const std::string & expected_generation_prompt) { + common_chat_templates_inputs inputs; + inputs.messages = opts.messages; + inputs.add_generation_prompt = opts.add_generation_prompt; + inputs.continue_final_message = opts.continue_final_message; auto params = common_chat_templates_apply(tmpls.get(), inputs); - if (params.prompt.find("\nSunny\n") == std::string::npos) { - throw std::runtime_error("Reka Edge prompt did not render tool response history"); - } - if (params.prompt.rfind("assistant: \n") == std::string::npos) { - throw std::runtime_error("Reka Edge prompt did not render thinking generation prompt"); - } + assert_contains(params.prompt, system_msg.content); + assert_contains(params.prompt, message_user.content); + assert_equals(expected_generation_prompt, params.generation_prompt); + assert_ends_with(params.prompt, expected_generation_prompt); + }; + + { + auto tmpls = read_templates("models/templates/Qwen3.5-4B.jinja"); + check(tmpls, basic(), "<|im_start|>assistant\n\n"); + check(tmpls, continuation_content(), "<|im_start|>assistant\n\nI'm thinking\n\n\nHello, "); + check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n\nI'm"); } { - common_chat_templates_inputs inputs; - inputs.messages = { - message_user, - simple_assist_msg("The first point is") - }; - inputs.add_generation_prompt = false; - inputs.enable_thinking = false; - inputs.chat_template_kwargs["continue_final_message"] = "true"; + auto tmpls = read_templates("models/templates/openai-gpt-oss-120b.jinja"); + check(tmpls, basic(), "<|start|>assistant"); + check(tmpls, continuation_content(), "<|start|>assistant<|channel|>analysis<|message|>I'm thinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, "); + check(tmpls, continuation_reasoning(), "<|start|>assistant<|channel|>analysis<|message|>I'm"); + } - auto params = common_chat_templates_apply(tmpls.get(), inputs); - if (string_ends_with(params.prompt, "")) { - throw std::runtime_error("Reka Edge continue_final_message unexpectedly closed the assistant turn"); - } + { + auto tmpls = read_templates("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); + check(tmpls, basic(), ""); + check(tmpls, continuation_content(), "[THINK]I'm thinking[/THINK]Hello, "); + check(tmpls, continuation_reasoning(), "[THINK]I'm"); + } + + { + auto tmpls = read_templates("models/templates/google-gemma-4-31B-it.jinja"); + check(tmpls, basic(), "<|turn>model\n"); + check(tmpls, continuation_content(), "<|turn>model\n<|channel>thought\nI'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|turn>model\n<|channel>thought\nI'm"); + + // Special case when last message is a tool response + test_case_options after_tool_call = continuation_reasoning(); + after_tool_call.messages = { system_msg, message_user, tool_call_msg, tool_msg, message_assist_prefill_reasoning }; + check(tmpls, after_tool_call, "<|channel>thought\nI'm"); + } + + { + auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja"); + check(tmpls, basic(), "<|start_header_id|>assistant<|end_header_id|>\n\n>>>"); + check(tmpls, continuation_content(), "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\nHello, "); + check(tmpls, continuation_reasoning(), "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\n"); + } + + { + auto tmpls = read_templates("models/templates/Reka-Edge.jinja"); + check(tmpls, basic(), "assistant: \n"); + check(tmpls, continuation_content(), "assistant: \nI'm thinking\n\n\nHello, "); + check(tmpls, continuation_reasoning(), "assistant: \nI'm"); + } + + { + auto tmpls = read_templates("models/templates/moonshotai-Kimi-K2.jinja"); + check(tmpls, basic(), "<|im_assistant|>assistant<|im_middle|>"); + check(tmpls, continuation_content(), "<|im_assistant|>assistant<|im_middle|>I'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|im_assistant|>assistant<|im_middle|>I'm"); + } + + { + auto tmpls = read_templates("models/templates/LFM2-8B-A1B.jinja"); + check(tmpls, basic(), "<|im_start|>assistant\n"); + check(tmpls, continuation_content(), "<|im_start|>assistant\nI'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|im_start|>assistant\nI'm"); + } + + { + auto tmpls = read_templates("models/templates/LFM2.5-Instruct.jinja"); + check(tmpls, basic(), "<|im_start|>assistant\n"); + check(tmpls, continuation_content(), "<|im_start|>assistant\nI'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|im_start|>assistant\nI'm"); + } + + { + auto tmpls = read_templates("models/templates/GigaChat3-10B-A1.8B.jinja"); + check(tmpls, basic(), "assistant<|role_sep|>\n"); + check(tmpls, continuation_content(), "assistant<|role_sep|>\nHello, "); + check(tmpls, continuation_reasoning(), "assistant<|role_sep|>\n"); + } + + { + auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-V3.2.jinja"); + check(tmpls, basic(), "<|Assistant|>"); + check(tmpls, continuation_content(), "<|Assistant|>I'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|Assistant|>I'm"); } } @@ -4841,7 +5596,7 @@ int main(int argc, char ** argv) { test_tools_oaicompat_json_conversion(); test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); - test_reka_edge_common_path(); + test_template_generation_prompt(); test_template_output_peg_parsers(detailed_debug); std::cout << "\n[chat] All tests passed!" << '\n'; } diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 73de0d3bba11..dc00edfa82aa 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1032,23 +1032,33 @@ json oaicompat_chat_params_parse( auto caps = common_chat_templates_get_caps(opt.tmpls.get()); common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = common_chat_tools_parse_oaicompat(tools); - inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); - inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); - inputs.grammar = grammar; - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]); - inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - const bool continue_final_message = json_value(body, "continue_final_message", false); - if (continue_final_message && inputs.add_generation_prompt) { + inputs.messages = common_chat_msgs_parse_oaicompat(messages); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.use_jinja = opt.use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]); + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + inputs.continue_final_message = body.contains("continue_final_message") ? + common_chat_continuation_parse(body.at("continue_final_message")) : + COMMON_CHAT_CONTINUATION_NONE; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_NONE && opt.prefill_assistant + && !inputs.messages.empty() && inputs.messages.back().role == "assistant") { + if (inputs.messages.size() >= 2 && inputs.messages[inputs.messages.size() - 2].role == "assistant") { + throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); + } + inputs.continue_final_message = COMMON_CHAT_CONTINUATION_AUTO; + inputs.add_generation_prompt = false; + } + if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && inputs.add_generation_prompt) { throw std::invalid_argument("Cannot set both add_generation_prompt and continue_final_message to true."); } - inputs.reasoning_format = opt.reasoning_format; + inputs.reasoning_format = opt.reasoning_format; if (body.contains("reasoning_format")) { inputs.reasoning_format = common_reasoning_format_from_name(body.at("reasoning_format").get()); } - inputs.enable_thinking = opt.enable_thinking; + inputs.enable_thinking = opt.enable_thinking; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (body.contains("grammar")) { throw std::invalid_argument("Cannot use custom grammar constraints with tools."); @@ -1073,84 +1083,11 @@ json oaicompat_chat_params_parse( throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)"); } - // if the assistant message appears at the end of list, we do not add end-of-turn token - // for ex. this can be useful to modify the reasoning process in reasoning models - // continue_final_message is the explicit opt in alias from the vLLM/transformers API, - // equivalent to the prefill_assistant heuristic - bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" - && (continue_final_message || opt.prefill_assistant); - common_chat_msg last_message; - if (prefill_assistant_message) { - last_message = inputs.messages.back(); - inputs.messages.pop_back(); - - /* sanity check, max one assistant message at the end of the list */ - if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ - throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); - } - - // reject reasoning prefill on channel based templates that do not expose explicit thinking tags - if (!last_message.reasoning_content.empty() && inputs.enable_thinking) { - auto probe_params = common_chat_templates_apply(opt.tmpls.get(), inputs); - if (probe_params.supports_thinking && probe_params.thinking_end_tag.empty()) { - throw std::invalid_argument("Assistant prefill with reasoning_content is not supported yet for this template."); - } - } - - inputs.add_generation_prompt = true; - } inputs.force_pure_content = opt.force_pure_content; // Apply chat template to the list of messages auto chat_params = common_chat_templates_apply(opt.tmpls.get(), inputs); - /* Append assistant prefilled message */ - if (prefill_assistant_message) { - const bool thinking_active = chat_params.supports_thinking && !chat_params.thinking_end_tag.empty(); - const bool has_reasoning = !last_message.reasoning_content.empty(); - const bool has_content = !last_message.content.empty() || !last_message.content_parts.empty(); - const bool mid_reasoning = has_reasoning && !has_content; - - // some templates inject thinking_start in generation_prompt, others let the model emit it - const bool gp_has_think = thinking_active - && chat_params.generation_prompt.find(chat_params.thinking_start_tag) != std::string::npos; - - // open the thinking block when reasoning is present and the template did not inject it - if (has_reasoning) { - if (thinking_active && !gp_has_think) { - chat_params.prompt += chat_params.thinking_start_tag; - } - chat_params.prompt += last_message.reasoning_content; - } - - if (thinking_active) { - if (mid_reasoning) { - // model continues inside the thinking block, keep generation_prompt open on think - if (!gp_has_think) { - chat_params.generation_prompt += chat_params.thinking_start_tag; - } - } else { - // close thinking block when reasoning is followed by content, or when the template forced it open - if (has_reasoning || gp_has_think) { - chat_params.prompt += chat_params.thinking_end_tag; - } - // strip thinking_start from generation_prompt so the parser routes model output as content - auto pos = chat_params.generation_prompt.rfind(chat_params.thinking_start_tag); - if (pos != std::string::npos) { - chat_params.generation_prompt = chat_params.generation_prompt.substr(0, pos); - } - } - } - - if (!last_message.content_parts.empty()) { - for (auto & p : last_message.content_parts) { - chat_params.prompt += p.text; - } - } else { - chat_params.prompt += last_message.content; - } - } - llama_params["chat_format"] = static_cast(chat_params.format); llama_params["prompt"] = chat_params.prompt; if (!chat_params.grammar.empty()) { diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index cbc40a35fc4a..d45513dbebae 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -144,6 +144,17 @@ json task_params::to_json(bool only_metrics) const { // // task_result_state // +task_result_state::task_result_state(const common_chat_parser_params & chat_parser_params) + : chat_parser_params(chat_parser_params) + , oai_resp_id("resp_" + random_string()) + , oai_resp_reasoning_id("rs_" + random_string()) + , oai_resp_message_id("msg_" + random_string()) { + if (!chat_parser_params.echo) { + // initialize chat_msg to avoid emitting a delta containing the assistant prefill + chat_msg = common_chat_parse("", true, chat_parser_params); + } +} + common_chat_msg task_result_state::update_chat_msg( const std::string & text_added, bool is_partial, @@ -421,6 +432,7 @@ task_params server_task::params_from_json_cmpl( if (data.contains("chat_parser")) { params.chat_parser_params.parser.load(data.at("chat_parser").get()); } + params.chat_parser_params.echo = json_value(data, "echo", false); } { diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 64bdecd794f1..0978bb6ff162 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -112,11 +112,7 @@ struct task_result_state { const std::string oai_resp_message_id; std::string oai_resp_fc_id; // function call ID for current args delta - task_result_state(const common_chat_parser_params & chat_parser_params) - : chat_parser_params(chat_parser_params) - , oai_resp_id("resp_" + random_string()) - , oai_resp_reasoning_id("rs_" + random_string()) - , oai_resp_message_id("msg_" + random_string()) {} + task_result_state(const common_chat_parser_params & chat_parser_params); // parse partial tool calls and update the internal state common_chat_msg update_chat_msg( diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 243e41605782..f80e46133c77 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -158,11 +158,12 @@ def test_chat_template(): @pytest.mark.parametrize("prefill,re_prefill", [ ("Whill", "Whill"), - ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"), + ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Wh\n\nill"), ]) def test_chat_template_assistant_prefill(prefill, re_prefill): global server - server.chat_template = "llama3" + server.jinja = True + server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" server.debug = True # to get the "__verbose" object in the response server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -175,14 +176,15 @@ def test_chat_template_assistant_prefill(prefill, re_prefill): }) assert res.status_code == 200 assert "__verbose" in res.body - assert res.body["__verbose"]["prompt"] == f" <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}" + assert res.body["__verbose"]["prompt"].endswith(f"<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}") def test_chat_template_continue_final_message_vllm_compat(): """continue_final_message is the vLLM/transformers explicit alias for the prefill_assistant heuristic. Both must produce the same prompt.""" global server - server.chat_template = "llama3" + server.jinja = True + server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" server.debug = True server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -197,7 +199,7 @@ def test_chat_template_continue_final_message_vllm_compat(): }) assert res.status_code == 200 assert "__verbose" in res.body - assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill" + assert res.body["__verbose"]["prompt"].endswith("<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill") def test_chat_template_continue_final_message_mutual_exclusion(): From 3e12fbdea5c1ac4225c7dcf79506d30950283fc3 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 17 May 2026 23:30:25 +0800 Subject: [PATCH 3/8] llama: avoid copying logits during prompt decode in MTP (#23198) * llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm --- common/speculative.cpp | 27 +++++++++++++++-- common/speculative.h | 5 +++- src/llama-context.cpp | 51 ++++++++++++++++++++++++--------- src/llama-context.h | 2 +- src/llama-cparams.h | 3 +- src/llama-ext.h | 10 +++---- src/llama-graph.cpp | 3 ++ src/models/qwen35.cpp | 6 +++- src/models/qwen35moe.cpp | 6 +++- tools/server/server-context.cpp | 5 ++++ 10 files changed, 91 insertions(+), 27 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 3488b9393c5a..e591bab875db 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -146,8 +146,11 @@ struct common_speculative_impl { virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; - // true if this implementation requires the target context to extract embeddings + // true if this implementation requires the target context to extract post-norm embeddings virtual bool need_embd() const = 0; + + // true if this implementation requires the target context to extract pre-norm embeddings + virtual bool need_embd_pre_norm() const { return false; } }; struct common_speculative_impl_draft_simple : public common_speculative_impl { @@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } - llama_set_embeddings_pre_norm(ctx_tgt, true); - llama_set_embeddings_pre_norm(ctx_dft, true); + llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false); + llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true); pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); @@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { } bool need_embd() const override { + return false; + } + + bool need_embd_pre_norm() const override { return true; } }; @@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) { return false; } +bool common_speculative_need_embd_pre_norm(common_speculative * spec) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->need_embd_pre_norm()) { + return true; + } + } + + return false; +} + void common_speculative_draft(common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 614db9b1b509..f24bac79edb7 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co // process the batch and update the internal state of the speculative context bool common_speculative_process(common_speculative * spec, const llama_batch & batch); -// true if any implementation requires target embeddings to be extracted +// true if any implementation requires target post-norm embeddings to be extracted bool common_speculative_need_embd(common_speculative * spec); +// true if any implementation requires target pre-norm embeddings to be extracted +bool common_speculative_need_embd_pre_norm(common_speculative * spec); + // generate drafts for the sequences specified with `common_speculative_get_draft_params` void common_speculative_draft(common_speculative * spec); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d62abc4009b8..b1b12d017c0a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -895,8 +895,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { throw std::runtime_error("no pre-norm embeddings"); } - const int64_t j = output_resolve_row(i); const uint32_t n_embd = model.hparams.n_embd; + + if (!cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + } + return embd_pre_norm.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); return embd_pre_norm.data + j*n_embd; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); @@ -1088,10 +1097,11 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } -void llama_context::set_embeddings_pre_norm(bool value) { - LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); +void llama_context::set_embeddings_pre_norm(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); - cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm_masked = masked; } void llama_context::set_causal_attn(bool value) { @@ -1737,6 +1747,7 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; + int64_t n_tokens_prev = 0; do { const auto & ubatch = mctx->get_ubatch(); @@ -1882,16 +1893,21 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract pre-norm embeddings (hidden state before the final output norm) // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. - if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); - GGML_ASSERT(backend_h != nullptr); + { + const bool masked = cparams.embeddings_pre_norm_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; + + if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; - float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd; + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float)); + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + } } // Copy backend sampling output if this ubatch produced any sampling tensors. @@ -1908,6 +1924,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -1999,6 +2016,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd.size = has_embd ? n_embd_out*n_outputs_max : 0; embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_pre_norm.size = (size_t) n_embd * n_batch; + } + // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { @@ -3547,8 +3570,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) { - ctx->set_embeddings_pre_norm(value); +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_pre_norm(value, masked); } float * llama_get_embeddings_pre_norm(llama_context * ctx) { diff --git a/src/llama-context.h b/src/llama-context.h index e16ac4c618ba..d03f681d4a13 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -110,7 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); - void set_embeddings_pre_norm(bool value); + void set_embeddings_pre_norm(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 5898a1c38d51..20ec59fe3357 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -28,7 +28,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; - bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; diff --git a/src/llama-ext.h b/src/llama-ext.h index 11f1986676a5..edfa71c207c5 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -93,14 +93,14 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // pre-norm embeddings (hidden state before the final output norm) // -// mirrors: -// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); -LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value); +// Set whether the context outputs pre-norm embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); -LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); -// mirrors: // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 858c297dd762..31cf41a1c2d2 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -848,6 +848,9 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 2b4d5b14cd42..361d7538a038 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -211,6 +211,10 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "h_pre_norm", -1); res->t_h_pre_norm = cur; + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 22e3e1107655..4f63c410d668 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -234,6 +234,10 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "h_pre_norm", -1); res->t_h_pre_norm = cur; + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1ce7f0958279..0f3fb9efa3ca 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -243,6 +243,11 @@ struct server_slot { return task->need_embd() || (spec && common_speculative_need_embd(spec)); } + bool need_embd_pre_norm() const { + GGML_ASSERT(task); + return spec && common_speculative_need_embd_pre_norm(spec); + } + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens // (MTP supports splitting — uses task->need_embd() not need_embd()) From 84c678242a501edff0ea16551a88cfd76c11881c Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Sun, 17 May 2026 18:00:10 +0200 Subject: [PATCH 4/8] CUDA: Continue directly including cuda/iterator (#23102) Cont of #22936, forgot to update one site --- ggml/src/ggml-cuda/top-k.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 59ce36fb1c95..db1d39e2dc71 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -5,6 +5,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) # define CUB_TOP_K_AVAILABLE +# include using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 #endif // GGML_CUDA_USE_CUB From e0de4c24194a8b819c0df4ca1acc309f1aeb51fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 17 May 2026 18:07:21 +0200 Subject: [PATCH 5/8] cmake : do not install conversion script (#23204) --- CMakeLists.txt | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 447460723139..599827a8dc00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -286,18 +286,6 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake ${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama) -install( - FILES convert_hf_to_gguf.py - PERMISSIONS - OWNER_READ - OWNER_WRITE - OWNER_EXECUTE - GROUP_READ - GROUP_EXECUTE - WORLD_READ - WORLD_EXECUTE - DESTINATION ${CMAKE_INSTALL_BINDIR}) - configure_file(cmake/llama.pc.in "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" @ONLY) From 87589042cac2c390cec8d68fb2fad64e0a2a252a Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Sun, 17 May 2026 14:42:26 -0400 Subject: [PATCH 6/8] cmake : fix LLAMA_BUILD_UI logic (#23190) --- CMakeLists.txt | 9 ++------- common/common.h | 2 -- tools/server/server-http.cpp | 8 +++----- tools/ui/CMakeLists.txt | 13 +++---------- 4 files changed, 8 insertions(+), 24 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 599827a8dc00..d6d6bb0e7048 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,20 +108,15 @@ option(LLAMA_BUILD_TESTS "llama: build tests" option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) -# Deprecated: use LLAMA_BUILD_UI instead (kept for backward compat) -option(LLAMA_BUILD_WEBUI "llama: build the embedded Web UI for server (deprecated: use LLAMA_BUILD_UI)" ON) -option(LLAMA_USE_PREBUILT_WEBUI "llama: use prebuilt WebUI from HF Bucket when available (deprecated: use LLAMA_USE_PREBUILT_UI)" ON) - -# New option names option(LLAMA_BUILD_UI "llama: build the embedded Web UI for server" ON) option(LLAMA_USE_PREBUILT_UI "llama: use prebuilt UI from HF Bucket when available (requires LLAMA_BUILD_UI=ON)" ON) # Backward compat: when old var is set but new one isn't, forward the value -if(DEFINED LLAMA_BUILD_WEBUI AND NOT DEFINED LLAMA_BUILD_UI) +if(DEFINED LLAMA_BUILD_WEBUI) set(LLAMA_BUILD_UI ${LLAMA_BUILD_WEBUI}) message(DEPRECATION "LLAMA_BUILD_WEBUI is deprecated, use LLAMA_BUILD_UI instead") endif() -if(DEFINED LLAMA_USE_PREBUILT_WEBUI AND NOT DEFINED LLAMA_USE_PREBUILT_UI) +if(DEFINED LLAMA_USE_PREBUILT_WEBUI) set(LLAMA_USE_PREBUILT_UI ${LLAMA_USE_PREBUILT_WEBUI}) message(DEPRECATION "LLAMA_USE_PREBUILT_WEBUI is deprecated, use LLAMA_USE_PREBUILT_UI instead") endif() diff --git a/common/common.h b/common/common.h index 514bab11942c..1d3d788b2de4 100644 --- a/common/common.h +++ b/common/common.h @@ -617,8 +617,6 @@ struct common_params { // UI configs #ifdef LLAMA_UI_DEFAULT_ENABLED bool ui = LLAMA_UI_DEFAULT_ENABLED != 0; -#elif defined(LLAMA_WEBUI_DEFAULT_ENABLED) - bool ui = LLAMA_WEBUI_DEFAULT_ENABLED != 0; #else bool ui = true; // default to enabled when not set #endif diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 39a21f4ecc7f..9d008fc94c2a 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -231,11 +231,10 @@ bool server_http_context::init(const common_params & params) { }; auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { - (void)req; // suppress unused parameter warning when LLAMA_BUILD_UI / LLAMA_BUILD_WEBUI is not defined + (void)req; // suppress unused parameter warning when LLAMA_BUILD_UI is not defined bool ready = is_ready.load(); if (!ready) { -// Support both old and new preprocessor defines -#if defined(LLAMA_BUILD_UI) || defined(LLAMA_BUILD_WEBUI) +#if defined(LLAMA_BUILD_UI) auto tmp = string_split(req.path, '.'); if (req.path == "/" || (tmp.size() > 0 && tmp.back() == "html")) { res.status = 503; @@ -313,8 +312,7 @@ bool server_http_context::init(const common_params & params) { return 1; } } else { -// Support both old and new preprocessor defines -#if defined(LLAMA_BUILD_UI) || defined(LLAMA_BUILD_WEBUI) +#if defined(LLAMA_BUILD_UI) // using embedded static index.html srv->Get(params.api_prefix + "/", [](const httplib::Request & /*req*/, httplib::Response & res) { // COEP and COOP headers, required by pyodide (python interpreter) diff --git a/tools/ui/CMakeLists.txt b/tools/ui/CMakeLists.txt index 9687ca92e55e..383940cb6369 100644 --- a/tools/ui/CMakeLists.txt +++ b/tools/ui/CMakeLists.txt @@ -14,12 +14,7 @@ endif() set(TARGET_SRCS "") set(UI_COMPILE_DEFS "") -# Support both old (LLAMA_BUILD_WEBUI) and new (LLAMA_BUILD_UI) option names -if(LLAMA_BUILD_WEBUI OR LLAMA_BUILD_UI) - if(LLAMA_BUILD_WEBUI AND NOT LLAMA_BUILD_UI) - message(DEPRECATION "LLAMA_BUILD_WEBUI is deprecated, use LLAMA_BUILD_UI instead") - endif() - +if(LLAMA_BUILD_UI) set(PUBLIC_ASSETS index.html bundle.js @@ -125,19 +120,17 @@ if(LLAMA_BUILD_WEBUI OR LLAMA_BUILD_UI) endforeach() list(APPEND UI_COMPILE_DEFS - LLAMA_BUILD_WEBUI # Deprecated: use LLAMA_BUILD_UI LLAMA_BUILD_UI - LLAMA_WEBUI_DEFAULT_ENABLED=1 # Deprecated: use LLAMA_UI_DEFAULT_ENABLED LLAMA_UI_DEFAULT_ENABLED=1 ) message(STATUS "UI: embedded with source: ${UI_SOURCE}") else() message(WARNING "UI: no source available. Neither local build (build/tools/ui/dist/) nor HF Bucket download succeeded.") message(WARNING "UI: building server without embedded UI. Set LLAMA_BUILD_UI=OFF to suppress this warning.") - list(APPEND UI_COMPILE_DEFS LLAMA_WEBUI_DEFAULT_ENABLED=0 LLAMA_UI_DEFAULT_ENABLED=0) + list(APPEND UI_COMPILE_DEFS LLAMA_UI_DEFAULT_ENABLED=0) endif() else() - list(APPEND UI_COMPILE_DEFS LLAMA_WEBUI_DEFAULT_ENABLED=0 LLAMA_UI_DEFAULT_ENABLED=0) + list(APPEND UI_COMPILE_DEFS LLAMA_UI_DEFAULT_ENABLED=0) endif() # Build the static library From 726704a160c7d86acff1ff11e04b8316cf69d951 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Sun, 17 May 2026 15:05:11 -0600 Subject: [PATCH 7/8] feat: Support d_conv=15 for ssm-conv.cu (#23017) Branch: ModalityConditionalAdapters AI-usage: none Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/ssm-conv.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4841389fbc88..4c4daf85dc67 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -140,11 +140,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa }; switch (nc) { - case 3: launch_kernel(std::integral_constant{}); break; - case 4: launch_kernel(std::integral_constant{}); break; - case 5: launch_kernel(std::integral_constant{}); break; - case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + case 15: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now."); } } From dd7cad7197f991b18ded6aca46ff095972b95318 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 18 May 2026 02:33:14 +0200 Subject: [PATCH 8/8] cmake : do not check for bin install dir (#23234) --- cmake/llama-config.cmake.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/llama-config.cmake.in b/cmake/llama-config.cmake.in index 90cbec5b6f13..b4defc76ff0e 100644 --- a/cmake/llama-config.cmake.in +++ b/cmake/llama-config.cmake.in @@ -7,7 +7,7 @@ set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@) set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@") set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@") -set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@") +set(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@") find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake)