Skip to content

Commit 063d9c1

Browse files
authored
common/peg : refactor until gbnf grammar generation (ggml-org#24839)
* common/peg : refactor until gbnf grammar into an ac automaton * cont : add a test with multiple strings * cont : pad state with 0s so rules line up * cont : clean up comments * cont : use set everywhere * cont : inline state num string padding * cont : add a ref to PR * cont : fix regression in server-tools.cpp
1 parent c576070 commit 063d9c1

4 files changed

Lines changed: 202 additions & 83 deletions

File tree

common/peg-parser.cpp

Lines changed: 120 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
#include "unicode.h"
77

88
#include <algorithm>
9+
#include <deque>
910
#include <initializer_list>
1011
#include <map>
1112
#include <memory>
1213
#include <nlohmann/json.hpp>
1314
#include <regex>
15+
#include <set>
1416
#include <stdexcept>
15-
#include <unordered_set>
1617

1718
// Trick to catch missing branches
1819
template <typename T>
@@ -88,40 +89,7 @@ struct trie {
8889
return match_result{match_result::NO_MATCH};
8990
}
9091

91-
struct prefix_and_next {
92-
std::vector<uint32_t> prefix;
93-
std::vector<uint32_t> next_chars;
94-
};
95-
96-
std::vector<prefix_and_next> collect_prefix_and_next() {
97-
std::vector<uint32_t> prefix;
98-
std::vector<prefix_and_next> result;
99-
collect_prefix_and_next(0, prefix, result);
100-
return result;
101-
}
102-
10392
private:
104-
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
105-
if (!nodes[index].is_word) {
106-
if (!nodes[index].children.empty()) {
107-
std::vector<uint32_t> chars;
108-
chars.reserve(nodes[index].children.size());
109-
for (const auto & p : nodes[index].children) {
110-
chars.push_back(p.first);
111-
}
112-
out.emplace_back(prefix_and_next{prefix, chars});
113-
}
114-
}
115-
116-
for (const auto & p : nodes[index].children) {
117-
uint32_t ch = p.first;
118-
auto child = p.second;
119-
prefix.push_back(ch);
120-
collect_prefix_and_next(child, prefix, out);
121-
prefix.pop_back();
122-
}
123-
}
124-
12593
size_t create_node() {
12694
size_t index = nodes.size();
12795
nodes.emplace_back();
@@ -153,6 +121,65 @@ struct trie {
153121
}
154122
};
155123

124+
// Aho-Corasick automaton
125+
struct aho_corasick {
126+
trie t;
127+
std::vector<size_t> fail; // failure links
128+
std::vector<size_t> order; // states in BFS order
129+
std::vector<bool> terminal; // match states (directly or via a suffix link)
130+
std::set<uint32_t> alphabet; // every character with a transition
131+
132+
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
133+
const auto & nodes = t.nodes;
134+
const size_t n = nodes.size();
135+
136+
fail.assign(n, 0);
137+
order.reserve(n);
138+
139+
std::deque<size_t> queue{ 0 };
140+
while (!queue.empty()) {
141+
size_t u = queue.front();
142+
queue.pop_front();
143+
order.push_back(u);
144+
for (const auto & [ch, v] : nodes[u].children) {
145+
if (u != 0) {
146+
size_t f = fail[u];
147+
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
148+
f = fail[f];
149+
}
150+
auto it = nodes[f].children.find(ch);
151+
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
152+
}
153+
queue.push_back(v);
154+
}
155+
}
156+
157+
terminal.assign(n, false);
158+
for (size_t u : order) {
159+
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
160+
}
161+
162+
for (const auto & node : nodes) {
163+
for (const auto & [ch, v] : node.children) {
164+
alphabet.insert(ch);
165+
}
166+
}
167+
}
168+
169+
size_t num_states() const { return t.nodes.size(); }
170+
bool is_terminal(size_t s) const { return terminal[s]; }
171+
172+
// follow failure links until a transition on `ch` exists.
173+
size_t next(size_t state, uint32_t ch) const {
174+
const auto & nodes = t.nodes;
175+
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
176+
state = fail[state];
177+
}
178+
auto it = nodes[state].children.find(ch);
179+
return it != nodes[state].children.end() ? it->second : 0;
180+
}
181+
};
182+
156183
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
157184
if (pos + hex_count > str.length()) {
158185
return {0, 0};
@@ -992,12 +1019,12 @@ void common_peg_arena::resolve_refs() {
9921019
}
9931020

9941021
std::string common_peg_arena::dump(common_peg_parser_id id) const {
995-
std::unordered_set<common_peg_parser_id> visited;
1022+
std::set<common_peg_parser_id> visited;
9961023
return dump_impl(id, visited);
9971024
}
9981025

9991026
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
1000-
std::unordered_set<common_peg_parser_id> & visited) const {
1027+
std::set<common_peg_parser_id> & visited) const {
10011028
// Check for cycles
10021029
if (visited.count(id)) {
10031030
return "[cycle]";
@@ -1502,61 +1529,74 @@ static std::string gbnf_escape_char_class(uint32_t c) {
15021529
return std::string(buf);
15031530
}
15041531

1505-
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
1506-
trie matcher(strings);
1507-
auto pieces = matcher.collect_prefix_and_next();
1532+
// GBNF grammar matching strings that contain no string in `strings` as a
1533+
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
1534+
// the start state rule name.
1535+
//
1536+
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
1537+
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
1538+
const std::string & prefix,
1539+
const std::vector<std::string> & strings) {
1540+
aho_corasick ac(strings);
15081541

1509-
std::string pattern;
1510-
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
1511-
for (size_t i = 0; i < pieces.size(); ++i) {
1512-
if (i > 0) {
1513-
pattern += " | ";
1542+
auto state_name = [&](size_t s) -> std::string {
1543+
if (s == 0) {
1544+
return prefix;
15141545
}
1546+
std::string num = std::to_string(s);
1547+
num = num.size() == 1 ? ("0" + num) : num;
1548+
return prefix + "-" + num;
1549+
};
15151550

1516-
const auto & pre = pieces[i].prefix;
1517-
const auto & chars = pieces[i].next_chars;
1518-
1519-
std::string cls;
1520-
cls.reserve(chars.size());
1551+
auto char_class = [](const std::vector<uint32_t> & chars, bool negate) {
1552+
std::string s = negate ? "[^" : "[";
15211553
for (uint32_t ch : chars) {
1522-
cls += gbnf_escape_char_class(ch);
1523-
}
1524-
1525-
if (!pre.empty()) {
1526-
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
1527-
pattern += pre_literal + " [^" + cls + "]";
1528-
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
1529-
// char, so the repetition alone cannot match a value that *ends* on a proper
1530-
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
1531-
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
1532-
// values, so without this the grammar would reject input the parser accepts.
1533-
// Allow the value to terminate on any proper prefix as an optional tail.
1534-
// This makes the grammar a slight superset of the runtime language (a value
1535-
// may end on the longest prefix, which greedy first-match would not itself
1536-
// produce); harmless for constrained generation, which only needs to admit
1537-
// every runtime-valid string.
1538-
if (!trailing.empty()) {
1539-
trailing += " | ";
1540-
}
1541-
trailing += pre_literal;
1542-
} else {
1543-
pattern += "[^" + cls + "]";
1554+
s += gbnf_escape_char_class(ch);
15441555
}
1556+
return s + "]";
1557+
};
1558+
1559+
for (size_t q = 0; q < ac.num_states(); q++) {
1560+
if (ac.is_terminal(q)) {
1561+
continue; // match states are dropped
1562+
}
1563+
1564+
std::map<size_t, std::vector<uint32_t>> buckets;
1565+
std::vector<uint32_t> excluded;
1566+
for (uint32_t c : ac.alphabet) {
1567+
size_t d = ac.next(q, c);
1568+
if (ac.is_terminal(d)) {
1569+
excluded.push_back(c); // completes a forbidden string -> omit
1570+
} else if (d != 0) {
1571+
buckets[d].push_back(c); // specific non-root destination
1572+
excluded.push_back(c);
1573+
}
1574+
}
1575+
1576+
std::string rhs = "|"; // every state is accepting
1577+
for (const auto & [d, chars] : buckets) {
1578+
rhs += " " + char_class(chars, false) + " " + state_name(d) + " |";
1579+
}
1580+
rhs += " " + char_class(excluded, true) + " " + state_name(0);
1581+
1582+
builder.add_rule(state_name(q), rhs);
15451583
}
15461584

1547-
std::string result = "(" + pattern + ")*";
1548-
if (!trailing.empty()) {
1549-
result += " (" + trailing + ")?";
1585+
// An empty delimiter makes the start state terminal. Emit an entry rule
1586+
// that matches nothing so the returned reference stays valid.
1587+
if (ac.is_terminal(0)) {
1588+
builder.add_rule(prefix, "|");
15501589
}
1551-
return result;
1590+
1591+
return state_name(0);
15521592
}
15531593

1554-
static std::unordered_set<std::string> collect_reachable_rules(
1594+
static std::set<std::string> collect_reachable_rules(
15551595
const common_peg_arena & arena,
15561596
const common_peg_parser_id & rule
15571597
) {
1558-
std::unordered_set<std::string> reachable;
1559-
std::unordered_set<std::string> visited;
1598+
std::set<std::string> reachable;
1599+
std::set<std::string> visited;
15601600

15611601
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
15621602
const auto & parser = arena.get(id);
@@ -1765,7 +1805,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
17651805
if (p.delimiters.empty()) {
17661806
return ".*";
17671807
}
1768-
return gbnf_excluding_pattern(p.delimiters);
1808+
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
17691809
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
17701810
if (schema_delegates(p)) {
17711811
return to_gbnf(p.child);
@@ -1789,7 +1829,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
17891829
};
17901830

17911831
// Collect reachable rules
1792-
std::unordered_set<std::string> reachable_rules;
1832+
std::set<std::string> reachable_rules;
17931833

17941834
if (lazy) {
17951835
// Collect rules reachable from trigger rules

common/peg-parser.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#include <nlohmann/json_fwd.hpp>
44

55
#include <memory>
6+
#include <set>
67
#include <unordered_map>
7-
#include <unordered_set>
88
#include <string>
99
#include <string_view>
1010
#include <functional>
@@ -335,7 +335,7 @@ class common_peg_arena {
335335
friend class common_peg_parser_builder;
336336

337337
private:
338-
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
338+
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
339339

340340
common_peg_parser_id add_parser(common_peg_parser_variant parser);
341341
void add_rule(const std::string & name, common_peg_parser_id id);

tests/peg-parser/test-gbnf-generation.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,86 @@ void test_gbnf_generation(testing &t) {
129129
});
130130

131131
assert_gbnf_equal(t, R"""(
132-
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])* ("<" | "</" | "</t" | "</ta" | "</tag")?
132+
root ::= until-0
133133
space ::= | " " | "\n"{1,2} [ \t]{0,20}
134+
until-0 ::= | [<] until-0-01 | [^<] until-0
135+
until-0-01 ::= | [<] until-0-01 | [/] until-0-02 | [^/<] until-0
136+
until-0-02 ::= | [<] until-0-01 | [t] until-0-03 | [^<t] until-0
137+
until-0-03 ::= | [<] until-0-01 | [a] until-0-04 | [^<a] until-0
138+
until-0-04 ::= | [<] until-0-01 | [g] until-0-05 | [^<g] until-0
139+
until-0-05 ::= | [<] until-0-01 | [^<>] until-0
140+
)""", gbnf);
141+
});
142+
143+
t.test("until grammar overlapping delimiter", [](testing &t) {
144+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
145+
return p.until("\n</parameter>\n");
146+
});
147+
148+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
149+
parser.build_grammar(builder);
150+
});
151+
152+
assert_gbnf_equal(t, R"""(
153+
root ::= until-0
154+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
155+
until-0 ::= | [\n] until-0-01 | [^\n] until-0
156+
until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0
157+
until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0
158+
until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0
159+
until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0
160+
until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0
161+
until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0
162+
until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0
163+
until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0
164+
until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0
165+
until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0
166+
until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0
167+
until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0
168+
until-0-13 ::= | [^\n] until-0
169+
)""", gbnf);
170+
});
171+
172+
// DeepSeek-V3.2 tag prefix. The DSML token (|DSML|) embeds U+FF5C,
173+
// so the delimiter mixes ASCII and multi-byte codepoints.
174+
t.test("until grammar unicode delimiter", [](testing &t) {
175+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
176+
return p.until("<|DSML|");
177+
});
178+
179+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
180+
parser.build_grammar(builder);
181+
});
182+
183+
assert_gbnf_equal(t, R"""(
184+
root ::= until-0
185+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
186+
until-0 ::= | [<] until-0-01 | [^<] until-0
187+
until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0
188+
until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^<D] until-0
189+
until-0-03 ::= | [<] until-0-01 | [S] until-0-04 | [^<S] until-0
190+
until-0-04 ::= | [<] until-0-01 | [M] until-0-05 | [^<M] until-0
191+
until-0-05 ::= | [<] until-0-01 | [L] until-0-06 | [^<L] until-0
192+
until-0-06 ::= | [<] until-0-01 | [^<\uFF5C] until-0
193+
)""", gbnf);
194+
});
195+
196+
t.test("until grammar multiple delimiters", [](testing &t) {
197+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
198+
return p.until_one_of({"ab", "cd", "ef"});
199+
});
200+
201+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
202+
parser.build_grammar(builder);
203+
});
204+
205+
assert_gbnf_equal(t, R"""(
206+
root ::= until-0
207+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
208+
until-0 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^ace] until-0
209+
until-0-01 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^abce] until-0
210+
until-0-03 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acde] until-0
211+
until-0-05 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acef] until-0
134212
)""", gbnf);
135213
});
136214

tools/server/server-tools.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <cstring>
1212
#include <climits>
1313
#include <algorithm>
14+
#include <unordered_set>
1415

1516
namespace fs = std::filesystem;
1617

0 commit comments

Comments
 (0)