diff --git a/CMakeLists.txt b/CMakeLists.txt index 44746072313..d6d6bb0e704 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,20 +108,15 @@ option(LLAMA_BUILD_TESTS "llama: build tests" option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) -# Deprecated: use LLAMA_BUILD_UI instead (kept for backward compat) -option(LLAMA_BUILD_WEBUI "llama: build the embedded Web UI for server (deprecated: use LLAMA_BUILD_UI)" ON) -option(LLAMA_USE_PREBUILT_WEBUI "llama: use prebuilt WebUI from HF Bucket when available (deprecated: use LLAMA_USE_PREBUILT_UI)" ON) - -# New option names option(LLAMA_BUILD_UI "llama: build the embedded Web UI for server" ON) option(LLAMA_USE_PREBUILT_UI "llama: use prebuilt UI from HF Bucket when available (requires LLAMA_BUILD_UI=ON)" ON) # Backward compat: when old var is set but new one isn't, forward the value -if(DEFINED LLAMA_BUILD_WEBUI AND NOT DEFINED LLAMA_BUILD_UI) +if(DEFINED LLAMA_BUILD_WEBUI) set(LLAMA_BUILD_UI ${LLAMA_BUILD_WEBUI}) message(DEPRECATION "LLAMA_BUILD_WEBUI is deprecated, use LLAMA_BUILD_UI instead") endif() -if(DEFINED LLAMA_USE_PREBUILT_WEBUI AND NOT DEFINED LLAMA_USE_PREBUILT_UI) +if(DEFINED LLAMA_USE_PREBUILT_WEBUI) set(LLAMA_USE_PREBUILT_UI ${LLAMA_USE_PREBUILT_WEBUI}) message(DEPRECATION "LLAMA_USE_PREBUILT_WEBUI is deprecated, use LLAMA_USE_PREBUILT_UI instead") endif() @@ -286,18 +281,6 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake ${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama) -install( - FILES convert_hf_to_gguf.py - PERMISSIONS - OWNER_READ - OWNER_WRITE - OWNER_EXECUTE - GROUP_READ - GROUP_EXECUTE - WORLD_READ - WORLD_EXECUTE - DESTINATION ${CMAKE_INSTALL_BINDIR}) - configure_file(cmake/llama.pc.in "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" @ONLY) diff --git a/ci/run.sh b/ci/run.sh index 529da07779f..a8cbd3371d3 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -117,6 +117,12 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then # if on Mac, disable METAL if [[ "$OSTYPE" == "darwin"* ]]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=OFF -DGGML_BLAS=OFF" + + MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION="/usr/local/lib/cmake/vulkan" + MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION="${MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION}/SPIRV-Headers/SPIRV-HeadersConfig.cmake" + if [[ -f "${MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION}" || -h "${MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION}" ]]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DSPIRV-Headers_DIR=${MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION}/SPIRV-Headers" + fi fi # Build shared libs on Windows diff --git a/cmake/llama-config.cmake.in b/cmake/llama-config.cmake.in index 90cbec5b6f1..b4defc76ff0 100644 --- a/cmake/llama-config.cmake.in +++ b/cmake/llama-config.cmake.in @@ -7,7 +7,7 @@ set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@) set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@") set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@") -set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@") +set(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@") find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake) diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 6021fc4ede5..db3a6cc6fe3 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -43,11 +43,33 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & const autoparser & autoparser) { // Create the result structure common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = autoparser.preserved_tokens; + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = autoparser.preserved_tokens; + + std::string parser_generation_prompt = data.generation_prompt; + + if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !inputs.continue_msg.empty()) { + // Build up generation prompt manually + const auto & msg = inputs.continue_msg; + + if (!autoparser.reasoning.start.empty()) { + data.generation_prompt = data.generation_prompt.substr(0, data.generation_prompt.find(autoparser.reasoning.start)); + data.generation_prompt += autoparser.reasoning.start + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += autoparser.reasoning.end; + } + } + + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += msg.render_content(); + } + + data.prompt += data.generation_prompt; + } - auto parser = autoparser.build_parser(inputs); + auto parser = autoparser.build_parser(inputs, parser_generation_prompt); data.parser = parser.save(); // Build grammar if tools are present @@ -87,7 +109,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & return data; } -common_peg_arena autoparser::build_parser(const generation_params & inputs) const { +common_peg_arena autoparser::build_parser(const generation_params & inputs, const std::string & generation_prompt) const { if (!analysis_complete) { throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)"); } @@ -121,7 +143,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons } else { parser = content.build_parser(ctx); } - return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser; + return pure_content ? p.prefix(generation_prompt, reasoning.start) + parser : p.prefix(generation_prompt, reasoning.start) << parser; }); } diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 6c547409760..c680e686867 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -60,16 +60,21 @@ struct generation_params { common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; bool stream = true; std::string grammar; - bool add_generation_prompt = false; - bool enable_thinking = true; - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); - std::string generation_prompt; + bool add_generation_prompt = false; + common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + common_chat_msg continue_msg; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); json extra_context; bool add_bos = false; bool add_eos = false; bool is_inference = true; bool add_inference = false; bool mark_input = true; // whether to mark input strings in the jinja context + + bool has_continuation() const { + return continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !continue_msg.empty(); + } }; // ============================================================================ @@ -386,7 +391,7 @@ struct autoparser { void analyze_template(const common_chat_template & tmpl); // Build the PEG parser for this template - common_peg_arena build_parser(const generation_params & inputs) const; + common_peg_arena build_parser(const generation_params & inputs, const std::string & generation_prompt) const; private: // Collect tokens from entire analysis to preserve diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 79274febe7e..12e747d1ca1 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -785,7 +785,7 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s if (delimiter.empty()) { return literal(s); } - return literal(s.substr(0, s.rfind(delimiter))); + return literal(s.substr(0, s.find(delimiter))); } common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) { diff --git a/common/chat.cpp b/common/chat.cpp index 70b9f5dc2c5..56873e3a1e9 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -70,6 +70,26 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) { return !msg.content.empty() || !msg.tool_calls.empty(); } +std::string common_chat_msg::render_content(const std::string & delimiter) const { + if (!content.empty() && !content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); + } + if (!content.empty()) { + return content; + } + + std::string text; + for (const auto & part : content_parts) { + if (part.type == "text") { + if (!text.empty()) { + text += delimiter; + } + text += part.text; + } + } + return text; +} + json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { if (!content.empty() && !content_parts.empty()) { throw std::runtime_error("Cannot specify both content and content_parts"); @@ -451,6 +471,22 @@ std::vector common_chat_tools_parse_oaicompat(const json & too return result; } +common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value) { + if (value.is_boolean() && value.get()) { + return COMMON_CHAT_CONTINUATION_AUTO; + } + if (value.is_string()) { + auto value_str = value.get(); + if (value_str == "reasoning_content") { + return COMMON_CHAT_CONTINUATION_REASONING; + } + if (value_str == "content") { + return COMMON_CHAT_CONTINUATION_CONTENT; + } + } + return COMMON_CHAT_CONTINUATION_NONE; +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -811,6 +847,36 @@ std::string common_chat_template_direct_apply( return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt); } +static std::string common_chat_template_generation_prompt_impl( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs, + const std::optional & messages_override = std::nullopt, + const std::optional & tools_override = std::nullopt, + const std::optional & additional_context = std::nullopt) { + + auto adjusted_messages = messages_override ? *messages_override : inputs.messages; + + autoparser::generation_params params = inputs; + params.add_generation_prompt = false; + params.continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context); + params.add_generation_prompt = true; + std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context); + + size_t prefix_len = 0; + size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size()); + while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) { + prefix_len++; + } + return gen_prompt.substr(prefix_len); +} + +std::string common_chat_template_generation_prompt( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + return common_chat_template_generation_prompt_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt); +} + static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { common_chat_params data; @@ -863,6 +929,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ data.thinking_start_tag = "[THINK]"; data.thinking_end_tag = "[/THINK]"; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override = */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { "[THINK]", @@ -871,8 +938,19 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ "[ARGS]", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = "[THINK]" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "[/THINK]" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, "[THINK]"); + auto generation_prompt = p.eps(); auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); @@ -963,6 +1041,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp } data.prompt = prompt; + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; @@ -972,6 +1051,18 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp "<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>", }; + // Adjust prompt for continuation + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = "<|start|>assistant<|channel|>analysis<|message|>" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "<|end|><|start|>assistant<|channel|>final<|message|>" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); @@ -1080,12 +1171,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); if (inputs.add_generation_prompt && string_ends_with(data.prompt, "\n")) { // This may happen if the model generates content + tool_call, the // template does not add the model's next turn and confuses the model // from emitting its proper reasoning token sequence. - data.prompt += "<|turn>model\n"; + data.generation_prompt = "<|turn>model\n"; + data.prompt += data.generation_prompt; } data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; @@ -1101,13 +1194,25 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ "<|turn>", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = string_ends_with(data.prompt, "\n") ? "<|turn>model\n" : ""; + data.generation_prompt += "<|channel>thought\n" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>")); + auto start = p.rule("start", p.optional(p.literal("<|turn>model\n"))); if (extract_reasoning) { p.rule("thought", p.literal("<|channel>thought") + p.space() + p.reasoning(p.until("")) + p.literal("")); @@ -1224,15 +1329,22 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = { + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = { ">>>all", }; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + data.generation_prompt = "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\n" + msg.render_content(); + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { // Functionary v3.2 format: // - Normal content: >>>all\n{content} @@ -1244,7 +1356,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // When no tools, content goes until end auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>")); auto content_until_end = p.literal("all\n") + p.content(p.rest()); - auto generation_prompt = p.literal(inputs.generation_prompt); + auto generation_prompt = p.literal("<|start_header_id|>assistant<|end_header_id|>\n\n>>>"); // If no tools or tool_choice is NONE, just parse content if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { @@ -1318,9 +1430,10 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; data.preserved_tokens = { "<|tool_calls_section_begin|>", "<|tool_calls_section_end|>", @@ -1343,10 +1456,22 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_assistant|>assistant<|im_middle|>"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { // Kimi K2 Thinking format: // - Reasoning: {reasoning} @@ -1366,7 +1491,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning( p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) + p.optional(p.literal(THINK_END))) : p.eps(); - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); // Content only parser (no tools) @@ -1442,6 +1567,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1461,12 +1587,24 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat const std::string TOOL_CALL_END = "<|tool_call_end|>"; const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_start|>assistant\n"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -1521,6 +1659,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1536,12 +1675,24 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_start|>assistant\n"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -1592,6 +1743,7 @@ static common_chat_params common_chat_params_init_gigachat_v3( common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = false; data.preserved_tokens = { @@ -1599,6 +1751,12 @@ static common_chat_params common_chat_params_init_gigachat_v3( "<|role_sep|>\n", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + data.generation_prompt = "assistant<|role_sep|>\n" + msg.render_content(); + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n"; @@ -1634,7 +1792,7 @@ static common_chat_params common_chat_params_init_gigachat_v3( ret = p.content(p.rest()); } - return p.literal(inputs.generation_prompt) + ret; + return p.literal("assistant<|role_sep|>\n") + ret; }); data.parser = parser.save(); @@ -1662,12 +1820,13 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; data.thinking_start_tag = ""; data.thinking_end_tag = ""; - data.preserved_tokens = { + data.preserved_tokens = { "|DSML|", "", "", @@ -1687,9 +1846,21 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha const std::string INVOKE_END = ""; const std::string PARAM_START = "<" + DSML + "parameter"; const std::string PARAM_END = ""; + const std::string GEN_PROMPT = "<|Assistant|>"; + + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -2116,21 +2287,6 @@ std::optional common_chat_try_specialized_template( return std::nullopt; } -static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { - autoparser::generation_params params = inputs; - params.add_generation_prompt = false; - std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); - params.add_generation_prompt = true; - std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); - - size_t prefix_len = 0; - size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size()); - while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) { - prefix_len++; - } - return gen_prompt.substr(prefix_len); -} - static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) { autoparser::generation_params params; @@ -2149,6 +2305,27 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ params.add_bos = tmpls->add_bos; params.add_eos = tmpls->add_eos; + params.continue_final_message = inputs.continue_final_message; + if (params.continue_final_message != COMMON_CHAT_CONTINUATION_NONE) { + params.add_generation_prompt = false; + + if (!inputs.messages.empty()) { + // Render messages[:-1] and store continuation message separately + params.continue_msg = inputs.messages.back(); + params.messages.erase(params.messages.size() - 1); + } + + if (params.continue_final_message == COMMON_CHAT_CONTINUATION_AUTO && !inputs.messages.empty()) { + // Resolve based on message content + params.continue_final_message = COMMON_CHAT_CONTINUATION_CONTENT; + if (!params.continue_msg.reasoning_content.empty() && + params.continue_msg.content.empty() && + params.continue_msg.content_parts.empty()) { + params.continue_final_message = COMMON_CHAT_CONTINUATION_REASONING; + } + } + } + if (src.find("<|channel|>") == std::string::npos) { // map developer to system for all models except for GPT-OSS workaround::map_developer_role_to_system(params.messages); @@ -2169,8 +2346,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ workaround::func_args_not_string(params.messages); } - params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params); - params.extra_context = common_chat_extra_context(); for (auto el : inputs.chat_template_kwargs) { params.extra_context[el.first] = json::parse(el.second); @@ -2200,17 +2375,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ auto params_copy = params; params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE; data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, params); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.generation_prompt = params.generation_prompt; - auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) { - return p.prefix(params.generation_prompt) << p.content(p.rest()); + auto parser = build_chat_peg_parser([&data](common_chat_peg_builder &p) { + return p.literal(data.generation_prompt) << p.content(p.rest()); }); data.parser = parser.save(); return data; } if (auto result = common_chat_try_specialized_template(tmpl, src, params)) { - result->generation_prompt = params.generation_prompt; return *result; } @@ -2224,7 +2398,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start); auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end); } - auto_params.generation_prompt = params.generation_prompt; common_peg_arena arena; arena.load(auto_params.parser); LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str()); diff --git a/common/chat.h b/common/chat.h index 054f5ffe777..8ace3e6ba69 100644 --- a/common/chat.h +++ b/common/chat.h @@ -89,6 +89,8 @@ struct common_chat_msg { nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const; + std::string render_content(const std::string & delimiter = "\n\n") const; + bool empty() const { return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); @@ -164,12 +166,22 @@ enum common_chat_format { COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; + +// Continuation method provided via `continue_final_message` +enum common_chat_continuation { + COMMON_CHAT_CONTINUATION_NONE, + COMMON_CHAT_CONTINUATION_AUTO, + COMMON_CHAT_CONTINUATION_REASONING, + COMMON_CHAT_CONTINUATION_CONTENT, +}; + struct common_chat_templates_inputs { std::vector messages; std::string grammar; std::string json_schema; - bool add_generation_prompt = true; - bool use_jinja = true; + bool add_generation_prompt = true; + common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + bool use_jinja = true; // Parameters below only supported when use_jinja is true std::vector tools; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; @@ -207,6 +219,7 @@ struct common_chat_parser_params { bool reasoning_in_content = false; std::string generation_prompt; bool parse_tool_calls = true; + bool echo = false; // Include assistant prefilled msg in output bool debug = false; // Enable debug output for PEG parser common_peg_arena parser = {}; common_chat_parser_params() = default; @@ -267,6 +280,8 @@ std::vector common_chat_msgs_parse_oaicompat(const nlohmann::or std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); +common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value); + // DEPRECATED: only used in tests nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); @@ -279,6 +294,10 @@ std::string common_chat_template_direct_apply( const common_chat_template & tmpl, const autoparser::generation_params & inputs); +std::string common_chat_template_generation_prompt( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs); + std::optional common_chat_try_specialized_template( const common_chat_template & tmpl, const std::string & src, diff --git a/common/common.h b/common/common.h index 514bab11942..1d3d788b2de 100644 --- a/common/common.h +++ b/common/common.h @@ -617,8 +617,6 @@ struct common_params { // UI configs #ifdef LLAMA_UI_DEFAULT_ENABLED bool ui = LLAMA_UI_DEFAULT_ENABLED != 0; -#elif defined(LLAMA_WEBUI_DEFAULT_ENABLED) - bool ui = LLAMA_WEBUI_DEFAULT_ENABLED != 0; #else bool ui = true; // default to enabled when not set #endif diff --git a/common/speculative.cpp b/common/speculative.cpp index 3488b9393c5..e591bab875d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -146,8 +146,11 @@ struct common_speculative_impl { virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; - // true if this implementation requires the target context to extract embeddings + // true if this implementation requires the target context to extract post-norm embeddings virtual bool need_embd() const = 0; + + // true if this implementation requires the target context to extract pre-norm embeddings + virtual bool need_embd_pre_norm() const { return false; } }; struct common_speculative_impl_draft_simple : public common_speculative_impl { @@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } - llama_set_embeddings_pre_norm(ctx_tgt, true); - llama_set_embeddings_pre_norm(ctx_dft, true); + llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false); + llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true); pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); @@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { } bool need_embd() const override { + return false; + } + + bool need_embd_pre_norm() const override { return true; } }; @@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) { return false; } +bool common_speculative_need_embd_pre_norm(common_speculative * spec) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->need_embd_pre_norm()) { + return true; + } + } + + return false; +} + void common_speculative_draft(common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 614db9b1b50..f24bac79edb 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co // process the batch and update the internal state of the speculative context bool common_speculative_process(common_speculative * spec, const llama_batch & batch); -// true if any implementation requires target embeddings to be extracted +// true if any implementation requires target post-norm embeddings to be extracted bool common_speculative_need_embd(common_speculative * spec); +// true if any implementation requires target pre-norm embeddings to be extracted +bool common_speculative_need_embd_pre_norm(common_speculative * spec); + // generate drafts for the sequences specified with `common_speculative_get_draft_params` void common_speculative_draft(common_speculative * spec); diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4841389fbc8..4c4daf85dc6 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -140,11 +140,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa }; switch (nc) { - case 3: launch_kernel(std::integral_constant{}); break; - case 4: launch_kernel(std::integral_constant{}); break; - case 5: launch_kernel(std::integral_constant{}); break; - case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + case 15: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now."); } } diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 59ce36fb1c9..db1d39e2dc7 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -5,6 +5,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) # define CUB_TOP_K_AVAILABLE +# include using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 #endif // GGML_CUDA_USE_CUB diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 715a263a6d0..6dbcea065b3 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,6 +8,8 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +find_package(SPIRV-Headers REQUIRED) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files add_definitions(/MP) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d62abc4009b..b1b12d017c0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -895,8 +895,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { throw std::runtime_error("no pre-norm embeddings"); } - const int64_t j = output_resolve_row(i); const uint32_t n_embd = model.hparams.n_embd; + + if (!cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + } + return embd_pre_norm.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); return embd_pre_norm.data + j*n_embd; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); @@ -1088,10 +1097,11 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } -void llama_context::set_embeddings_pre_norm(bool value) { - LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); +void llama_context::set_embeddings_pre_norm(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); - cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm_masked = masked; } void llama_context::set_causal_attn(bool value) { @@ -1737,6 +1747,7 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; + int64_t n_tokens_prev = 0; do { const auto & ubatch = mctx->get_ubatch(); @@ -1882,16 +1893,21 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract pre-norm embeddings (hidden state before the final output norm) // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. - if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); - GGML_ASSERT(backend_h != nullptr); + { + const bool masked = cparams.embeddings_pre_norm_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; + + if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; - float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd; + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float)); + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + } } // Copy backend sampling output if this ubatch produced any sampling tensors. @@ -1908,6 +1924,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -1999,6 +2016,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd.size = has_embd ? n_embd_out*n_outputs_max : 0; embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_pre_norm.size = (size_t) n_embd * n_batch; + } + // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { @@ -3547,8 +3570,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) { - ctx->set_embeddings_pre_norm(value); +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_pre_norm(value, masked); } float * llama_get_embeddings_pre_norm(llama_context * ctx) { diff --git a/src/llama-context.h b/src/llama-context.h index e16ac4c618b..d03f681d4a1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -110,7 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); - void set_embeddings_pre_norm(bool value); + void set_embeddings_pre_norm(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 5898a1c38d5..20ec59fe335 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -28,7 +28,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; - bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; diff --git a/src/llama-ext.h b/src/llama-ext.h index 11f1986676a..edfa71c207c 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -93,14 +93,14 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // pre-norm embeddings (hidden state before the final output norm) // -// mirrors: -// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); -LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value); +// Set whether the context outputs pre-norm embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); -LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); -// mirrors: // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 858c297dd76..31cf41a1c2d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -848,6 +848,9 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 2b4d5b14cd4..361d7538a03 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -211,6 +211,10 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "h_pre_norm", -1); res->t_h_pre_norm = cur; + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 22e3e110765..4f63c410d66 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -234,6 +234,10 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "h_pre_norm", -1); res->t_h_pre_norm = cur; + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 05c60d29743..a428ef35c18 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1,4 +1,4 @@ -// Tests chat handling, including grammar generation and parsing for tool calling, for various templates. +// Tests chat handling, including grammar genration and parsing for tool calling, for various templates. // // Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates, // e.g. given Minja (http://github.com/google/minja) checked out in parent dir: @@ -100,6 +100,25 @@ template static void assert_equals(const T & expected, const T & actua } } +static void assert_contains(const std::string & haystack, const std::string & needle) { + if (haystack.find(needle) == std::string::npos) { + LOG_ERR("Expected to contain: %s\n", needle.c_str()); + LOG_ERR("Actual: %s\n", haystack.c_str()); + common_log_flush(common_log_main()); + throw std::runtime_error("Test failed"); + } +} + +static void assert_ends_with(const std::string & str, const std::string & suffix) { + if (str.size() < suffix.size() || + str.compare(str.size() - suffix.size(), suffix.size(), suffix) != 0) { + LOG_ERR("Expected to end with: %s\n", suffix.c_str()); + LOG_ERR("Actual: %s\n", str.c_str()); + common_log_flush(common_log_main()); + throw std::runtime_error("Test failed"); + } +} + static std::string read_file(const std::string & path) { std::ifstream fs(path, std::ios_base::binary); if (!fs.is_open()) { @@ -945,6 +964,8 @@ const common_chat_msg message_assist_call_python_lines_unclosed = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')"); const common_chat_msg message_assist_json_content = simple_assist_msg("{\n \"response\": \"Hello, world!\\nWhat's up?\"\n}"); +const common_chat_msg message_assist_prefill_content = simple_assist_msg("Hello, ", "I'm thinking"); +const common_chat_msg message_assist_prefill_reasoning = simple_assist_msg("", "I'm"); // Use for PEG parser implementations struct peg_test_case { @@ -1350,7 +1371,10 @@ class peg_test_builder { peg_test_case tc_; public: - peg_test_builder(peg_tester & tester, const std::string & input) : tester_(tester) { tc_.input = input; } + peg_test_builder(peg_tester & tester, const std::string & input) : tester_(tester) { + tc_.input = input; + tc_.params.add_generation_prompt = true; + } // Parameter setters peg_test_builder & reasoning_format(common_reasoning_format fmt) { @@ -1373,6 +1397,16 @@ class peg_test_builder { return *this; } + peg_test_builder & add_generation_prompt(bool val) { + tc_.params.add_generation_prompt = val; + return *this; + } + + peg_test_builder & continue_final_message(common_chat_continuation cont) { + tc_.params.continue_final_message = cont; + return *this; + } + peg_test_builder & json_schema(const std::string & schema) { tc_.params.json_schema = schema; return *this; @@ -2045,6 +2079,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_content("Final answer without tools.") .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + { common_chat_msg user_start; user_start.role = "user"; @@ -2204,6 +2259,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking[/THINK]Hello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2350,6 +2426,26 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2409,6 +2505,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).expect_reconstruction().run(); tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2580,6 +2684,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist) .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + { // additional tests for https://github.com/ggml-org/llama.cpp/pull/21760 auto tmpls = read_templates("models/templates/google-gemma-4-31B-it.jinja"); @@ -2633,6 +2758,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); tst.test("Hello, world!").reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(simple_assist_msg("Hello, world!")).run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { // NousResearch-Hermes-2-Pro and Hermes-3 (tool calling models) @@ -2656,6 +2802,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Note: Hermes template doesn't support thinking/reasoning natively // Note: We only support one tool calling format per template, no alternate formats + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { // Test simple content-only template @@ -2679,6 +2833,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) // .expect(message_assist_thoughts) // .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2695,6 +2870,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -2784,6 +2967,26 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -3019,6 +3222,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "set_unit", R"({"unit": "celsius"})", {} }, }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug); @@ -3030,6 +3241,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}")) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -3284,6 +3516,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "magic_int", R"({"ref": 42, "name": "foo bar"})", {} }, }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // GLM-4.6 tests - format: function_name\n...\n...\n @@ -3380,6 +3633,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_reconstruction() .run(); } + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Verify the throw path produces a readable error message, not std::out_of_range. @@ -3630,6 +3904,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { } }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -3657,6 +3952,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(kimi_id_special_func_tool_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // LFM2-8B-A1B tests - uses <|tool_list_start|>/<|tool_list_end|> and <|tool_call_start|>[name(args)]<|tool_call_end|> @@ -3742,6 +4058,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "special_function", R"({"arg1": 1})", {} }, }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // LFM2.5 tests - uses plain "List of tools: [...]" and bare [name(args)] without wrapper tokens @@ -3803,6 +4140,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ empty_args_tool }) .expect(simple_assist_msg("", "", "empty_args", "{}")) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Reka-Edge tests - uses native JSON format with per-call wrapper @@ -3888,6 +4246,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "special_function", R"({"arg1": 1})", {} }, }) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } @@ -3900,6 +4279,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // MiniMax-M2 tests - XML invoke format with parameter tags @@ -3924,6 +4311,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // NVIDIA-Nemotron-Nano-v2 tests - ... format @@ -3934,6 +4342,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // CohereForAI-c4ai-command-r7b (uses START_RESPONSE/END_RESPONSE, START_THINKING/END_THINKING, START_ACTION/END_ACTION) @@ -3963,6 +4379,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // mistralai-Mistral-Nemo-Instruct-2407.jinja @@ -3974,6 +4398,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call_id) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.1.jinja", detailed_debug); @@ -3983,6 +4415,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Functionary v3.2 - recipient-based format: >>>recipient\n{content} { @@ -3993,6 +4433,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // FireFunction @@ -4004,6 +4452,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // DeepSeek R1 Distill Llama 8B - reasoning tests only (forced open thinking) @@ -4019,6 +4475,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .expect(message_assist_thoughts) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // llama-cpp DeepSeek R1 template (always forced-open thinking) { @@ -4036,6 +4513,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .parallel_tool_calls(true) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // DeepSeek R1 Distill Qwen 32B - reasoning tests only (forced open thinking) // Note: Template uses forced-open mode (prompt ends with ), so input shouldn't include opening tag @@ -4056,6 +4554,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // MiMo-VL / Hermes 3 / Qwen 2.5 (Common JSON format) @@ -4069,6 +4588,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Reka Edge @@ -4114,6 +4641,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .is_partial(true) .expect(message_assist_call_cutoff_args) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Apriel 1.5 @@ -4124,6 +4672,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(message_assist_call) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Apriel 1.6 Thinker (reasoning-only support) @@ -4147,6 +4703,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ special_function_tool }) .expect(simple_assist_msg("", "Here are my reasoning steps:\nI'm\nthinking", "special_function", "{\"arg1\":1}")) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n[BEGIN FINAL RESPONSE]\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Mistral Small 3.2 - FUNC_BRACKET_TAG format: [TOOL_CALLS]func_name[CALL_ID]id[ARGS]{...} @@ -4172,7 +4749,13 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_reconstruction() .run(); - + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // Devstral { @@ -4194,18 +4777,42 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Llama 3.1 auto tst = peg_tester("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", detailed_debug); tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { // Llama 3.2 auto tst = peg_tester("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", detailed_debug); tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } { // Llama 3.3 auto tst = peg_tester("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", detailed_debug); tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).expect_reconstruction().run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // GPT-OSS format tests @@ -4366,6 +4973,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .expect(message_assist_thoughts) .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } { @@ -4556,6 +5184,26 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .run(); + // Continuation tests + tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); + + tst.test(" thinking\n\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .enable_thinking(true) + .messages({ message_user, message_assist_prefill_reasoning }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_REASONING) + .expect_reasoning("I'm thinking") + .expect_content("Hello, world!\nWhat's up?") + .run(); } // GigaChat V3 @@ -4576,6 +5224,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call_content) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } // GigaChat V3.1 @@ -4596,55 +5252,154 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect(message_assist_call_content) .expect_reconstruction() .run(); + + // Continuation tests + tst.test("world!\nWhat's up?") + .messages({ message_user, message_assist_prefill_content }) + .add_generation_prompt(false) + .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_content("Hello, world!\nWhat's up?") + .run(); } } -static void test_reka_edge_common_path() { - auto tmpls = read_templates("models/templates/Reka-Edge.jinja"); +static void test_template_generation_prompt() { + common_chat_msg system_msg; + system_msg.role = "system"; + system_msg.content ="You are a helpful assistant."; - { - common_chat_templates_inputs inputs; - common_chat_msg system_msg; - system_msg.role = "system"; - system_msg.content = "Use tools when needed."; + common_chat_msg tool_call_msg = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); - common_chat_msg tool_call_msg = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); + common_chat_msg tool_msg; + tool_msg.role = "tool"; + tool_msg.tool_name = "special_function"; + tool_msg.tool_call_id = "call0"; + tool_msg.content = "Sunny"; - common_chat_msg tool_msg; - tool_msg.role = "tool"; - tool_msg.tool_name = "special_function"; - tool_msg.tool_call_id = "call0"; - tool_msg.content = "Sunny"; + struct test_case_options { + std::vector messages; + bool add_generation_prompt = true; + common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + }; - inputs.messages = { system_msg, message_user, tool_call_msg, tool_msg, message_user }; - inputs.tools = { special_function_tool }; - inputs.enable_thinking = true; - inputs.add_generation_prompt = true; + auto basic = [&]() { + test_case_options opts; + opts.messages = { system_msg, message_user }; + return opts; + }; + + auto continuation_content = [&]() { + test_case_options opts; + opts.messages = { system_msg, message_user, message_assist_prefill_content }; + opts.add_generation_prompt = false; + opts.continue_final_message = COMMON_CHAT_CONTINUATION_CONTENT; + return opts; + }; + + auto continuation_reasoning = [&]() { + test_case_options opts; + opts.messages = { system_msg, message_user, message_assist_prefill_reasoning }; + opts.add_generation_prompt = false; + opts.continue_final_message = COMMON_CHAT_CONTINUATION_REASONING; + return opts; + }; + + auto check = [&](const common_chat_templates_ptr & tmpls, + const test_case_options & opts, + const std::string & expected_generation_prompt) { + common_chat_templates_inputs inputs; + inputs.messages = opts.messages; + inputs.add_generation_prompt = opts.add_generation_prompt; + inputs.continue_final_message = opts.continue_final_message; auto params = common_chat_templates_apply(tmpls.get(), inputs); - if (params.prompt.find("\nSunny\n") == std::string::npos) { - throw std::runtime_error("Reka Edge prompt did not render tool response history"); - } - if (params.prompt.rfind("assistant: \n") == std::string::npos) { - throw std::runtime_error("Reka Edge prompt did not render thinking generation prompt"); - } + assert_contains(params.prompt, system_msg.content); + assert_contains(params.prompt, message_user.content); + assert_equals(expected_generation_prompt, params.generation_prompt); + assert_ends_with(params.prompt, expected_generation_prompt); + }; + + { + auto tmpls = read_templates("models/templates/Qwen3.5-4B.jinja"); + check(tmpls, basic(), "<|im_start|>assistant\n\n"); + check(tmpls, continuation_content(), "<|im_start|>assistant\n\nI'm thinking\n\n\nHello, "); + check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n\nI'm"); } { - common_chat_templates_inputs inputs; - inputs.messages = { - message_user, - simple_assist_msg("The first point is") - }; - inputs.add_generation_prompt = false; - inputs.enable_thinking = false; - inputs.chat_template_kwargs["continue_final_message"] = "true"; + auto tmpls = read_templates("models/templates/openai-gpt-oss-120b.jinja"); + check(tmpls, basic(), "<|start|>assistant"); + check(tmpls, continuation_content(), "<|start|>assistant<|channel|>analysis<|message|>I'm thinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, "); + check(tmpls, continuation_reasoning(), "<|start|>assistant<|channel|>analysis<|message|>I'm"); + } - auto params = common_chat_templates_apply(tmpls.get(), inputs); - if (string_ends_with(params.prompt, "")) { - throw std::runtime_error("Reka Edge continue_final_message unexpectedly closed the assistant turn"); - } + { + auto tmpls = read_templates("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); + check(tmpls, basic(), ""); + check(tmpls, continuation_content(), "[THINK]I'm thinking[/THINK]Hello, "); + check(tmpls, continuation_reasoning(), "[THINK]I'm"); + } + + { + auto tmpls = read_templates("models/templates/google-gemma-4-31B-it.jinja"); + check(tmpls, basic(), "<|turn>model\n"); + check(tmpls, continuation_content(), "<|turn>model\n<|channel>thought\nI'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|turn>model\n<|channel>thought\nI'm"); + + // Special case when last message is a tool response + test_case_options after_tool_call = continuation_reasoning(); + after_tool_call.messages = { system_msg, message_user, tool_call_msg, tool_msg, message_assist_prefill_reasoning }; + check(tmpls, after_tool_call, "<|channel>thought\nI'm"); + } + + { + auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja"); + check(tmpls, basic(), "<|start_header_id|>assistant<|end_header_id|>\n\n>>>"); + check(tmpls, continuation_content(), "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\nHello, "); + check(tmpls, continuation_reasoning(), "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\n"); + } + + { + auto tmpls = read_templates("models/templates/Reka-Edge.jinja"); + check(tmpls, basic(), "assistant: \n"); + check(tmpls, continuation_content(), "assistant: \nI'm thinking\n\n\nHello, "); + check(tmpls, continuation_reasoning(), "assistant: \nI'm"); + } + + { + auto tmpls = read_templates("models/templates/moonshotai-Kimi-K2.jinja"); + check(tmpls, basic(), "<|im_assistant|>assistant<|im_middle|>"); + check(tmpls, continuation_content(), "<|im_assistant|>assistant<|im_middle|>I'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|im_assistant|>assistant<|im_middle|>I'm"); + } + + { + auto tmpls = read_templates("models/templates/LFM2-8B-A1B.jinja"); + check(tmpls, basic(), "<|im_start|>assistant\n"); + check(tmpls, continuation_content(), "<|im_start|>assistant\nI'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|im_start|>assistant\nI'm"); + } + + { + auto tmpls = read_templates("models/templates/LFM2.5-Instruct.jinja"); + check(tmpls, basic(), "<|im_start|>assistant\n"); + check(tmpls, continuation_content(), "<|im_start|>assistant\nI'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|im_start|>assistant\nI'm"); + } + + { + auto tmpls = read_templates("models/templates/GigaChat3-10B-A1.8B.jinja"); + check(tmpls, basic(), "assistant<|role_sep|>\n"); + check(tmpls, continuation_content(), "assistant<|role_sep|>\nHello, "); + check(tmpls, continuation_reasoning(), "assistant<|role_sep|>\n"); + } + + { + auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-V3.2.jinja"); + check(tmpls, basic(), "<|Assistant|>"); + check(tmpls, continuation_content(), "<|Assistant|>I'm thinkingHello, "); + check(tmpls, continuation_reasoning(), "<|Assistant|>I'm"); } } @@ -4841,7 +5596,7 @@ int main(int argc, char ** argv) { test_tools_oaicompat_json_conversion(); test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); - test_reka_edge_common_path(); + test_template_generation_prompt(); test_template_output_peg_parsers(detailed_debug); std::cout << "\n[chat] All tests passed!" << '\n'; } diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 73de0d3bba1..dc00edfa82a 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1032,23 +1032,33 @@ json oaicompat_chat_params_parse( auto caps = common_chat_templates_get_caps(opt.tmpls.get()); common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = common_chat_tools_parse_oaicompat(tools); - inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); - inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); - inputs.grammar = grammar; - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]); - inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - const bool continue_final_message = json_value(body, "continue_final_message", false); - if (continue_final_message && inputs.add_generation_prompt) { + inputs.messages = common_chat_msgs_parse_oaicompat(messages); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.use_jinja = opt.use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]); + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + inputs.continue_final_message = body.contains("continue_final_message") ? + common_chat_continuation_parse(body.at("continue_final_message")) : + COMMON_CHAT_CONTINUATION_NONE; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_NONE && opt.prefill_assistant + && !inputs.messages.empty() && inputs.messages.back().role == "assistant") { + if (inputs.messages.size() >= 2 && inputs.messages[inputs.messages.size() - 2].role == "assistant") { + throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); + } + inputs.continue_final_message = COMMON_CHAT_CONTINUATION_AUTO; + inputs.add_generation_prompt = false; + } + if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && inputs.add_generation_prompt) { throw std::invalid_argument("Cannot set both add_generation_prompt and continue_final_message to true."); } - inputs.reasoning_format = opt.reasoning_format; + inputs.reasoning_format = opt.reasoning_format; if (body.contains("reasoning_format")) { inputs.reasoning_format = common_reasoning_format_from_name(body.at("reasoning_format").get()); } - inputs.enable_thinking = opt.enable_thinking; + inputs.enable_thinking = opt.enable_thinking; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (body.contains("grammar")) { throw std::invalid_argument("Cannot use custom grammar constraints with tools."); @@ -1073,84 +1083,11 @@ json oaicompat_chat_params_parse( throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)"); } - // if the assistant message appears at the end of list, we do not add end-of-turn token - // for ex. this can be useful to modify the reasoning process in reasoning models - // continue_final_message is the explicit opt in alias from the vLLM/transformers API, - // equivalent to the prefill_assistant heuristic - bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" - && (continue_final_message || opt.prefill_assistant); - common_chat_msg last_message; - if (prefill_assistant_message) { - last_message = inputs.messages.back(); - inputs.messages.pop_back(); - - /* sanity check, max one assistant message at the end of the list */ - if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ - throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); - } - - // reject reasoning prefill on channel based templates that do not expose explicit thinking tags - if (!last_message.reasoning_content.empty() && inputs.enable_thinking) { - auto probe_params = common_chat_templates_apply(opt.tmpls.get(), inputs); - if (probe_params.supports_thinking && probe_params.thinking_end_tag.empty()) { - throw std::invalid_argument("Assistant prefill with reasoning_content is not supported yet for this template."); - } - } - - inputs.add_generation_prompt = true; - } inputs.force_pure_content = opt.force_pure_content; // Apply chat template to the list of messages auto chat_params = common_chat_templates_apply(opt.tmpls.get(), inputs); - /* Append assistant prefilled message */ - if (prefill_assistant_message) { - const bool thinking_active = chat_params.supports_thinking && !chat_params.thinking_end_tag.empty(); - const bool has_reasoning = !last_message.reasoning_content.empty(); - const bool has_content = !last_message.content.empty() || !last_message.content_parts.empty(); - const bool mid_reasoning = has_reasoning && !has_content; - - // some templates inject thinking_start in generation_prompt, others let the model emit it - const bool gp_has_think = thinking_active - && chat_params.generation_prompt.find(chat_params.thinking_start_tag) != std::string::npos; - - // open the thinking block when reasoning is present and the template did not inject it - if (has_reasoning) { - if (thinking_active && !gp_has_think) { - chat_params.prompt += chat_params.thinking_start_tag; - } - chat_params.prompt += last_message.reasoning_content; - } - - if (thinking_active) { - if (mid_reasoning) { - // model continues inside the thinking block, keep generation_prompt open on think - if (!gp_has_think) { - chat_params.generation_prompt += chat_params.thinking_start_tag; - } - } else { - // close thinking block when reasoning is followed by content, or when the template forced it open - if (has_reasoning || gp_has_think) { - chat_params.prompt += chat_params.thinking_end_tag; - } - // strip thinking_start from generation_prompt so the parser routes model output as content - auto pos = chat_params.generation_prompt.rfind(chat_params.thinking_start_tag); - if (pos != std::string::npos) { - chat_params.generation_prompt = chat_params.generation_prompt.substr(0, pos); - } - } - } - - if (!last_message.content_parts.empty()) { - for (auto & p : last_message.content_parts) { - chat_params.prompt += p.text; - } - } else { - chat_params.prompt += last_message.content; - } - } - llama_params["chat_format"] = static_cast(chat_params.format); llama_params["prompt"] = chat_params.prompt; if (!chat_params.grammar.empty()) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1ce7f095827..0f3fb9efa3c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -243,6 +243,11 @@ struct server_slot { return task->need_embd() || (spec && common_speculative_need_embd(spec)); } + bool need_embd_pre_norm() const { + GGML_ASSERT(task); + return spec && common_speculative_need_embd_pre_norm(spec); + } + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens // (MTP supports splitting — uses task->need_embd() not need_embd()) diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 39a21f4ecc7..9d008fc94c2 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -231,11 +231,10 @@ bool server_http_context::init(const common_params & params) { }; auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { - (void)req; // suppress unused parameter warning when LLAMA_BUILD_UI / LLAMA_BUILD_WEBUI is not defined + (void)req; // suppress unused parameter warning when LLAMA_BUILD_UI is not defined bool ready = is_ready.load(); if (!ready) { -// Support both old and new preprocessor defines -#if defined(LLAMA_BUILD_UI) || defined(LLAMA_BUILD_WEBUI) +#if defined(LLAMA_BUILD_UI) auto tmp = string_split(req.path, '.'); if (req.path == "/" || (tmp.size() > 0 && tmp.back() == "html")) { res.status = 503; @@ -313,8 +312,7 @@ bool server_http_context::init(const common_params & params) { return 1; } } else { -// Support both old and new preprocessor defines -#if defined(LLAMA_BUILD_UI) || defined(LLAMA_BUILD_WEBUI) +#if defined(LLAMA_BUILD_UI) // using embedded static index.html srv->Get(params.api_prefix + "/", [](const httplib::Request & /*req*/, httplib::Response & res) { // COEP and COOP headers, required by pyodide (python interpreter) diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index cbc40a35fc4..d45513dbeba 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -144,6 +144,17 @@ json task_params::to_json(bool only_metrics) const { // // task_result_state // +task_result_state::task_result_state(const common_chat_parser_params & chat_parser_params) + : chat_parser_params(chat_parser_params) + , oai_resp_id("resp_" + random_string()) + , oai_resp_reasoning_id("rs_" + random_string()) + , oai_resp_message_id("msg_" + random_string()) { + if (!chat_parser_params.echo) { + // initialize chat_msg to avoid emitting a delta containing the assistant prefill + chat_msg = common_chat_parse("", true, chat_parser_params); + } +} + common_chat_msg task_result_state::update_chat_msg( const std::string & text_added, bool is_partial, @@ -421,6 +432,7 @@ task_params server_task::params_from_json_cmpl( if (data.contains("chat_parser")) { params.chat_parser_params.parser.load(data.at("chat_parser").get()); } + params.chat_parser_params.echo = json_value(data, "echo", false); } { diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 64bdecd794f..0978bb6ff16 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -112,11 +112,7 @@ struct task_result_state { const std::string oai_resp_message_id; std::string oai_resp_fc_id; // function call ID for current args delta - task_result_state(const common_chat_parser_params & chat_parser_params) - : chat_parser_params(chat_parser_params) - , oai_resp_id("resp_" + random_string()) - , oai_resp_reasoning_id("rs_" + random_string()) - , oai_resp_message_id("msg_" + random_string()) {} + task_result_state(const common_chat_parser_params & chat_parser_params); // parse partial tool calls and update the internal state common_chat_msg update_chat_msg( diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 243e4160578..f80e46133c7 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -158,11 +158,12 @@ def test_chat_template(): @pytest.mark.parametrize("prefill,re_prefill", [ ("Whill", "Whill"), - ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"), + ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Wh\n\nill"), ]) def test_chat_template_assistant_prefill(prefill, re_prefill): global server - server.chat_template = "llama3" + server.jinja = True + server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" server.debug = True # to get the "__verbose" object in the response server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -175,14 +176,15 @@ def test_chat_template_assistant_prefill(prefill, re_prefill): }) assert res.status_code == 200 assert "__verbose" in res.body - assert res.body["__verbose"]["prompt"] == f" <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}" + assert res.body["__verbose"]["prompt"].endswith(f"<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}") def test_chat_template_continue_final_message_vllm_compat(): """continue_final_message is the vLLM/transformers explicit alias for the prefill_assistant heuristic. Both must produce the same prompt.""" global server - server.chat_template = "llama3" + server.jinja = True + server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" server.debug = True server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -197,7 +199,7 @@ def test_chat_template_continue_final_message_vllm_compat(): }) assert res.status_code == 200 assert "__verbose" in res.body - assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill" + assert res.body["__verbose"]["prompt"].endswith("<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill") def test_chat_template_continue_final_message_mutual_exclusion(): diff --git a/tools/ui/CMakeLists.txt b/tools/ui/CMakeLists.txt index 9687ca92e55..383940cb636 100644 --- a/tools/ui/CMakeLists.txt +++ b/tools/ui/CMakeLists.txt @@ -14,12 +14,7 @@ endif() set(TARGET_SRCS "") set(UI_COMPILE_DEFS "") -# Support both old (LLAMA_BUILD_WEBUI) and new (LLAMA_BUILD_UI) option names -if(LLAMA_BUILD_WEBUI OR LLAMA_BUILD_UI) - if(LLAMA_BUILD_WEBUI AND NOT LLAMA_BUILD_UI) - message(DEPRECATION "LLAMA_BUILD_WEBUI is deprecated, use LLAMA_BUILD_UI instead") - endif() - +if(LLAMA_BUILD_UI) set(PUBLIC_ASSETS index.html bundle.js @@ -125,19 +120,17 @@ if(LLAMA_BUILD_WEBUI OR LLAMA_BUILD_UI) endforeach() list(APPEND UI_COMPILE_DEFS - LLAMA_BUILD_WEBUI # Deprecated: use LLAMA_BUILD_UI LLAMA_BUILD_UI - LLAMA_WEBUI_DEFAULT_ENABLED=1 # Deprecated: use LLAMA_UI_DEFAULT_ENABLED LLAMA_UI_DEFAULT_ENABLED=1 ) message(STATUS "UI: embedded with source: ${UI_SOURCE}") else() message(WARNING "UI: no source available. Neither local build (build/tools/ui/dist/) nor HF Bucket download succeeded.") message(WARNING "UI: building server without embedded UI. Set LLAMA_BUILD_UI=OFF to suppress this warning.") - list(APPEND UI_COMPILE_DEFS LLAMA_WEBUI_DEFAULT_ENABLED=0 LLAMA_UI_DEFAULT_ENABLED=0) + list(APPEND UI_COMPILE_DEFS LLAMA_UI_DEFAULT_ENABLED=0) endif() else() - list(APPEND UI_COMPILE_DEFS LLAMA_WEBUI_DEFAULT_ENABLED=0 LLAMA_UI_DEFAULT_ENABLED=0) + list(APPEND UI_COMPILE_DEFS LLAMA_UI_DEFAULT_ENABLED=0) endif() # Build the static library