Skip to content

Commit e21cdc1

Browse files
authored
common/gemma4 : handle parsing edge cases (ggml-org#21760)
1 parent e974923 commit e21cdc1

5 files changed

Lines changed: 140 additions & 7 deletions

File tree

common/chat.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
10911091
common_chat_params data;
10921092

10931093
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
1094+
1095+
if (inputs.add_generation_prompt && string_ends_with(data.prompt, "<turn|>\n")) {
1096+
// This may happen if the model generates content + tool_call, the
1097+
// template does not add the model's next turn and confuses the model
1098+
// from emitting its proper reasoning token sequence.
1099+
data.prompt += "<|turn>model\n";
1100+
}
1101+
10941102
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
10951103
data.supports_thinking = true;
10961104
data.thinking_start_tag = "<|channel>thought";
@@ -1118,7 +1126,8 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
11181126
p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("<channel|>") + p.literal("<channel|>")));
11191127
}
11201128

1121-
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));
1129+
auto consume_empty_channels = p.gbnf(p.zero_or_more(p.literal("<|channel>") + p.negate(p.literal("thought"))), "");
1130+
auto thought = (p.peek(p.literal("<|channel>")) + consume_empty_channels + p.ref("thought")) | p.negate(p.literal("<|channel>"));
11221131

11231132
if (has_response_format) {
11241133
auto response_format = p.literal("```json") <<
@@ -1182,12 +1191,16 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
11821191
/* max = */ inputs.parallel_tool_calls ? -1 : 1
11831192
));
11841193

1185-
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"})));
1194+
auto scan_to_toolcall = p.rule("scan-to-toolcall", p.until("<|tool_call>"));
1195+
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>", "<|tool_call>"})));
11861196
auto message = p.rule("message", thought + content);
1187-
return start + p.zero_or_more(message) + tool_call;
1197+
return start + p.zero_or_more(message) + scan_to_toolcall + tool_call;
11881198
}
11891199

1190-
auto content = p.rule("content", p.content(p.until("<|channel>")));
1200+
// Gemma 4 may emit an extra <|channel>thought\n<channel|> at the end of the content. It may
1201+
// also emit a single trailing <channel|> token. Consume all complete reasoning blocks and
1202+
// then stop at the first unmatched <channel|> token.
1203+
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>"})));
11911204
auto message = p.rule("message", thought + content);
11921205
return start + p.one_or_more(message);
11931206
});

common/peg-parser.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,10 @@ struct parser_executor {
890890
}
891891
return result;
892892
}
893+
894+
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
895+
return arena.parse(p.child, ctx, start_pos);
896+
}
893897
};
894898

895899
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
@@ -957,7 +961,8 @@ void common_peg_arena::resolve_refs() {
957961
std::is_same_v<T, common_peg_and_parser> ||
958962
std::is_same_v<T, common_peg_not_parser> ||
959963
std::is_same_v<T, common_peg_tag_parser> ||
960-
std::is_same_v<T, common_peg_atomic_parser>) {
964+
std::is_same_v<T, common_peg_atomic_parser> ||
965+
std::is_same_v<T, common_peg_gbnf_parser>) {
961966
p.child = resolve_ref(p.child);
962967
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
963968
p.child = resolve_ref(p.child);
@@ -1036,6 +1041,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
10361041
return "Not(" + dump_impl(p.child, visited) + ")";
10371042
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
10381043
return "Atomic(" + dump_impl(p.child, visited) + ")";
1044+
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
1045+
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
10391046
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
10401047
return "Any";
10411048
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
@@ -1565,6 +1572,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
15651572
std::is_same_v<T, common_peg_not_parser> ||
15661573
std::is_same_v<T, common_peg_tag_parser> ||
15671574
std::is_same_v<T, common_peg_atomic_parser> ||
1575+
std::is_same_v<T, common_peg_gbnf_parser> ||
15681576
std::is_same_v<T, common_peg_schema_parser>) {
15691577
visit(p.child);
15701578
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
@@ -1651,10 +1659,13 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
16511659
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
16521660
std::string s;
16531661
for (const auto & child : p.children) {
1662+
auto child_gbnf = to_gbnf(child);
1663+
if (child_gbnf.empty()) {
1664+
continue;
1665+
}
16541666
if (!s.empty()) {
16551667
s += " ";
16561668
}
1657-
auto child_gbnf = to_gbnf(child);
16581669
const auto & child_parser = effective_parser(child);
16591670
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
16601671
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
@@ -1754,6 +1765,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
17541765
return to_gbnf(p.child);
17551766
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
17561767
return to_gbnf(p.child);
1768+
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
1769+
return p.grammar;
17571770
} else {
17581771
static_assert(is_always_false_v<T>);
17591772
}
@@ -1888,6 +1901,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
18881901
{"child", p.child},
18891902
{"tag", p.tag}
18901903
};
1904+
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
1905+
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
18911906
}
18921907
}, variant);
18931908
}
@@ -2050,6 +2065,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
20502065
};
20512066
}
20522067

2068+
if (type == "gbnf") {
2069+
if (!j.contains("child") || !j.contains("grammar")) {
2070+
throw std::runtime_error("gbnf parser missing required fields");
2071+
}
2072+
return common_peg_gbnf_parser{
2073+
j["child"].get<common_peg_parser_id>(),
2074+
j["grammar"].get<std::string>(),
2075+
};
2076+
}
2077+
20532078
throw std::runtime_error("Unknown parser type: " + type);
20542079
}
20552080

common/peg-parser.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ struct common_peg_tag_parser {
270270
std::string tag;
271271
};
272272

273+
struct common_peg_gbnf_parser {
274+
common_peg_parser_id child;
275+
std::string grammar;
276+
};
277+
273278
// Variant holding all parser types
274279
using common_peg_parser_variant = std::variant<
275280
common_peg_epsilon_parser,
@@ -290,7 +295,8 @@ using common_peg_parser_variant = std::variant<
290295
common_peg_rule_parser,
291296
common_peg_ref_parser,
292297
common_peg_atomic_parser,
293-
common_peg_tag_parser
298+
common_peg_tag_parser,
299+
common_peg_gbnf_parser
294300
>;
295301

296302
class common_peg_arena {
@@ -504,6 +510,10 @@ class common_peg_parser_builder {
504510
// Unlike rules, you can tag multiple nodes with the same tag.
505511
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
506512

513+
// Wraps a child parser but emits a custom GBNF grammar string instead of
514+
// the child's grammar. Parsing delegates entirely to the child.
515+
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
516+
507517
void set_root(const common_peg_parser & p);
508518

509519
common_peg_arena build();

tests/peg-parser/test-gbnf-generation.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,66 @@ void test_gbnf_generation(testing &t) {
258258
)""", gbnf);
259259
});
260260

261+
t.test("silent parser emits nothing in gbnf", [](testing &t) {
262+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
263+
return p.literal("hello") + p.gbnf(p.literal("world"), "");
264+
});
265+
266+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
267+
parser.build_grammar(builder);
268+
});
269+
270+
assert_gbnf_equal(t, R"""(
271+
root ::= "hello"
272+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
273+
)""", gbnf);
274+
});
275+
276+
t.test("silent choice inside sequence emits nothing", [](testing &t) {
277+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
278+
return p.literal("a") + p.gbnf(p.literal("b") | p.literal("c"), "") + p.literal("d");
279+
});
280+
281+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
282+
parser.build_grammar(builder);
283+
});
284+
285+
assert_gbnf_equal(t, R"""(
286+
root ::= "a" "d"
287+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
288+
)""", gbnf);
289+
});
290+
291+
t.test("silent wrapped in tag emits nothing", [](testing &t) {
292+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
293+
return p.literal("a") + p.tag("t", p.gbnf(p.literal("b"), ""));
294+
});
295+
296+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
297+
parser.build_grammar(builder);
298+
});
299+
300+
assert_gbnf_equal(t, R"""(
301+
root ::= "a"
302+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
303+
)""", gbnf);
304+
});
305+
306+
t.test("gbnf parser emits custom grammar", [](testing &t) {
307+
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
308+
return p.literal("a") + p.gbnf(p.literal("b"), "[a-z]+");
309+
});
310+
311+
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
312+
parser.build_grammar(builder);
313+
});
314+
315+
assert_gbnf_equal(t, R"""(
316+
root ::= "a" [a-z]+
317+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
318+
)""", gbnf);
319+
});
320+
261321
t.test("nested transparent wrappers get parenthesized", [](testing &t) {
262322
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
263323
return p.literal("x") + p.tag("outer", p.atomic(p.literal("a") | p.literal("b")));

tests/test-chat.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,6 +2118,31 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
21182118
.tools({ amount_tool })
21192119
.expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})"))
21202120
.run();
2121+
2122+
// Edge cases
2123+
tst.test(
2124+
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<channel|>")
2125+
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
2126+
.expect(message_assist)
2127+
.run();
2128+
2129+
tst.test(
2130+
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|>")
2131+
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
2132+
.expect(message_assist)
2133+
.run();
2134+
2135+
tst.test(
2136+
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|><channel|>")
2137+
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
2138+
.expect(message_assist)
2139+
.run();
2140+
2141+
tst.test(
2142+
"<|channel><|channel>thought\n<channel|>Hello, world!\nWhat's up?")
2143+
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
2144+
.expect(message_assist)
2145+
.run();
21212146
}
21222147

21232148
{

0 commit comments

Comments
 (0)