Skip to content

Commit 52b3df0

Browse files
authored
common/peg : implement ac parser for stricter grammar generation (#24869)
* common/peg : implement ac parser * cont : extract functions * cont : tidy up * cont : remove a test * cont : move ac() def
1 parent 7c082bc commit 52b3df0

4 files changed

Lines changed: 190 additions & 34 deletions

File tree

common/chat-auto-parser-generator.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
395395
arguments.name_suffix) +
396396
arguments.value_prefix +
397397
(schema_info.resolves_to_string(param_schema) ?
398-
p.tool_arg_string_value(until_suffix) :
399-
p.tool_arg_json_value(p.schema(
400-
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
401-
p.tool_arg_close(p.literal(arguments.value_suffix)));
398+
p.ac(p.tool_arg_string_value(until_suffix) +
399+
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
400+
(p.tool_arg_json_value(p.schema(
401+
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
402+
p.tool_arg_close(p.literal(arguments.value_suffix)))));
402403

403404
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
404405
if (is_required) {

common/peg-parser.cpp

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,10 @@ struct parser_executor {
921921
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
922922
return arena.parse(p.child, ctx, start_pos);
923923
}
924+
925+
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
926+
return arena.parse(p.child, ctx, start_pos);
927+
}
924928
};
925929

926930
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
@@ -989,7 +993,8 @@ void common_peg_arena::resolve_refs() {
989993
std::is_same_v<T, common_peg_not_parser> ||
990994
std::is_same_v<T, common_peg_tag_parser> ||
991995
std::is_same_v<T, common_peg_atomic_parser> ||
992-
std::is_same_v<T, common_peg_gbnf_parser>) {
996+
std::is_same_v<T, common_peg_gbnf_parser> ||
997+
std::is_same_v<T, common_peg_ac_parser>) {
993998
p.child = resolve_ref(p.child);
994999
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
9951000
p.child = resolve_ref(p.child);
@@ -1070,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
10701075
return "Atomic(" + dump_impl(p.child, visited) + ")";
10711076
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
10721077
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
1078+
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
1079+
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
10731080
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
10741081
return "Any";
10751082
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
@@ -1479,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
14791486
});
14801487
}
14811488

1489+
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
1490+
if (delimiters.empty()) {
1491+
throw std::runtime_error("ac parser requires at least one delimiter");
1492+
}
1493+
return add(common_peg_ac_parser{p, delimiters});
1494+
}
1495+
14821496
static std::string gbnf_escape_char_class(uint32_t c) {
14831497
if (c == '-' || c == ']' || c == '[' || c == '\\') {
14841498
return "\\" + std::string(1, (char) c);
@@ -1529,14 +1543,22 @@ static std::string gbnf_escape_char_class(uint32_t c) {
15291543
return std::string(buf);
15301544
}
15311545

1532-
// GBNF grammar matching strings that contain no string in `strings` as a
1533-
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
1534-
// the start state rule name.
1535-
//
1536-
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
1537-
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
1538-
const std::string & prefix,
1539-
const std::vector<std::string> & strings) {
1546+
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
1547+
std::string s = negate ? "[^" : "[";
1548+
for (uint32_t ch : chars) {
1549+
s += gbnf_escape_char_class(ch);
1550+
}
1551+
return s + "]";
1552+
}
1553+
1554+
static std::string gbnf_ac_grammar(
1555+
const common_grammar_builder & builder,
1556+
const std::string & prefix,
1557+
const std::vector<std::string> & strings,
1558+
const std::function<std::string(const std::vector<uint32_t> &,
1559+
const std::map<size_t, std::vector<uint32_t>> &,
1560+
const std::vector<uint32_t> &,
1561+
const std::function<std::string(size_t)> &)> & build_rule) {
15401562
aho_corasick ac(strings);
15411563

15421564
auto state_name = [&](size_t s) -> std::string {
@@ -1548,49 +1570,85 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder
15481570
return prefix + "-" + num;
15491571
};
15501572

1551-
auto char_class = [](const std::vector<uint32_t> & chars, bool negate) {
1552-
std::string s = negate ? "[^" : "[";
1553-
for (uint32_t ch : chars) {
1554-
s += gbnf_escape_char_class(ch);
1555-
}
1556-
return s + "]";
1557-
};
1558-
15591573
for (size_t q = 0; q < ac.num_states(); q++) {
15601574
if (ac.is_terminal(q)) {
1561-
continue; // match states are dropped
1575+
continue; // match states
15621576
}
15631577

15641578
std::map<size_t, std::vector<uint32_t>> buckets;
1565-
std::vector<uint32_t> excluded;
1579+
std::vector<uint32_t> completing; // chars that complete a delimiter
1580+
std::vector<uint32_t> specific; // chars with an explicit transition
15661581
for (uint32_t c : ac.alphabet) {
15671582
size_t d = ac.next(q, c);
15681583
if (ac.is_terminal(d)) {
1569-
excluded.push_back(c); // completes a forbidden string -> omit
1584+
completing.push_back(c);
1585+
specific.push_back(c);
15701586
} else if (d != 0) {
15711587
buckets[d].push_back(c); // specific non-root destination
1572-
excluded.push_back(c);
1588+
specific.push_back(c);
15731589
}
15741590
}
15751591

1576-
std::string rhs = "|"; // every state is accepting
1577-
for (const auto & [d, chars] : buckets) {
1578-
rhs += " " + char_class(chars, false) + " " + state_name(d) + " |";
1579-
}
1580-
rhs += " " + char_class(excluded, true) + " " + state_name(0);
1581-
1582-
builder.add_rule(state_name(q), rhs);
1592+
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
15831593
}
15841594

15851595
// An empty delimiter makes the start state terminal. Emit an entry rule
1586-
// that matches nothing so the returned reference stays valid.
1596+
// that matches the empty string so the returned reference stays valid.
15871597
if (ac.is_terminal(0)) {
15881598
builder.add_rule(prefix, "|");
15891599
}
15901600

15911601
return state_name(0);
15921602
}
15931603

1604+
// GBNF grammar matching strings that contain no string in `strings` as a
1605+
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
1606+
// the start state rule name.
1607+
//
1608+
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
1609+
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
1610+
const std::string & prefix,
1611+
const std::vector<std::string> & strings) {
1612+
return gbnf_ac_grammar(builder, prefix, strings,
1613+
[](const std::vector<uint32_t> & /*completing*/,
1614+
const std::map<size_t, std::vector<uint32_t>> & buckets,
1615+
const std::vector<uint32_t> & specific,
1616+
const std::function<std::string(size_t)> & state_name) {
1617+
// every state is accepting and completing chars get no
1618+
// alternative, so a forbidden string can never be matched
1619+
std::string rhs = "|";
1620+
for (const auto & [d, chars] : buckets) {
1621+
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
1622+
}
1623+
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
1624+
return rhs;
1625+
});
1626+
}
1627+
1628+
// GBNF grammar matching everything up to and including the first occurrence of
1629+
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
1630+
// the start state rule name.
1631+
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
1632+
const std::string & prefix,
1633+
const std::vector<std::string> & strings) {
1634+
return gbnf_ac_grammar(builder, prefix, strings,
1635+
[](const std::vector<uint32_t> & completing,
1636+
const std::map<size_t, std::vector<uint32_t>> & buckets,
1637+
const std::vector<uint32_t> & specific,
1638+
const std::function<std::string(size_t)> & state_name) {
1639+
std::vector<std::string> alts;
1640+
if (!completing.empty()) {
1641+
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
1642+
}
1643+
for (const auto & [d, chars] : buckets) {
1644+
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
1645+
}
1646+
// every other character keeps scanning from the start state
1647+
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
1648+
return string_join(alts, " | ");
1649+
});
1650+
}
1651+
15941652
static std::set<std::string> collect_reachable_rules(
15951653
const common_peg_arena & arena,
15961654
const common_peg_parser_id & rule
@@ -1628,6 +1686,7 @@ static std::set<std::string> collect_reachable_rules(
16281686
std::is_same_v<T, common_peg_tag_parser> ||
16291687
std::is_same_v<T, common_peg_atomic_parser> ||
16301688
std::is_same_v<T, common_peg_gbnf_parser> ||
1689+
std::is_same_v<T, common_peg_ac_parser> ||
16311690
std::is_same_v<T, common_peg_schema_parser>) {
16321691
visit(p.child);
16331692
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
@@ -1822,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
18221881
return to_gbnf(p.child);
18231882
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
18241883
return p.grammar;
1884+
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
1885+
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
18251886
} else {
18261887
static_assert(is_always_false_v<T>);
18271888
}
@@ -1958,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
19582019
};
19592020
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
19602021
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
2022+
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
2023+
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
19612024
}
19622025
}, variant);
19632026
}
@@ -2130,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
21302193
};
21312194
}
21322195

2196+
if (type == "ac") {
2197+
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
2198+
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
2199+
}
2200+
return common_peg_ac_parser{
2201+
j["child"].get<common_peg_parser_id>(),
2202+
j["delimiters"].get<std::vector<std::string>>(),
2203+
};
2204+
}
2205+
21332206
throw std::runtime_error("Unknown parser type: " + type);
21342207
}
21352208

common/peg-parser.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
275275
std::string grammar;
276276
};
277277

278+
struct common_peg_ac_parser {
279+
common_peg_parser_id child;
280+
std::vector<std::string> delimiters;
281+
};
282+
278283
// Variant holding all parser types
279284
using common_peg_parser_variant = std::variant<
280285
common_peg_epsilon_parser,
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
296301
common_peg_ref_parser,
297302
common_peg_atomic_parser,
298303
common_peg_tag_parser,
299-
common_peg_gbnf_parser
304+
common_peg_gbnf_parser,
305+
common_peg_ac_parser
300306
>;
301307

302308
class common_peg_arena {
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
514520
// the child's grammar. Parsing delegates entirely to the child.
515521
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
516522

523+
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
524+
// automaton of `delimiters`, matching everything up to and including the
525+
// first delimiter. Parsing delegates entirely to the child, which is
526+
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
527+
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
528+
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
529+
517530
void set_root(const common_peg_parser & p);
518531

519532
common_peg_arena build();

tests/peg-parser/test-gbnf-generation.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,75 @@ void test_gbnf_generation(testing &t) {
212212
)""", gbnf);
213213
});
214214

215+
t.test("ac grammar", [](testing &t) {
216+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
217+
return p.ac(p.until("</tag>") + p.literal("</tag>"), "</tag>");
218+
});
219+
220+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
221+
parser.build_grammar(builder);
222+
});
223+
224+
assert_gbnf_equal(t, R"""(
225+
ac-3 ::= [<] ac-3-01 | [^<] ac-3
226+
ac-3-01 ::= [<] ac-3-01 | [/] ac-3-02 | [^/<] ac-3
227+
ac-3-02 ::= [<] ac-3-01 | [t] ac-3-03 | [^<t] ac-3
228+
ac-3-03 ::= [<] ac-3-01 | [a] ac-3-04 | [^<a] ac-3
229+
ac-3-04 ::= [<] ac-3-01 | [g] ac-3-05 | [^<g] ac-3
230+
ac-3-05 ::= [>] | [<] ac-3-01 | [^<>] ac-3
231+
root ::= ac-3
232+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
233+
)""", gbnf);
234+
});
235+
236+
t.test("ac grammar terminates at first delimiter", [](testing &t) {
237+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
238+
return p.ac(p.until("\n</parameter>\n") + p.literal("\n</parameter>\n"), "\n</parameter>\n");
239+
});
240+
241+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
242+
parser.build_grammar(builder);
243+
});
244+
245+
assert_gbnf_equal(t, R"""(
246+
ac-3 ::= [\n] ac-3-01 | [^\n] ac-3
247+
ac-3-01 ::= [\n] ac-3-01 | [<] ac-3-02 | [^\n<] ac-3
248+
ac-3-02 ::= [\n] ac-3-01 | [/] ac-3-03 | [^\n/] ac-3
249+
ac-3-03 ::= [\n] ac-3-01 | [p] ac-3-04 | [^\np] ac-3
250+
ac-3-04 ::= [\n] ac-3-01 | [a] ac-3-05 | [^\na] ac-3
251+
ac-3-05 ::= [\n] ac-3-01 | [r] ac-3-06 | [^\nr] ac-3
252+
ac-3-06 ::= [\n] ac-3-01 | [a] ac-3-07 | [^\na] ac-3
253+
ac-3-07 ::= [\n] ac-3-01 | [m] ac-3-08 | [^\nm] ac-3
254+
ac-3-08 ::= [\n] ac-3-01 | [e] ac-3-09 | [^\ne] ac-3
255+
ac-3-09 ::= [\n] ac-3-01 | [t] ac-3-10 | [^\nt] ac-3
256+
ac-3-10 ::= [\n] ac-3-01 | [e] ac-3-11 | [^\ne] ac-3
257+
ac-3-11 ::= [\n] ac-3-01 | [r] ac-3-12 | [^\nr] ac-3
258+
ac-3-12 ::= [\n] ac-3-01 | [>] ac-3-13 | [^\n>] ac-3
259+
ac-3-13 ::= [\n] | [^\n] ac-3
260+
root ::= ac-3
261+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
262+
)""", gbnf);
263+
});
264+
265+
t.test("ac grammar multiple delimiters", [](testing &t) {
266+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
267+
return p.ac(p.eps(), std::vector<std::string>{"ab", "cd", "ef"});
268+
});
269+
270+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
271+
parser.build_grammar(builder);
272+
});
273+
274+
assert_gbnf_equal(t, R"""(
275+
ac-1 ::= [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^ace] ac-1
276+
ac-1-01 ::= [b] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^abce] ac-1
277+
ac-1-03 ::= [d] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acde] ac-1
278+
ac-1-05 ::= [f] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acef] ac-1
279+
root ::= ac-1
280+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
281+
)""", gbnf);
282+
});
283+
215284
t.test("complex expressions with parentheses", [](testing &t) {
216285
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
217286
return p.one_or_more(p.literal("a") | p.literal("b"));

0 commit comments

Comments
 (0)