Skip to content

Commit 39cf5d6

Browse files
authored
common : delegate assistant continuation to underlying template handlers (#23089)
* common : delegate assistant continuation to template handler * server : implement echo parameter to exclude assistant prefill in the response * server : fix tests for prefill * server : use existing llama template * cont : clean up
1 parent a6d6183 commit 39cf5d6

10 files changed

Lines changed: 1110 additions & 189 deletions

File tree

common/chat-auto-parser-generator.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,33 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
4343
const autoparser & autoparser) {
4444
// Create the result structure
4545
common_chat_params data;
46-
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
47-
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
48-
data.preserved_tokens = autoparser.preserved_tokens;
46+
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
47+
data.generation_prompt = common_chat_template_generation_prompt(tmpl, inputs);
48+
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
49+
data.preserved_tokens = autoparser.preserved_tokens;
50+
51+
std::string parser_generation_prompt = data.generation_prompt;
52+
53+
if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !inputs.continue_msg.empty()) {
54+
// Build up generation prompt manually
55+
const auto & msg = inputs.continue_msg;
56+
57+
if (!autoparser.reasoning.start.empty()) {
58+
data.generation_prompt = data.generation_prompt.substr(0, data.generation_prompt.find(autoparser.reasoning.start));
59+
data.generation_prompt += autoparser.reasoning.start + msg.reasoning_content;
60+
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
61+
data.generation_prompt += autoparser.reasoning.end;
62+
}
63+
}
64+
65+
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
66+
data.generation_prompt += msg.render_content();
67+
}
68+
69+
data.prompt += data.generation_prompt;
70+
}
4971

50-
auto parser = autoparser.build_parser(inputs);
72+
auto parser = autoparser.build_parser(inputs, parser_generation_prompt);
5173
data.parser = parser.save();
5274

5375
// Build grammar if tools are present
@@ -87,7 +109,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
87109
return data;
88110
}
89111

90-
common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
112+
common_peg_arena autoparser::build_parser(const generation_params & inputs, const std::string & generation_prompt) const {
91113
if (!analysis_complete) {
92114
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
93115
}
@@ -121,7 +143,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
121143
} else {
122144
parser = content.build_parser(ctx);
123145
}
124-
return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser;
146+
return pure_content ? p.prefix(generation_prompt, reasoning.start) + parser : p.prefix(generation_prompt, reasoning.start) << parser;
125147
});
126148
}
127149

common/chat-auto-parser.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,21 @@ struct generation_params {
6060
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
6161
bool stream = true;
6262
std::string grammar;
63-
bool add_generation_prompt = false;
64-
bool enable_thinking = true;
65-
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
66-
std::string generation_prompt;
63+
bool add_generation_prompt = false;
64+
common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE;
65+
common_chat_msg continue_msg;
66+
bool enable_thinking = true;
67+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
6768
json extra_context;
6869
bool add_bos = false;
6970
bool add_eos = false;
7071
bool is_inference = true;
7172
bool add_inference = false;
7273
bool mark_input = true; // whether to mark input strings in the jinja context
74+
75+
bool has_continuation() const {
76+
return continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !continue_msg.empty();
77+
}
7378
};
7479

7580
// ============================================================================
@@ -386,7 +391,7 @@ struct autoparser {
386391
void analyze_template(const common_chat_template & tmpl);
387392

388393
// Build the PEG parser for this template
389-
common_peg_arena build_parser(const generation_params & inputs) const;
394+
common_peg_arena build_parser(const generation_params & inputs, const std::string & generation_prompt) const;
390395

391396
private:
392397
// Collect tokens from entire analysis to preserve

common/chat-peg-parser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s
785785
if (delimiter.empty()) {
786786
return literal(s);
787787
}
788-
return literal(s.substr(0, s.rfind(delimiter)));
788+
return literal(s.substr(0, s.find(delimiter)));
789789
}
790790

791791
common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) {

0 commit comments

Comments
 (0)