Skip to content

Commit cc25c9c

Browse files
committed
jinja: add --dump-prog for debugging
1 parent beac530 commit cc25c9c

3 files changed

Lines changed: 186 additions & 3 deletions

File tree

common/jinja/runtime.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,4 +954,46 @@ value keyword_argument_expression::execute_impl(context & ctx) {
954954
return mk_val<value_kwarg>(k, v);
955955
}
956956

957+
std::string runtime::debug_dump_program(const program & prog, const std::string & src) {
958+
std::ostringstream oss;
959+
size_t lvl = 0;
960+
context ctx;
961+
ctx.src.reset(new std::string(src));
962+
963+
auto indent = [](size_t lvl) -> std::string {
964+
return std::string(lvl * 2, ' ');
965+
};
966+
967+
ctx.visitor = [&](bool is_leaf, statement * node, std::vector<visitor_pair> children) {
968+
oss << indent(lvl) << node->type() << ":\n";
969+
lvl++;
970+
if (is_leaf) {
971+
const auto & pos = node->pos;
972+
oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n";
973+
std::string snippet = peak_source(src, pos);
974+
string_replace_all(snippet, "\n", "\n" + indent(lvl));
975+
oss << indent(lvl) << snippet << "\n";
976+
} else {
977+
for (auto & [label, children_vec] : children) {
978+
oss << indent(lvl) << label << ":\n";
979+
lvl++;
980+
for (auto * child : children_vec) {
981+
if (!child) {
982+
continue;
983+
}
984+
child->visit(ctx);
985+
}
986+
lvl--;
987+
}
988+
}
989+
lvl--;
990+
};
991+
992+
for (const auto & stmt : prog.body) {
993+
stmt->visit(ctx);
994+
}
995+
996+
return oss.str();
997+
}
998+
957999
} // namespace jinja

common/jinja/runtime.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) {
4747
// not thread-safe
4848
void enable_debug(bool enable);
4949

50+
// for visiting AST nodes
51+
// function signature: void(bool is_leaf, statement * node, pair of <label, children>)
52+
using visitor_pair = std::pair<std::string, std::vector<statement *>>;
53+
using visitor_fn = std::function<void(bool, statement *, std::vector<visitor_pair>)>;
54+
5055
struct context {
5156
std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
5257
std::time_t current_time; // for functions that need current time
5358

5459
bool is_get_stats = false; // whether to collect stats
5560

61+
visitor_fn visitor;
62+
5663
// src is optional, used for error reporting
5764
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
5865
env = mk_val<value_object>();
@@ -99,13 +106,23 @@ struct context {
99106
value_object env;
100107
};
101108

109+
// utils for visiting AST nodes
110+
static std::vector<statement *> stmts_to_ptr(const statements & stmts) {
111+
std::vector<statement *> children;
112+
for (const auto & stmt : stmts) {
113+
children.push_back(stmt.get());
114+
}
115+
return children;
116+
}
117+
102118
/**
103119
* Base class for all nodes in the AST.
104120
*/
105121
struct statement {
106122
size_t pos; // position in source, for debugging
107123
virtual ~statement() = default;
108124
virtual std::string type() const { return "Statement"; }
125+
virtual void visit(context & ctx) { ctx.visitor(true, this, {}); }
109126

110127
// execute_impl must be overridden by derived classes
111128
virtual value execute_impl(context &) { throw_exec_error(); }
@@ -166,6 +183,13 @@ struct if_statement : public statement {
166183

167184
std::string type() const override { return "If"; }
168185
value execute_impl(context & ctx) override;
186+
void visit(context & ctx) override {
187+
ctx.visitor(false, this, {
188+
{"test", {test.get()}},
189+
{"body", stmts_to_ptr(body)},
190+
{"alternate", stmts_to_ptr(alternate)}
191+
});
192+
}
169193
};
170194

171195
struct identifier;
@@ -190,6 +214,14 @@ struct for_statement : public statement {
190214

191215
std::string type() const override { return "For"; }
192216
value execute_impl(context & ctx) override;
217+
void visit(context & ctx) override {
218+
ctx.visitor(false, this, {
219+
{"loopvar", {loopvar.get()}},
220+
{"iterable", {iterable.get()}},
221+
{"body", stmts_to_ptr(body)},
222+
{"default_block", stmts_to_ptr(default_block)}
223+
});
224+
}
193225
};
194226

195227
struct break_statement : public statement {
@@ -241,6 +273,13 @@ struct set_statement : public statement {
241273

242274
std::string type() const override { return "Set"; }
243275
value execute_impl(context & ctx) override;
276+
void visit(context & ctx) override {
277+
ctx.visitor(false, this, {
278+
{"assignee", {assignee.get()}},
279+
{"value", {val.get()}},
280+
{"body", stmts_to_ptr(body)}
281+
});
282+
}
244283
};
245284

246285
struct macro_statement : public statement {
@@ -256,6 +295,13 @@ struct macro_statement : public statement {
256295

257296
std::string type() const override { return "Macro"; }
258297
value execute_impl(context & ctx) override;
298+
void visit(context & ctx) override {
299+
ctx.visitor(false, this, {
300+
{"name", {name.get()}},
301+
{"args", stmts_to_ptr(args)},
302+
{"body", stmts_to_ptr(body)}
303+
});
304+
}
259305
};
260306

261307
struct comment_statement : public statement {
@@ -289,6 +335,12 @@ struct member_expression : public expression {
289335
}
290336
std::string type() const override { return "MemberExpression"; }
291337
value execute_impl(context & ctx) override;
338+
void visit(context & ctx) override {
339+
ctx.visitor(false, this, {
340+
{"object", {object.get()}},
341+
{"property", {property.get()}}
342+
});
343+
}
292344
};
293345

294346
struct call_expression : public expression {
@@ -302,6 +354,12 @@ struct call_expression : public expression {
302354
}
303355
std::string type() const override { return "CallExpression"; }
304356
value execute_impl(context & ctx) override;
357+
void visit(context & ctx) override {
358+
ctx.visitor(false, this, {
359+
{"callee", {callee.get()}},
360+
{"args", stmts_to_ptr(args)}
361+
});
362+
}
305363
};
306364

307365
/**
@@ -405,6 +463,12 @@ struct binary_expression : public expression {
405463
}
406464
std::string type() const override { return "BinaryExpression"; }
407465
value execute_impl(context & ctx) override;
466+
void visit(context & ctx) override {
467+
ctx.visitor(false, this, {
468+
{"left", {left.get()}},
469+
{"right", {right.get()}}
470+
});
471+
}
408472
};
409473

410474
/**
@@ -431,6 +495,12 @@ struct filter_expression : public expression {
431495

432496
std::string type() const override { return "FilterExpression"; }
433497
value execute_impl(context & ctx) override;
498+
void visit(context & ctx) override {
499+
ctx.visitor(false, this, {
500+
{"operand", {operand.get()}},
501+
{"filter", {filter.get()}}
502+
});
503+
}
434504
};
435505

436506
struct filter_statement : public statement {
@@ -443,6 +513,12 @@ struct filter_statement : public statement {
443513
}
444514
std::string type() const override { return "FilterStatement"; }
445515
value execute_impl(context & ctx) override;
516+
void visit(context & ctx) override {
517+
ctx.visitor(false, this, {
518+
{"filter", {filter.get()}},
519+
{"body", stmts_to_ptr(body)}
520+
});
521+
}
446522
};
447523

448524
/**
@@ -468,6 +544,12 @@ struct select_expression : public expression {
468544
}
469545
return lhs->execute_impl(ctx);
470546
}
547+
void visit(context & ctx) override {
548+
ctx.visitor(false, this, {
549+
{"lhs", {lhs.get()}},
550+
{"test", {test.get()}}
551+
});
552+
}
471553
};
472554

473555
/**
@@ -486,6 +568,12 @@ struct test_expression : public expression {
486568
}
487569
std::string type() const override { return "TestExpression"; }
488570
value execute_impl(context & ctx) override;
571+
void visit(context & ctx) override {
572+
ctx.visitor(false, this, {
573+
{"operand", {operand.get()}},
574+
{"test", {test.get()}}
575+
});
576+
}
489577
};
490578

491579
/**
@@ -501,6 +589,11 @@ struct unary_expression : public expression {
501589
}
502590
std::string type() const override { return "UnaryExpression"; }
503591
value execute_impl(context & ctx) override;
592+
void visit(context & ctx) override {
593+
ctx.visitor(false, this, {
594+
{"argument", {argument.get()}}
595+
});
596+
}
504597
};
505598

506599
struct slice_expression : public expression {
@@ -518,6 +611,13 @@ struct slice_expression : public expression {
518611
[[noreturn]] value execute_impl(context &) override {
519612
throw std::runtime_error("must be handled by MemberExpression");
520613
}
614+
void visit(context & ctx) override {
615+
ctx.visitor(false, this, {
616+
{"start_expr", {start_expr.get()}},
617+
{"stop_expr", {stop_expr.get()}},
618+
{"step_expr", {step_expr.get()}}
619+
});
620+
}
521621
};
522622

523623
struct keyword_argument_expression : public expression {
@@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression {
531631
}
532632
std::string type() const override { return "KeywordArgumentExpression"; }
533633
value execute_impl(context & ctx) override;
634+
void visit(context & ctx) override {
635+
ctx.visitor(false, this, {
636+
{"key", {key.get()}},
637+
{"val", {val.get()}}
638+
});
639+
}
534640
};
535641

536642
struct spread_expression : public expression {
@@ -539,6 +645,11 @@ struct spread_expression : public expression {
539645
chk_type<expression>(this->argument);
540646
}
541647
std::string type() const override { return "SpreadExpression"; }
648+
void visit(context & ctx) override {
649+
ctx.visitor(false, this, {
650+
{"argument", {argument.get()}}
651+
});
652+
}
542653
};
543654

544655
struct call_statement : public statement {
@@ -553,6 +664,13 @@ struct call_statement : public statement {
553664
}
554665
std::string type() const override { return "CallStatement"; }
555666
value execute_impl(context & ctx) override;
667+
void visit(context & ctx) override {
668+
ctx.visitor(false, this, {
669+
{"call", {call.get()}},
670+
{"caller_args", stmts_to_ptr(caller_args)},
671+
{"body", stmts_to_ptr(body)}
672+
});
673+
}
556674
};
557675

558676
struct ternary_expression : public expression {
@@ -575,6 +693,13 @@ struct ternary_expression : public expression {
575693
return false_expr->execute(ctx);
576694
}
577695
}
696+
void visit(context & ctx) override {
697+
ctx.visitor(false, this, {
698+
{"condition", {condition.get()}},
699+
{"true_expr", {true_expr.get()}},
700+
{"false_expr", {false_expr.get()}}
701+
});
702+
}
578703
};
579704

580705
struct raised_exception : public std::exception {
@@ -648,6 +773,8 @@ struct runtime {
648773
}
649774
return parts;
650775
}
776+
777+
static std::string debug_dump_program(const program & prog, const std::string & src);
651778
};
652779

653780
} // namespace jinja

tests/test-chat-template.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using json = nlohmann::ordered_json;
2525
static int main_automated_tests(void);
2626

2727
static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false);
28-
static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = "");
28+
static void run_single(const std::string& contents, json input, bool use_common = false, bool dump_prog = false, const std::string & output_path = "");
2929

3030
static std::string HELP = R"(
3131
Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
@@ -35,6 +35,7 @@ Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
3535
--json <path> Path to the JSON input file.
3636
--stop-on-first-fail Stop testing on the first failure (default: false).
3737
--no-common Use direct Jinja engine instead of common chat templates (default: use common).
38+
--dump-prog Dump the parsed program for debugging (only for single template runs).
3839
--output <path> Path to output results (only for single template runs).
3940
If PATH_TO_TEMPLATE is a file, runs that single template.
4041
If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
@@ -118,6 +119,7 @@ int main(int argc, char ** argv) {
118119
std::string & json_to_use = DEFAULT_JSON;
119120
bool stop_on_first_fail = false;
120121
bool use_common = true;
122+
bool dump_prog = false;
121123

122124
for (size_t i = 1; i < args.size(); i++) {
123125
if (args[i] == "--help" || args[i] == "-h") {
@@ -136,6 +138,8 @@ int main(int argc, char ** argv) {
136138
i++;
137139
} else if (args[i] == "--no-common") {
138140
use_common = true;
141+
} else if (args[i] == "--dump-prog") {
142+
dump_prog = true;
139143
} else if (tmpl_path.empty()) {
140144
tmpl_path = args[i];
141145
} else {
@@ -172,7 +176,7 @@ int main(int argc, char ** argv) {
172176
std::string contents = std::string(
173177
std::istreambuf_iterator<char>(infile),
174178
std::istreambuf_iterator<char>());
175-
run_single(contents, input_json, use_common, output_path);
179+
run_single(contents, input_json, use_common, dump_prog, output_path);
176180
} else {
177181
std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n";
178182
return 1;
@@ -276,11 +280,21 @@ static jinja::value_string format_using_direct_engine(
276280
}
277281

278282

279-
void run_single(const std::string& contents, json input, bool use_common, const std::string & output_path) {
283+
void run_single(const std::string& contents, json input, bool use_common, bool dump_prog, const std::string & output_path) {
280284
jinja::enable_debug(true);
281285

282286
jinja::value_string output_parts;
283287

288+
if (dump_prog) {
289+
jinja::lexer lexer;
290+
auto lexer_res = lexer.tokenize(contents);
291+
jinja::program ast = jinja::parse_from_tokens(lexer_res);
292+
std::string prog_dump = jinja::runtime::debug_dump_program(ast, contents);
293+
std::cout << "\n=== DUMPED PROGRAM ===\n";
294+
std::cout << prog_dump << "\n";
295+
return;
296+
}
297+
284298
if (use_common) {
285299
std::string bos_token = "<s>";
286300
std::string eos_token = "</s>";

0 commit comments

Comments
 (0)