66#include " unicode.h"
77
88#include < algorithm>
9+ #include < deque>
910#include < initializer_list>
1011#include < map>
1112#include < memory>
1213#include < nlohmann/json.hpp>
1314#include < regex>
15+ #include < set>
1416#include < stdexcept>
15- #include < unordered_set>
1617
1718// Trick to catch missing branches
1819template <typename T>
@@ -88,40 +89,7 @@ struct trie {
8889 return match_result{match_result::NO_MATCH };
8990 }
9091
91- struct prefix_and_next {
92- std::vector<uint32_t > prefix;
93- std::vector<uint32_t > next_chars;
94- };
95-
96- std::vector<prefix_and_next> collect_prefix_and_next () {
97- std::vector<uint32_t > prefix;
98- std::vector<prefix_and_next> result;
99- collect_prefix_and_next (0 , prefix, result);
100- return result;
101- }
102-
10392 private:
104- void collect_prefix_and_next (size_t index, std::vector<uint32_t > & prefix, std::vector<prefix_and_next> & out) {
105- if (!nodes[index].is_word ) {
106- if (!nodes[index].children .empty ()) {
107- std::vector<uint32_t > chars;
108- chars.reserve (nodes[index].children .size ());
109- for (const auto & p : nodes[index].children ) {
110- chars.push_back (p.first );
111- }
112- out.emplace_back (prefix_and_next{prefix, chars});
113- }
114- }
115-
116- for (const auto & p : nodes[index].children ) {
117- uint32_t ch = p.first ;
118- auto child = p.second ;
119- prefix.push_back (ch);
120- collect_prefix_and_next (child, prefix, out);
121- prefix.pop_back ();
122- }
123- }
124-
12593 size_t create_node () {
12694 size_t index = nodes.size ();
12795 nodes.emplace_back ();
@@ -153,6 +121,65 @@ struct trie {
153121 }
154122};
155123
124+ // Aho-Corasick automaton
125+ struct aho_corasick {
126+ trie t;
127+ std::vector<size_t > fail; // failure links
128+ std::vector<size_t > order; // states in BFS order
129+ std::vector<bool > terminal; // match states (directly or via a suffix link)
130+ std::set<uint32_t > alphabet; // every character with a transition
131+
132+ aho_corasick (const std::vector<std::string> & strings) : t(strings) {
133+ const auto & nodes = t.nodes ;
134+ const size_t n = nodes.size ();
135+
136+ fail.assign (n, 0 );
137+ order.reserve (n);
138+
139+ std::deque<size_t > queue{ 0 };
140+ while (!queue.empty ()) {
141+ size_t u = queue.front ();
142+ queue.pop_front ();
143+ order.push_back (u);
144+ for (const auto & [ch, v] : nodes[u].children ) {
145+ if (u != 0 ) {
146+ size_t f = fail[u];
147+ while (f && nodes[f].children .find (ch) == nodes[f].children .end ()) {
148+ f = fail[f];
149+ }
150+ auto it = nodes[f].children .find (ch);
151+ fail[v] = (it != nodes[f].children .end () && it->second != v) ? it->second : 0 ;
152+ }
153+ queue.push_back (v);
154+ }
155+ }
156+
157+ terminal.assign (n, false );
158+ for (size_t u : order) {
159+ terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
160+ }
161+
162+ for (const auto & node : nodes) {
163+ for (const auto & [ch, v] : node.children ) {
164+ alphabet.insert (ch);
165+ }
166+ }
167+ }
168+
169+ size_t num_states () const { return t.nodes .size (); }
170+ bool is_terminal (size_t s) const { return terminal[s]; }
171+
172+ // follow failure links until a transition on `ch` exists.
173+ size_t next (size_t state, uint32_t ch) const {
174+ const auto & nodes = t.nodes ;
175+ while (state && nodes[state].children .find (ch) == nodes[state].children .end ()) {
176+ state = fail[state];
177+ }
178+ auto it = nodes[state].children .find (ch);
179+ return it != nodes[state].children .end () ? it->second : 0 ;
180+ }
181+ };
182+
156183static std::pair<uint32_t , size_t > parse_hex_escape (const std::string & str, size_t pos, int hex_count) {
157184 if (pos + hex_count > str.length ()) {
158185 return {0 , 0 };
@@ -992,12 +1019,12 @@ void common_peg_arena::resolve_refs() {
9921019}
9931020
9941021std::string common_peg_arena::dump (common_peg_parser_id id) const {
995- std::unordered_set <common_peg_parser_id> visited;
1022+ std::set <common_peg_parser_id> visited;
9961023 return dump_impl (id, visited);
9971024}
9981025
9991026std::string common_peg_arena::dump_impl (common_peg_parser_id id,
1000- std::unordered_set <common_peg_parser_id> & visited) const {
1027+ std::set <common_peg_parser_id> & visited) const {
10011028 // Check for cycles
10021029 if (visited.count (id)) {
10031030 return " [cycle]" ;
@@ -1502,61 +1529,74 @@ static std::string gbnf_escape_char_class(uint32_t c) {
15021529 return std::string (buf);
15031530}
15041531
1505- static std::string gbnf_excluding_pattern (const std::vector<std::string> & strings) {
1506- trie matcher (strings);
1507- auto pieces = matcher.collect_prefix_and_next ();
1532+ // GBNF grammar matching strings that contain no string in `strings` as a
1533+ // substring. Emits the complement of an Aho-Corasick automaton DFA and returns
1534+ // the start state rule name.
1535+ //
1536+ // ref: https://github.com/ggml-org/llama.cpp/pull/24839
1537+ static std::string gbnf_excluding_grammar (const common_grammar_builder & builder,
1538+ const std::string & prefix,
1539+ const std::vector<std::string> & strings) {
1540+ aho_corasick ac (strings);
15081541
1509- std::string pattern;
1510- std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
1511- for (size_t i = 0 ; i < pieces.size (); ++i) {
1512- if (i > 0 ) {
1513- pattern += " | " ;
1542+ auto state_name = [&](size_t s) -> std::string {
1543+ if (s == 0 ) {
1544+ return prefix;
15141545 }
1546+ std::string num = std::to_string (s);
1547+ num = num.size () == 1 ? (" 0" + num) : num;
1548+ return prefix + " -" + num;
1549+ };
15151550
1516- const auto & pre = pieces[i].prefix ;
1517- const auto & chars = pieces[i].next_chars ;
1518-
1519- std::string cls;
1520- cls.reserve (chars.size ());
1551+ auto char_class = [](const std::vector<uint32_t > & chars, bool negate) {
1552+ std::string s = negate ? " [^" : " [" ;
15211553 for (uint32_t ch : chars) {
1522- cls += gbnf_escape_char_class (ch);
1523- }
1524-
1525- if (!pre .empty ()) {
1526- std::string pre_literal = gbnf_format_literal (common_unicode_cpts_to_utf8 (pre ));
1527- pattern += pre_literal + " [^" + cls + " ]" ;
1528- // Each interior alternative consumes a delimiter-prefix plus a disambiguating
1529- // char, so the repetition alone cannot match a value that *ends* on a proper
1530- // prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
1531- // "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
1532- // values, so without this the grammar would reject input the parser accepts.
1533- // Allow the value to terminate on any proper prefix as an optional tail.
1534- // This makes the grammar a slight superset of the runtime language (a value
1535- // may end on the longest prefix, which greedy first-match would not itself
1536- // produce); harmless for constrained generation, which only needs to admit
1537- // every runtime-valid string.
1538- if (!trailing.empty ()) {
1539- trailing += " | " ;
1540- }
1541- trailing += pre_literal;
1542- } else {
1543- pattern += " [^" + cls + " ]" ;
1554+ s += gbnf_escape_char_class (ch);
15441555 }
1556+ return s + " ]" ;
1557+ };
1558+
1559+ for (size_t q = 0 ; q < ac.num_states (); q++) {
1560+ if (ac.is_terminal (q)) {
1561+ continue ; // match states are dropped
1562+ }
1563+
1564+ std::map<size_t , std::vector<uint32_t >> buckets;
1565+ std::vector<uint32_t > excluded;
1566+ for (uint32_t c : ac.alphabet ) {
1567+ size_t d = ac.next (q, c);
1568+ if (ac.is_terminal (d)) {
1569+ excluded.push_back (c); // completes a forbidden string -> omit
1570+ } else if (d != 0 ) {
1571+ buckets[d].push_back (c); // specific non-root destination
1572+ excluded.push_back (c);
1573+ }
1574+ }
1575+
1576+ std::string rhs = " |" ; // every state is accepting
1577+ for (const auto & [d, chars] : buckets) {
1578+ rhs += " " + char_class (chars, false ) + " " + state_name (d) + " |" ;
1579+ }
1580+ rhs += " " + char_class (excluded, true ) + " " + state_name (0 );
1581+
1582+ builder.add_rule (state_name (q), rhs);
15451583 }
15461584
1547- std::string result = " (" + pattern + " )*" ;
1548- if (!trailing.empty ()) {
1549- result += " (" + trailing + " )?" ;
1585+ // An empty delimiter makes the start state terminal. Emit an entry rule
1586+ // that matches nothing so the returned reference stays valid.
1587+ if (ac.is_terminal (0 )) {
1588+ builder.add_rule (prefix, " |" );
15501589 }
1551- return result;
1590+
1591+ return state_name (0 );
15521592}
15531593
1554- static std::unordered_set <std::string> collect_reachable_rules (
1594+ static std::set <std::string> collect_reachable_rules (
15551595 const common_peg_arena & arena,
15561596 const common_peg_parser_id & rule
15571597) {
1558- std::unordered_set <std::string> reachable;
1559- std::unordered_set <std::string> visited;
1598+ std::set <std::string> reachable;
1599+ std::set <std::string> visited;
15601600
15611601 std::function<void (common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
15621602 const auto & parser = arena.get (id);
@@ -1765,7 +1805,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
17651805 if (p.delimiters .empty ()) {
17661806 return " .*" ;
17671807 }
1768- return gbnf_excluding_pattern ( p.delimiters );
1808+ return gbnf_excluding_grammar (builder, " until- " + std::to_string (id), p.delimiters );
17691809 } else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
17701810 if (schema_delegates (p)) {
17711811 return to_gbnf (p.child );
@@ -1789,7 +1829,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
17891829 };
17901830
17911831 // Collect reachable rules
1792- std::unordered_set <std::string> reachable_rules;
1832+ std::set <std::string> reachable_rules;
17931833
17941834 if (lazy) {
17951835 // Collect rules reachable from trigger rules
0 commit comments