Skip to content

Commit a4701c9

Browse files
authored
common/autoparser: fixes for newline handling / forced tool calls (ggml-org#22654)
* chat/autoparser: the fixes * Move optspace() to chat-peg-parser, comment out server tests invalidated due to content now allowed with forced tool calls. * Trim whitespace on apply instead
1 parent 994118a commit a4701c9

10 files changed

Lines changed: 402 additions & 107 deletions

common/chat-auto-parser-generator.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
136136
if (!end.empty()) {
137137
if (!start.empty()) {
138138
// Standard tag-based: optional(<think>reasoning</think>)
139-
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
139+
return p.optional(p.optspace(start) + p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end));
140140
}
141141
// Delimiter-style (empty start)
142-
return p.optional(p.reasoning(p.until(end)) + end + p.space());
142+
return p.optional(p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end));
143143
}
144144
}
145145

@@ -186,7 +186,6 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const
186186
common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const {
187187
auto & p = ctx.p;
188188
const auto & inputs = ctx.inputs;
189-
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
190189

191190
// Build effective field names with dot notation if function_field is set
192191
std::string name_field = format.name_field;
@@ -225,8 +224,7 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
225224
tool_start = format.per_call_start;
226225
}
227226

228-
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser +
229-
p.end();
227+
return ctx.reasoning_parser + p.optional(p.content(p.until(tool_start))) + tools_parser + p.end();
230228
}
231229

232230
common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name,
@@ -270,7 +268,6 @@ common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p,
270268
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
271269
auto & p = ctx.p;
272270
const auto & inputs = ctx.inputs;
273-
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
274271

275272
common_peg_parser tool_choice = p.choice();
276273

@@ -336,14 +333,12 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
336333

337334
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
338335
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
339-
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
340-
p.end();
336+
return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end();
341337
}
342338

343339
common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const {
344340
auto & p = ctx.p;
345341
const auto & inputs = ctx.inputs;
346-
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
347342

348343
auto until_suffix = p.rule("until-suffix", p.until(arguments.value_suffix));
349344

@@ -471,8 +466,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
471466

472467
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
473468
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
474-
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
475-
p.end();
469+
return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end();
476470
}
477471

478472
} // namespace autoparser

common/chat-diff-analyzer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ void analyze_reasoning::compare_thinking_enabled() {
342342
if (left_trimmed.empty() && !diff.right.empty()) {
343343
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
344344
if (start.empty()) {
345-
start = trim_leading_whitespace(diff.right);
345+
start = diff.right;
346346
mode = reasoning_mode::TAG_BASED;
347347
}
348348
}
@@ -353,7 +353,7 @@ void analyze_reasoning::compare_thinking_enabled() {
353353
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
354354
start = seg[seg.size() - 2].value;
355355
}
356-
end = trim_trailing_whitespace(diff.left);
356+
end = diff.left;
357357
mode = reasoning_mode::TAG_BASED;
358358
}
359359
}
@@ -445,14 +445,14 @@ void analyze_reasoning::compare_reasoning_scope() {
445445
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
446446
if (result.result.success()) {
447447
start = result.tags["pre"];
448-
end = trim_trailing_whitespace(result.tags["post"]);
448+
end = result.tags["post"];
449449
} else {
450450
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
451451
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
452452
});
453453
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
454454
if (result.result.success()) {
455-
end = trim_trailing_whitespace(result.tags["post"]);
455+
end = result.tags["post"];
456456
} else {
457457
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
458458
mode = reasoning_mode::NONE;

common/chat-peg-parser.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,32 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s
816816
return literal(s.substr(0, s.rfind(delimiter)));
817817
}
818818

819+
common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) {
820+
auto parser = eps();
821+
size_t end_of_prefix_space = tag.size();
822+
size_t start_of_suffix_space = tag.size();
823+
for (size_t i = 0; i < tag.size(); i++) {
824+
if (!std::isspace(tag[i])) {
825+
end_of_prefix_space = i;
826+
break;
827+
}
828+
}
829+
for (size_t i = tag.size(); i > 0; i--) {
830+
if (!std::isspace(tag[i - 1])) {
831+
start_of_suffix_space = i;
832+
break;
833+
}
834+
}
835+
for (size_t i = 0; i < end_of_prefix_space; i++) {
836+
parser += optional(literal(std::string(1, tag[i])));
837+
}
838+
parser += literal(tag.substr(end_of_prefix_space, start_of_suffix_space - end_of_prefix_space));
839+
for (size_t i = start_of_suffix_space; i < tag.size(); i++) {
840+
parser += optional(literal(std::string(1, tag[i])));
841+
}
842+
return parser;
843+
}
844+
819845
common_peg_parser common_chat_peg_builder::standard_json_tools(
820846
const std::string & section_start,
821847
const std::string & section_end,

common/chat-peg-parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class common_chat_peg_builder : public common_peg_parser_builder {
9696
// Return a parser that parses the prefix of a string, up to a given delimiter.
9797
common_peg_parser prefix(const std::string & s, const std::string & delimiter = {});
9898

99+
// Return a parser that parses all elements of tag, but leading and trailing spaces are optional
100+
common_peg_parser optspace(const std::string & tag);
101+
99102
// Legacy-compatible helper for building standard JSON tool calls
100103
// Used by tests and manual parsers
101104
// name_key/args_key: JSON key names for function name and arguments

common/chat.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2221,8 +2221,8 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
22212221
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
22222222
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
22232223
if (auto_params.supports_thinking) {
2224-
auto_params.thinking_start_tag = autoparser.reasoning.start;
2225-
auto_params.thinking_end_tag = autoparser.reasoning.end;
2224+
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
2225+
auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end);
22262226
}
22272227
auto_params.generation_prompt = params.generation_prompt;
22282228
common_peg_arena arena;

common/reasoning-budget.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
158158
for (size_t i = 0; i < cur_p->size; i++) {
159159
if (cur_p->data[i].id != forced) {
160160
cur_p->data[i].logit = -INFINITY;
161+
} else {
162+
cur_p->data[i].logit = +INFINITY; // force the token
161163
}
162164
}
163165
}

scripts/server-test-function-call.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,18 @@ def print_info(msg):
7979
# ---------------------------------------------------------------------------
8080

8181

82-
def chat_completion(url, messages, tools=None, stream=False):
82+
def chat_completion(url, messages, tools=None, stream=False, force_tools=False):
8383
payload = {
8484
"messages": messages,
8585
"stream": stream,
8686
"max_tokens": 4096,
8787
}
8888
if tools:
8989
payload["tools"] = tools
90-
payload["tool_choice"] = "auto"
90+
if force_tools:
91+
payload["tool_choice"] = "required"
92+
else:
93+
payload["tool_choice"] = "auto"
9194

9295
try:
9396
response = requests.post(url, json=payload, stream=stream)
@@ -160,7 +163,13 @@ def chat_completion(url, messages, tools=None, stream=False):
160163
return result
161164

162165

163-
def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6):
166+
def all_tools_called(tools, all_tool_calls):
167+
all_tool_names = set([tc["function"]["name"] for tc in tools])
168+
all_called_tool_names = set([tc["function"]["name"] for tc in all_tool_calls])
169+
return all_tool_names == all_called_tool_names
170+
171+
172+
def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6, force_tools=False):
164173
"""
165174
Drive the multi-turn tool-call loop:
166175
1. Send messages to model.
@@ -172,8 +181,8 @@ def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turn
172181
msgs = list(messages)
173182
all_tool_calls: list[dict] = []
174183

175-
for _ in range(max_turns):
176-
result = chat_completion(url, msgs, tools=tools, stream=stream)
184+
for t in range(max_turns):
185+
result = chat_completion(url, msgs, tools=tools, stream=stream, force_tools=(force_tools and not all_tools_called(tools, all_tool_calls)))
177186
if result is None:
178187
return all_tool_calls, None
179188

@@ -235,17 +244,18 @@ def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turn
235244
# ---------------------------------------------------------------------------
236245

237246

238-
def run_test(url, test_case, stream):
247+
def run_test(url, test_case, stream, force_tools):
239248
name = test_case["name"]
240249
mode = f"{'stream' if stream else 'non-stream'}"
241-
print_header(f"{name} [{mode}]")
250+
print_header(f"{name} [{mode}, force_tools={force_tools}] ")
242251

243252
all_tool_calls, final_content = run_agentic_loop(
244253
url,
245254
messages=test_case["messages"],
246255
tools=test_case["tools"],
247256
mock_tool_responses=test_case["mock_tool_responses"],
248257
stream=stream,
258+
force_tools=force_tools
249259
)
250260

251261
if final_content is None and not all_tool_calls:
@@ -1093,6 +1103,9 @@ def main():
10931103
parser.add_argument(
10941104
"--stream-only", action="store_true", help="Only run streaming mode tests"
10951105
)
1106+
parser.add_argument(
1107+
"--force-tools", action="store_true", help="Change tool mode to forced instead of auto"
1108+
)
10961109
parser.add_argument(
10971110
"--test",
10981111
help="Run only the test whose name contains this substring (case-insensitive)",
@@ -1103,10 +1116,13 @@ def main():
11031116
print_info(f"Testing server at {url}")
11041117

11051118
modes = []
1119+
force_tools = False
11061120
if not args.stream_only:
11071121
modes.append(False)
11081122
if not args.no_stream:
11091123
modes.append(True)
1124+
if args.force_tools:
1125+
force_tools = True
11101126

11111127
cases: list[dict] = ALL_TEST_CASES
11121128
if args.test:
@@ -1121,7 +1137,7 @@ def main():
11211137
for stream in modes:
11221138
for case in cases:
11231139
total += 1
1124-
if run_test(url, case, stream=stream):
1140+
if run_test(url, case, stream=stream, force_tools=force_tools):
11251141
passed += 1
11261142

11271143
color = GREEN if passed == total else RED

0 commit comments

Comments
 (0)