|
8 | 8 | #include "../src/llama-grammar.h" |
9 | 9 | #include "../src/unicode.h" |
10 | 10 | #include "../tools/server/server-chat.h" |
| 11 | +#include "../tools/server/server-task.h" |
11 | 12 | #include "chat-auto-parser.h" |
12 | 13 | #include "chat.h" |
13 | 14 | #include "common.h" |
@@ -4434,6 +4435,141 @@ static void test_msg_diffs_compute() { |
4434 | 4435 | } |
4435 | 4436 | } |
4436 | 4437 |
|
| 4438 | +static void test_task_result_state_tool_call_stream_filter() { |
| 4439 | + auto tmpls = read_templates("models/templates/Kimi-K2-Thinking.jinja"); |
| 4440 | + |
| 4441 | + common_chat_templates_inputs inputs; |
| 4442 | + inputs.messages = { message_user }; |
| 4443 | + inputs.tools = { special_function_tool }; |
| 4444 | + inputs.parallel_tool_calls = true; |
| 4445 | + |
| 4446 | + make_peg_parser parser(tmpls.get(), inputs); |
| 4447 | + |
| 4448 | + common_chat_parser_params parser_params(parser.params_); |
| 4449 | + parser_params.parser = parser.arena_; |
| 4450 | + parser_params.parse_tool_calls = true; |
| 4451 | + |
| 4452 | + task_result_state state(parser_params); |
| 4453 | + |
| 4454 | + { |
| 4455 | + std::vector<common_chat_msg_diff> diffs; |
| 4456 | + state.update_chat_msg( |
| 4457 | + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0", |
| 4458 | + true, |
| 4459 | + diffs, |
| 4460 | + true); |
| 4461 | + assert_equals(size_t(0), diffs.size()); |
| 4462 | + } |
| 4463 | + |
| 4464 | + { |
| 4465 | + std::vector<common_chat_msg_diff> diffs; |
| 4466 | + state.update_chat_msg("<|tool_call_argument_begin|>", true, diffs, true); |
| 4467 | + assert_equals(size_t(1), diffs.size()); |
| 4468 | + assert_equals(size_t(0), diffs[0].tool_call_index); |
| 4469 | + assert_equals(std::string("special_function"), diffs[0].tool_call_delta.name); |
| 4470 | + assert_equals(std::string("functions.special_function:0"), diffs[0].tool_call_delta.id); |
| 4471 | + assert_equals(std::string(""), diffs[0].tool_call_delta.arguments); |
| 4472 | + } |
| 4473 | + |
| 4474 | + { |
| 4475 | + std::vector<common_chat_msg_diff> diffs; |
| 4476 | + state.update_chat_msg("{\"arg1\": ", true, diffs, true); |
| 4477 | + assert_equals(size_t(0), diffs.size()); |
| 4478 | + } |
| 4479 | + |
| 4480 | + { |
| 4481 | + std::vector<common_chat_msg_diff> diffs; |
| 4482 | + state.update_chat_msg("1}<|tool_call_end|><|tool_calls_section_end|>", true, diffs, true); |
| 4483 | + assert_equals(size_t(1), diffs.size()); |
| 4484 | + assert_equals(size_t(0), diffs[0].tool_call_index); |
| 4485 | + assert_equals(std::string(""), diffs[0].tool_call_delta.name); |
| 4486 | + assert_equals(std::string(""), diffs[0].tool_call_delta.id); |
| 4487 | + assert_equals(std::string("{\"arg1\": 1}"), diffs[0].tool_call_delta.arguments); |
| 4488 | + } |
| 4489 | + |
| 4490 | + { |
| 4491 | + task_result_state raw_state(parser_params); |
| 4492 | + std::vector<common_chat_msg_diff> diffs; |
| 4493 | + raw_state.update_chat_msg("Visible before marker\n", true, diffs, true); |
| 4494 | + assert_equals(size_t(1), diffs.size()); |
| 4495 | + assert_equals(std::string("Visible before marker\n"), diffs[0].content_delta); |
| 4496 | + |
| 4497 | + diffs.clear(); |
| 4498 | + raw_state.update_chat_msg("<function=read the llama_perf_context_data struct>", true, diffs, true); |
| 4499 | + assert_equals(size_t(0), diffs.size()); |
| 4500 | + |
| 4501 | + diffs.clear(); |
| 4502 | + raw_state.update_chat_msg(" trailing text", false, diffs, true); |
| 4503 | + assert_equals(size_t(0), diffs.size()); |
| 4504 | + } |
| 4505 | + |
| 4506 | + { |
| 4507 | + task_result_state code_state(parser_params); |
| 4508 | + std::vector<common_chat_msg_diff> diffs; |
| 4509 | + const std::string code = "```xml\n<function=example>\n```\n"; |
| 4510 | + code_state.update_chat_msg(code, true, diffs, true); |
| 4511 | + assert_equals(size_t(1), diffs.size()); |
| 4512 | + assert_equals(code, diffs[0].content_delta); |
| 4513 | + } |
| 4514 | + |
| 4515 | + { |
| 4516 | + auto qwen_tmpls = read_templates("models/templates/Qwen3.5-4B.jinja"); |
| 4517 | + |
| 4518 | + common_chat_templates_inputs qwen_inputs; |
| 4519 | + qwen_inputs.messages = { message_user }; |
| 4520 | + qwen_inputs.tools = { special_function_tool }; |
| 4521 | + qwen_inputs.parallel_tool_calls = true; |
| 4522 | + |
| 4523 | + make_peg_parser qwen_parser(qwen_tmpls.get(), qwen_inputs); |
| 4524 | + common_chat_parser_params qwen_params(qwen_parser.params_); |
| 4525 | + qwen_params.parser = qwen_parser.arena_; |
| 4526 | + qwen_params.parse_tool_calls = true; |
| 4527 | + |
| 4528 | + bool has_direct_function_trigger = false; |
| 4529 | + for (const auto & trigger : qwen_parser.params_.grammar_triggers) { |
| 4530 | + has_direct_function_trigger = has_direct_function_trigger || trigger.value == "<function="; |
| 4531 | + } |
| 4532 | + assert_equals(true, has_direct_function_trigger); |
| 4533 | + |
| 4534 | + const std::string direct_call = |
| 4535 | + "<function=special_function>\n" |
| 4536 | + "<parameter=arg1>\n" |
| 4537 | + "1\n" |
| 4538 | + "</parameter>\n" |
| 4539 | + "</function>\n"; |
| 4540 | + const auto direct_msg = common_chat_peg_parse(qwen_parser.arena_, direct_call, false, qwen_params); |
| 4541 | + assert_equals(size_t(1), direct_msg.tool_calls.size()); |
| 4542 | + assert_equals(std::string("special_function"), direct_msg.tool_calls[0].name); |
| 4543 | + assert_equals(std::string("{\"arg1\":1}"), direct_msg.tool_calls[0].arguments); |
| 4544 | + |
| 4545 | + task_result_state qwen_state(qwen_params); |
| 4546 | + |
| 4547 | + std::vector<common_chat_msg_diff> diffs; |
| 4548 | + qwen_state.update_chat_msg("<function=special_function>\n", true, diffs, true); |
| 4549 | + assert_equals(std::vector<common_chat_msg_diff>{}, diffs); |
| 4550 | + |
| 4551 | + diffs.clear(); |
| 4552 | + qwen_state.update_chat_msg("<parameter=arg1>\n", true, diffs, true); |
| 4553 | + assert_equals(size_t(1), diffs.size()); |
| 4554 | + assert_equals(size_t(0), diffs[0].tool_call_index); |
| 4555 | + assert_equals(std::string("special_function"), diffs[0].tool_call_delta.name); |
| 4556 | + assert_equals(false, diffs[0].tool_call_delta.id.empty()); |
| 4557 | + assert_equals(std::string(""), diffs[0].tool_call_delta.arguments); |
| 4558 | + |
| 4559 | + diffs.clear(); |
| 4560 | + qwen_state.update_chat_msg("1\n", true, diffs, true); |
| 4561 | + assert_equals(size_t(0), diffs.size()); |
| 4562 | + |
| 4563 | + diffs.clear(); |
| 4564 | + qwen_state.update_chat_msg("</parameter>\n</function>\n", true, diffs, true); |
| 4565 | + assert_equals(size_t(1), diffs.size()); |
| 4566 | + assert_equals(size_t(0), diffs[0].tool_call_index); |
| 4567 | + assert_equals(std::string(""), diffs[0].tool_call_delta.name); |
| 4568 | + assert_equals(std::string(""), diffs[0].tool_call_delta.id); |
| 4569 | + assert_equals(std::string("{\"arg1\":1}"), diffs[0].tool_call_delta.arguments); |
| 4570 | + } |
| 4571 | +} |
| 4572 | + |
4437 | 4573 | int main(int argc, char ** argv) { |
4438 | 4574 | bool detailed_debug = false; |
4439 | 4575 | bool only_run_filtered = false; |
@@ -4512,6 +4648,7 @@ int main(int argc, char ** argv) { |
4512 | 4648 | test_convert_responses_to_chatcmpl(); |
4513 | 4649 | test_developer_role_to_system_workaround(); |
4514 | 4650 | test_reka_edge_common_path(); |
| 4651 | + test_task_result_state_tool_call_stream_filter(); |
4515 | 4652 | test_template_output_peg_parsers(detailed_debug); |
4516 | 4653 | std::cout << "\n[chat] All tests passed!" << '\n'; |
4517 | 4654 | } |
|
0 commit comments