|
2 | 2 |
|
3 | 3 | #include "chat_template.h" |
4 | 4 |
|
| 5 | +#include "jinja/lexer.h" |
| 6 | +#include "jinja/parser.h" |
| 7 | +#include "jinja/runtime.h" |
| 8 | +#include "jinja/value.h" |
| 9 | + |
| 10 | +#include <nlohmann/json.hpp> |
| 11 | + |
| 12 | +#include <memory> |
| 13 | +#include <stdexcept> |
| 14 | + |
5 | 15 | namespace dflash::common { |
6 | 16 |
|
7 | 17 | // Qwen3.5 tool preamble — matches the official Jinja template exactly. |
@@ -155,4 +165,103 @@ std::string render_chat_template( |
155 | 165 | return result; |
156 | 166 | } |
157 | 167 |
|
| 168 | +// ─── Jinja path ───────────────────────────────────────────────────────── |
| 169 | +// |
| 170 | +// Render via a Jinja chat template (e.g. froggeric Qwen3.6 template). Each |
| 171 | +// thread caches the most-recently-parsed program for its template source, |
| 172 | +// so steady-state cost is just the runtime execute (parse happens once per |
| 173 | +// process per template). |
| 174 | + |
| 175 | +namespace { |
| 176 | + |
| 177 | +struct JinjaCache { |
| 178 | + std::string src; |
| 179 | + std::shared_ptr<jinja::program> prog; |
| 180 | +}; |
| 181 | + |
| 182 | +static thread_local JinjaCache tls_jinja_cache; |
| 183 | + |
| 184 | +static std::shared_ptr<jinja::program> get_or_parse(const std::string & template_src) { |
| 185 | + if (tls_jinja_cache.prog && tls_jinja_cache.src == template_src) { |
| 186 | + return tls_jinja_cache.prog; |
| 187 | + } |
| 188 | + jinja::lexer lex; |
| 189 | + jinja::lexer_result lex_res; |
| 190 | + try { |
| 191 | + lex_res = lex.tokenize(template_src); |
| 192 | + } catch (const std::exception & e) { |
| 193 | + throw std::runtime_error(std::string("jinja lexer: ") + e.what()); |
| 194 | + } |
| 195 | + auto prog = std::make_shared<jinja::program>(jinja::parse_from_tokens(lex_res)); |
| 196 | + tls_jinja_cache.src = template_src; |
| 197 | + tls_jinja_cache.prog = prog; |
| 198 | + return prog; |
| 199 | +} |
| 200 | + |
| 201 | +} // namespace |
| 202 | + |
| 203 | +std::string render_chat_template_jinja( |
| 204 | + const std::string & template_src, |
| 205 | + const std::vector<ChatMessage> & messages, |
| 206 | + const std::string & bos_token, |
| 207 | + const std::string & eos_token, |
| 208 | + bool add_generation_prompt, |
| 209 | + bool enable_thinking, |
| 210 | + const std::string & tools_json) |
| 211 | +{ |
| 212 | + if (template_src.empty()) { |
| 213 | + throw std::runtime_error("render_chat_template_jinja: template_src is empty"); |
| 214 | + } |
| 215 | + |
| 216 | + auto prog = get_or_parse(template_src); |
| 217 | + |
| 218 | + // Build the JSON input that mirrors llama.cpp's |
| 219 | + // common_chat_template_direct_apply_impl. Field names must match the |
| 220 | + // names the Jinja templates expect (messages, tools, bos_token, |
| 221 | + // eos_token, add_generation_prompt, enable_thinking). |
| 222 | + nlohmann::ordered_json messages_j = nlohmann::ordered_json::array(); |
| 223 | + for (const auto & m : messages) { |
| 224 | + nlohmann::ordered_json mj; |
| 225 | + mj["role"] = m.role; |
| 226 | + mj["content"] = m.content; |
| 227 | + if (!m.tool_call_id.empty()) { |
| 228 | + mj["tool_call_id"] = m.tool_call_id; |
| 229 | + } |
| 230 | + messages_j.push_back(std::move(mj)); |
| 231 | + } |
| 232 | + |
| 233 | + nlohmann::ordered_json inputs; |
| 234 | + inputs["messages"] = std::move(messages_j); |
| 235 | + inputs["bos_token"] = bos_token; |
| 236 | + inputs["eos_token"] = eos_token; |
| 237 | + inputs["add_generation_prompt"] = add_generation_prompt; |
| 238 | + inputs["enable_thinking"] = enable_thinking; |
| 239 | + |
| 240 | + bool has_tools = !tools_json.empty() && tools_json != "[]" && tools_json != "null"; |
| 241 | + if (has_tools) { |
| 242 | + try { |
| 243 | + inputs["tools"] = nlohmann::ordered_json::parse(tools_json); |
| 244 | + } catch (const std::exception & e) { |
| 245 | + throw std::runtime_error( |
| 246 | + std::string("render_chat_template_jinja: failed to parse tools JSON: ") + e.what()); |
| 247 | + } |
| 248 | + } |
| 249 | + |
| 250 | + jinja::context ctx(template_src); |
| 251 | + try { |
| 252 | + jinja::global_from_json(ctx, inputs, /*mark_input=*/false); |
| 253 | + } catch (const std::exception & e) { |
| 254 | + throw std::runtime_error(std::string("jinja global_from_json: ") + e.what()); |
| 255 | + } |
| 256 | + |
| 257 | + try { |
| 258 | + jinja::runtime rt(ctx); |
| 259 | + jinja::value results = rt.execute(*prog); |
| 260 | + auto parts = jinja::runtime::gather_string_parts(results); |
| 261 | + return parts->as_string().str(); |
| 262 | + } catch (const std::exception & e) { |
| 263 | + throw std::runtime_error(std::string("jinja runtime: ") + e.what()); |
| 264 | + } |
| 265 | +} |
| 266 | + |
158 | 267 | } // namespace dflash::common |
0 commit comments