Skip to content

Commit 1bbec6a

Browse files
authored
jinja : add capability check for object args (#20612)
1 parent f47a246 commit 1bbec6a

3 files changed

Lines changed: 114 additions & 13 deletions

File tree

common/chat.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1519,7 +1519,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
15191519
// map developer to system for all models except for GPT-OSS
15201520
workaround::map_developer_role_to_system(params.messages);
15211521
}
1522-
workaround::func_args_not_string(params.messages);
15231522

15241523
if (!tmpl.original_caps().supports_system_role) {
15251524
workaround::system_message_not_supported(params.messages);
@@ -1532,6 +1531,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
15321531
workaround::requires_non_null_content(params.messages);
15331532
}
15341533

1534+
if (tmpl.original_caps().supports_object_arguments) {
1535+
workaround::func_args_not_string(params.messages);
1536+
}
1537+
15351538
params.extra_context = common_chat_extra_context();
15361539
for (auto el : inputs.chat_template_kwargs) {
15371540
params.extra_context[el.first] = json::parse(el.second);

common/jinja/caps.cpp

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ std::map<std::string, bool> caps::to_map() const {
7575
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
7676
{"supports_system_role", supports_system_role},
7777
{"supports_preserve_reasoning", supports_preserve_reasoning},
78+
{"supports_object_arguments", supports_object_arguments},
7879
};
7980
}
8081

@@ -158,9 +159,9 @@ caps caps_get(jinja::program & prog) {
158159
}
159160
);
160161

161-
JJ_DEBUG("%s\n", ">>> Running capability check: single tool support");
162+
JJ_DEBUG("%s\n", ">>> Running capability check: single tool with object arguments support");
162163

163-
// case: tools support: single call
164+
// case: tools support: single call with object arguments
164165
caps_try_execute(
165166
prog,
166167
[&]() {
@@ -226,9 +227,7 @@ caps caps_get(jinja::program & prog) {
226227
},
227228
[&](bool success, value & messages, value & tools) {
228229
if (!success) {
229-
result.supports_tool_calls = false;
230-
result.supports_tools = false;
231-
return;
230+
return; // Nothing can be inferred
232231
}
233232

234233
auto & tool_name = tools->at(0)->at("function")->at("name");
@@ -242,16 +241,117 @@ caps caps_get(jinja::program & prog) {
242241
caps_print_stats(tool_calls, "messages[1].tool_calls");
243242
if (!tool_calls->stats.used) {
244243
result.supports_tool_calls = false;
244+
return;
245+
}
246+
247+
auto & tool_arg = tool_calls->at(0)->at("function")->at("arguments")->at("arg");
248+
caps_print_stats(tool_arg, "messages[1].tool_calls[0].function.arguments.arg");
249+
if (tool_arg->stats.used) {
250+
result.supports_object_arguments = true;
245251
}
246252
}
247253
);
248254

255+
if (!result.supports_object_arguments) {
256+
JJ_DEBUG("%s\n", ">>> Running capability check: single tool with string arguments support");
257+
258+
// case: tools support: single call with string arguments
259+
caps_try_execute(
260+
prog,
261+
[&]() {
262+
// messages
263+
return json::array({
264+
{
265+
{"role", "user"},
266+
{"content", "User message"},
267+
},
268+
{
269+
{"role", "assistant"},
270+
{"content", ""}, // Some templates expect content to be empty with tool calls
271+
{"tool_calls", json::array({
272+
{
273+
{"id", "call00001"},
274+
{"type", "function"},
275+
{"function", {
276+
{"name", "tool1"},
277+
{"arguments", R"({"arg": "value"})"}
278+
}}
279+
}
280+
})}
281+
},
282+
{
283+
{"role", "tool"},
284+
{"content", "Tool response"},
285+
{"tool_call_id", "call00001"}
286+
},
287+
{
288+
{"role", "assistant"},
289+
{"content", "The tool response was 'tool response'"}
290+
},
291+
{
292+
{"role", "user"},
293+
{"content", "User message"},
294+
},
295+
});
296+
},
297+
[&]() {
298+
// tools
299+
return json::array({
300+
{
301+
{"name", "tool"},
302+
{"type", "function"},
303+
{"function", {
304+
{"name", "tool1"},
305+
{"description", "Tool description"},
306+
{"parameters", {
307+
{"type", "object"},
308+
{"properties", {
309+
{"arg", {
310+
{"type", "string"},
311+
{"description", "Arg description"},
312+
}},
313+
}},
314+
{"required", json::array({ "arg" })},
315+
}},
316+
}},
317+
},
318+
});
319+
},
320+
[&](bool success, value & messages, value & tools) {
321+
if (!success) {
322+
result.supports_tool_calls = false;
323+
result.supports_tools = false;
324+
return;
325+
}
326+
327+
auto & tool_name = tools->at(0)->at("function")->at("name");
328+
caps_print_stats(tool_name, "tools[0].function.name");
329+
caps_print_stats(tools, "tools");
330+
if (!tool_name->stats.used) {
331+
result.supports_tools = false;
332+
}
333+
334+
auto & tool_calls = messages->at(1)->at("tool_calls");
335+
caps_print_stats(tool_calls, "messages[1].tool_calls");
336+
if (!tool_calls->stats.used) {
337+
result.supports_tool_calls = false;
338+
return;
339+
}
340+
}
341+
);
342+
}
343+
249344
JJ_DEBUG("%s\n", ">>> Running capability check: parallel tool support");
250345

251346
// case: tools support: parallel calls
252347
caps_try_execute(
253348
prog,
254349
[&]() {
350+
json args = json(R"({"arg": "value"})");
351+
if (result.supports_object_arguments) {
352+
args = json{{"arg", "value"}};
353+
}
354+
255355
// messages
256356
return json::array({
257357
{
@@ -267,19 +367,15 @@ caps caps_get(jinja::program & prog) {
267367
{"type", "function"},
268368
{"function", {
269369
{"name", "tool1"},
270-
{"arguments", {
271-
{"arg", "value"}
272-
}}
370+
{"arguments", args}
273371
}}
274372
},
275373
{
276374
{"id", "call00002"},
277375
{"type", "function"},
278376
{"function", {
279377
{"name", "tool1"},
280-
{"arguments", {
281-
{"arg", "value"}
282-
}}
378+
{"arguments", args}
283379
}}
284380
}
285381
})}
@@ -328,7 +424,7 @@ caps caps_get(jinja::program & prog) {
328424
return;
329425
}
330426

331-
auto & tool_calls = messages->at(1)->at("tool_calls");;
427+
auto & tool_calls = messages->at(1)->at("tool_calls");
332428
caps_print_stats(tool_calls, "messages[1].tool_calls");
333429

334430
// check for second tool call usage

common/jinja/caps.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct caps {
1818
bool supports_string_content = true;
1919
bool supports_typed_content = false;
2020

21+
bool supports_object_arguments = false;
22+
2123
// for reporting on server
2224
std::map<std::string, bool> to_map() const;
2325

0 commit comments

Comments
 (0)