Skip to content

Commit b3fed31

Browse files
authored
jinja, chat: add --reasoning-preserve flag (ggml-org#25105)
* jinja, chat: add --reasoning-preserve flag * correct help message
1 parent dbdaece commit b3fed31

5 files changed

Lines changed: 80 additions & 24 deletions

File tree

common/arg.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3296,6 +3296,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32963296
params.sampling.reasoning_budget_message = value;
32973297
}
32983298
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
3299+
add_opt(common_arg(
3300+
{"--reasoning-preserve"},
3301+
{"--no-reasoning-preserve"},
3302+
"preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n"
3303+
"compatible with certain templates having 'supports_preserve_reasoning' capability\n"
3304+
"example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking",
3305+
[](common_params & params, bool value) {
3306+
if (value) {
3307+
params.default_template_kwargs["preserve_reasoning"] = "true";
3308+
} else {
3309+
params.default_template_kwargs["preserve_reasoning"] = "false";
3310+
}
3311+
}
3312+
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE"));
32993313
add_opt(common_arg(
33003314
{"--chat-template"}, "JINJA_TEMPLATE",
33013315
string_format(

common/chat.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl(
912912
if (inputs.add_generation_prompt) {
913913
inp["add_generation_prompt"] = true;
914914
}
915+
if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) {
916+
bool enabled = inp["preserve_reasoning"].get<bool>();
917+
jinja::caps_apply_preserve_reasoning(ctx, enabled);
918+
}
915919

916920
jinja::global_from_json(ctx, inp, inputs.mark_input);
917921

common/jinja/caps.cpp

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,34 @@ using json = nlohmann::ordered_json;
1616
namespace jinja {
1717

1818
using caps_json_fn = std::function<json()>;
19-
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
19+
using caps_ctx_fn = std::function<void(context &)>;
20+
using caps_analyze_fn = std::function<void(bool, value &, value &, const std::string &)>;
21+
22+
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) {
23+
ctx.set_val("preserve_thinking", mk_val<value_bool>(enabled));
24+
ctx.set_val("clear_thinking", mk_val<value_bool>(!enabled));
25+
ctx.set_val("truncate_history_thinking", mk_val<value_bool>(!enabled));
26+
}
2027

2128
static void caps_try_execute(jinja::program & prog,
2229
const caps_json_fn & messages_fn,
30+
const caps_ctx_fn & ctx_fn,
2331
const caps_json_fn & tools_fn,
2432
const caps_analyze_fn & analyze_fn) {
2533
context ctx;
2634
ctx.is_get_stats = true;
2735
jinja::global_from_json(ctx, json{
2836
{"messages", messages_fn()},
29-
{"tools", tools_fn()},
37+
{"tools", tools_fn ? tools_fn() : json::array()},
3038
{"bos_token", ""},
3139
{"eos_token", ""},
3240
{"add_generation_prompt", true}
3341
}, true);
3442

43+
if (ctx_fn) {
44+
ctx_fn(ctx);
45+
}
46+
3547
auto messages = ctx.get_val("messages");
3648
auto tools = ctx.get_val("tools");
3749

@@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog,
4961
// ignore exceptions during capability analysis
5062
}
5163

52-
analyze_fn(success, messages, tools);
64+
analyze_fn(success, messages, tools, result);
5365
}
5466

5567
// for debugging only
@@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) {
109121
}
110122
});
111123
},
112-
[&]() {
113-
// tools
114-
return json{nullptr};
115-
},
116-
[&](bool success, value & messages, value &) {
124+
nullptr, // ctx_fn
125+
nullptr, // tools_fn
126+
[&](bool success, value & messages, value &, const std::string &) {
117127
auto & content = messages->at(0)->at("content");
118128
caps_print_stats(content, "messages[0].content");
119129
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
@@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) {
145155
},
146156
});
147157
},
148-
[&]() {
149-
// tools
150-
return json::array();
151-
},
152-
[&](bool, value & messages, value &) {
158+
nullptr, // ctx_fn
159+
nullptr, // tools_fn
160+
[&](bool, value & messages, value &, const std::string &) {
153161
auto & content = messages->at(0)->at("content");
154162
caps_print_stats(content, "messages[0].content");
155163
if (!content->stats.used) {
@@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) {
201209
},
202210
});
203211
},
212+
nullptr, // ctx_fn
204213
[&]() {
205214
// tools
206215
return json::array({
@@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) {
224233
},
225234
});
226235
},
227-
[&](bool success, value & messages, value & tools) {
236+
[&](bool success, value & messages, value & tools, const std::string &) {
228237
if (!success) {
229238
return; // Nothing can be inferred
230239
}
@@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) {
293302
},
294303
});
295304
},
305+
nullptr, // ctx_fn
296306
[&]() {
297307
// tools
298308
return json::array({
@@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) {
316326
},
317327
});
318328
},
319-
[&](bool success, value & messages, value & tools) {
329+
[&](bool success, value & messages, value & tools, const std::string &) {
320330
if (!success) {
321331
result.supports_tool_calls = false;
322332
result.supports_tools = false;
@@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) {
394404
},
395405
});
396406
},
407+
nullptr, // ctx_fn
397408
[&]() {
398409
// tools
399410
return json::array({
@@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) {
417428
},
418429
});
419430
},
420-
[&](bool success, value & messages, value & /*tools*/) {
431+
[&](bool success, value & messages, value &, const std::string &) {
421432
if (!success) {
422433
result.supports_parallel_tool_calls = false;
423434
return;
@@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) {
438449
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
439450

440451
// case: preserve reasoning content in chat history
452+
const std::string reasoning_placeholder = "<REASONING_CONTENT_PLACEHOLDER>";
441453
caps_try_execute(
442454
prog,
443455
[&]() {
444456
// messages
445457
return json::array({
458+
{
459+
{"role", "user"},
460+
{"content", "User message"}
461+
},
462+
{
463+
{"role", "assistant"},
464+
{"content", "Assistant message"},
465+
// check of reasoning_content deeper in the history, not just the last assistant message
466+
{"reasoning_content", reasoning_placeholder}
467+
},
446468
{
447469
{"role", "user"},
448470
{"content", "User message"}
@@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) {
458480
},
459481
});
460482
},
461-
[&]() {
462-
// tools
463-
return json::array();
483+
[&](context & ctx) {
484+
caps_apply_preserve_reasoning(ctx, true);
464485
},
465-
[&](bool, value & messages, value &) {
466-
auto & content = messages->at(1)->at("reasoning_content");
467-
caps_print_stats(content, "messages[1].reasoning_content");
468-
if (content->stats.used) {
486+
nullptr, // tools_fn
487+
[&](bool, value &, value &, const std::string & output) {
488+
// note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result
489+
if (output.find(reasoning_placeholder) != std::string::npos) {
469490
result.supports_preserve_reasoning = true;
470491
}
471492
}

common/jinja/caps.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ struct caps {
1212
bool supports_tool_calls = true;
1313
bool supports_system_role = true;
1414
bool supports_parallel_tool_calls = true;
15-
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
15+
16+
// supports preserve reasoning trace in the full history, not just the last assistant message
17+
bool supports_preserve_reasoning = false;
1618

1719
// one of the 2 content capabilities must be true
1820
bool supports_string_content = true;
@@ -29,4 +31,6 @@ struct caps {
2931

3032
caps caps_get(jinja::program & prog);
3133

34+
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled);
35+
3236
} // namespace jinja

tools/server/server-context.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,19 @@ struct server_context_impl {
15381538
/* media_path */ params_base.media_path,
15391539
/* force_pure_content */ params_base.force_pure_content_parser
15401540
};
1541+
1542+
{
1543+
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
1544+
auto it = params_base.default_template_kwargs.find("preserve_reasoning");
1545+
bool supported = caps.at("supports_preserve_reasoning");
1546+
bool enabled = it != params_base.default_template_kwargs.end();
1547+
if (supported && !enabled) {
1548+
SRV_INF("%s", "chat template supports preserving reasoning, consider enabling it via --reasoning-preserve\n");
1549+
}
1550+
if (!supported && enabled) {
1551+
SRV_WRN("%s", "chat template does NOT support preserving reasoning, --reasoning-preserve has no effect\n");
1552+
}
1553+
}
15411554
}
15421555

15431556
return true;

0 commit comments

Comments
 (0)