Skip to content

Commit 2830d68

Browse files
committed
tests: cover DFlash tool call boundaries
Add regression coverage for Kimi/Qwen partial tool-call streaming, raw marker quarantine, fenced-code false positives, direct Qwen function starts, lazy grammar triggers, and DFlash boundary plumbing.
1 parent 7924a6e commit 2830d68

3 files changed

Lines changed: 164 additions & 2 deletions

File tree

tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
157157
llama_build_and_test(test-grammar-integration.cpp)
158158
llama_build_and_test(test-llama-grammar.cpp)
159159
llama_build_and_test(test-chat.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
160-
target_include_directories(test-chat PRIVATE ${PROJECT_SOURCE_DIR}/tools/server)
160+
target_include_directories(test-chat PRIVATE ${PROJECT_SOURCE_DIR}/tools/server ${PROJECT_SOURCE_DIR}/tools/mtmd)
161161
target_link_libraries(test-chat PRIVATE server-context)
162162
llama_build_and_test(test-server-loop-guard.cpp)
163163
target_include_directories(test-server-loop-guard PRIVATE ${PROJECT_SOURCE_DIR}/tools/server)

tests/test-chat.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "../src/llama-grammar.h"
99
#include "../src/unicode.h"
1010
#include "../tools/server/server-chat.h"
11+
#include "../tools/server/server-task.h"
1112
#include "chat-auto-parser.h"
1213
#include "chat.h"
1314
#include "common.h"
@@ -4434,6 +4435,141 @@ static void test_msg_diffs_compute() {
44344435
}
44354436
}
44364437

4438+
static void test_task_result_state_tool_call_stream_filter() {
4439+
auto tmpls = read_templates("models/templates/Kimi-K2-Thinking.jinja");
4440+
4441+
common_chat_templates_inputs inputs;
4442+
inputs.messages = { message_user };
4443+
inputs.tools = { special_function_tool };
4444+
inputs.parallel_tool_calls = true;
4445+
4446+
make_peg_parser parser(tmpls.get(), inputs);
4447+
4448+
common_chat_parser_params parser_params(parser.params_);
4449+
parser_params.parser = parser.arena_;
4450+
parser_params.parse_tool_calls = true;
4451+
4452+
task_result_state state(parser_params);
4453+
4454+
{
4455+
std::vector<common_chat_msg_diff> diffs;
4456+
state.update_chat_msg(
4457+
"<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0",
4458+
true,
4459+
diffs,
4460+
true);
4461+
assert_equals(size_t(0), diffs.size());
4462+
}
4463+
4464+
{
4465+
std::vector<common_chat_msg_diff> diffs;
4466+
state.update_chat_msg("<|tool_call_argument_begin|>", true, diffs, true);
4467+
assert_equals(size_t(1), diffs.size());
4468+
assert_equals(size_t(0), diffs[0].tool_call_index);
4469+
assert_equals(std::string("special_function"), diffs[0].tool_call_delta.name);
4470+
assert_equals(std::string("functions.special_function:0"), diffs[0].tool_call_delta.id);
4471+
assert_equals(std::string(""), diffs[0].tool_call_delta.arguments);
4472+
}
4473+
4474+
{
4475+
std::vector<common_chat_msg_diff> diffs;
4476+
state.update_chat_msg("{\"arg1\": ", true, diffs, true);
4477+
assert_equals(size_t(0), diffs.size());
4478+
}
4479+
4480+
{
4481+
std::vector<common_chat_msg_diff> diffs;
4482+
state.update_chat_msg("1}<|tool_call_end|><|tool_calls_section_end|>", true, diffs, true);
4483+
assert_equals(size_t(1), diffs.size());
4484+
assert_equals(size_t(0), diffs[0].tool_call_index);
4485+
assert_equals(std::string(""), diffs[0].tool_call_delta.name);
4486+
assert_equals(std::string(""), diffs[0].tool_call_delta.id);
4487+
assert_equals(std::string("{\"arg1\": 1}"), diffs[0].tool_call_delta.arguments);
4488+
}
4489+
4490+
{
4491+
task_result_state raw_state(parser_params);
4492+
std::vector<common_chat_msg_diff> diffs;
4493+
raw_state.update_chat_msg("Visible before marker\n", true, diffs, true);
4494+
assert_equals(size_t(1), diffs.size());
4495+
assert_equals(std::string("Visible before marker\n"), diffs[0].content_delta);
4496+
4497+
diffs.clear();
4498+
raw_state.update_chat_msg("<function=read the llama_perf_context_data struct>", true, diffs, true);
4499+
assert_equals(size_t(0), diffs.size());
4500+
4501+
diffs.clear();
4502+
raw_state.update_chat_msg(" trailing text", false, diffs, true);
4503+
assert_equals(size_t(0), diffs.size());
4504+
}
4505+
4506+
{
4507+
task_result_state code_state(parser_params);
4508+
std::vector<common_chat_msg_diff> diffs;
4509+
const std::string code = "```xml\n<function=example>\n```\n";
4510+
code_state.update_chat_msg(code, true, diffs, true);
4511+
assert_equals(size_t(1), diffs.size());
4512+
assert_equals(code, diffs[0].content_delta);
4513+
}
4514+
4515+
{
4516+
auto qwen_tmpls = read_templates("models/templates/Qwen3.5-4B.jinja");
4517+
4518+
common_chat_templates_inputs qwen_inputs;
4519+
qwen_inputs.messages = { message_user };
4520+
qwen_inputs.tools = { special_function_tool };
4521+
qwen_inputs.parallel_tool_calls = true;
4522+
4523+
make_peg_parser qwen_parser(qwen_tmpls.get(), qwen_inputs);
4524+
common_chat_parser_params qwen_params(qwen_parser.params_);
4525+
qwen_params.parser = qwen_parser.arena_;
4526+
qwen_params.parse_tool_calls = true;
4527+
4528+
bool has_direct_function_trigger = false;
4529+
for (const auto & trigger : qwen_parser.params_.grammar_triggers) {
4530+
has_direct_function_trigger = has_direct_function_trigger || trigger.value == "<function=";
4531+
}
4532+
assert_equals(true, has_direct_function_trigger);
4533+
4534+
const std::string direct_call =
4535+
"<function=special_function>\n"
4536+
"<parameter=arg1>\n"
4537+
"1\n"
4538+
"</parameter>\n"
4539+
"</function>\n";
4540+
const auto direct_msg = common_chat_peg_parse(qwen_parser.arena_, direct_call, false, qwen_params);
4541+
assert_equals(size_t(1), direct_msg.tool_calls.size());
4542+
assert_equals(std::string("special_function"), direct_msg.tool_calls[0].name);
4543+
assert_equals(std::string("{\"arg1\":1}"), direct_msg.tool_calls[0].arguments);
4544+
4545+
task_result_state qwen_state(qwen_params);
4546+
4547+
std::vector<common_chat_msg_diff> diffs;
4548+
qwen_state.update_chat_msg("<function=special_function>\n", true, diffs, true);
4549+
assert_equals(std::vector<common_chat_msg_diff>{}, diffs);
4550+
4551+
diffs.clear();
4552+
qwen_state.update_chat_msg("<parameter=arg1>\n", true, diffs, true);
4553+
assert_equals(size_t(1), diffs.size());
4554+
assert_equals(size_t(0), diffs[0].tool_call_index);
4555+
assert_equals(std::string("special_function"), diffs[0].tool_call_delta.name);
4556+
assert_equals(false, diffs[0].tool_call_delta.id.empty());
4557+
assert_equals(std::string(""), diffs[0].tool_call_delta.arguments);
4558+
4559+
diffs.clear();
4560+
qwen_state.update_chat_msg("1\n", true, diffs, true);
4561+
assert_equals(size_t(0), diffs.size());
4562+
4563+
diffs.clear();
4564+
qwen_state.update_chat_msg("</parameter>\n</function>\n", true, diffs, true);
4565+
assert_equals(size_t(1), diffs.size());
4566+
assert_equals(size_t(0), diffs[0].tool_call_index);
4567+
assert_equals(std::string(""), diffs[0].tool_call_delta.name);
4568+
assert_equals(std::string(""), diffs[0].tool_call_delta.id);
4569+
assert_equals(std::string("{\"arg1\":1}"), diffs[0].tool_call_delta.arguments);
4570+
}
4571+
}
4572+
44374573
int main(int argc, char ** argv) {
44384574
bool detailed_debug = false;
44394575
bool only_run_filtered = false;
@@ -4512,6 +4648,7 @@ int main(int argc, char ** argv) {
45124648
test_convert_responses_to_chatcmpl();
45134649
test_developer_role_to_system_workaround();
45144650
test_reka_edge_common_path();
4651+
test_task_result_state_tool_call_stream_filter();
45154652
test_template_output_peg_parsers(detailed_debug);
45164653
std::cout << "\n[chat] All tests passed!" << '\n';
45174654
}

tests/test-dflash-plumbing.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ int main(int argc, char ** argv) {
4848
const std::string sampling_h = read_file(root + "/common/sampling.h");
4949
const std::string sampling_cpp = read_file(root + "/common/sampling.cpp");
5050
const std::string server_context = read_file(root + "/tools/server/server-context.cpp");
51+
const std::string server_task = read_file(root + "/tools/server/server-task.cpp");
52+
const std::string chat_auto_parser_generator = read_file(root + "/common/chat-auto-parser-generator.cpp");
5153
const std::string speculative = read_file(root + "/common/speculative.cpp");
5254
const std::string dflash_draft = read_file(root + "/src/models/dflash_draft.cpp");
5355
const std::string memory_recurrent = read_file(root + "/src/llama-memory-recurrent.cpp");
@@ -188,11 +190,34 @@ int main(int argc, char ** argv) {
188190
ok &= expect(cuda_argmax.find("const float raw_logit = heap_idx[i] >= 0 ? rowx[heap_idx[i]] : -FLT_MAX;") != std::string::npos, "CUDA deterministic top-K must return raw logits, not zero scores");
189191
ok &= expect(cuda_argmax.find("cub::DeviceTopK::MaxPairs") != std::string::npos, "CUDA deterministic top-K must use CUB fast path when available");
190192
ok &= expect(sampling_h.find("common_sampler_sample_reduced_and_accept_n") != std::string::npos, "common sampler must expose reduced-candidate verifier sampling");
193+
ok &= expect(sampling_h.find("common_sampler_blocks_speculative") != std::string::npos, "common sampler must expose grammar/reasoning guard for speculative decoding");
194+
ok &= expect(sampling_cpp.find("Lazy grammars are safe to speculate while still awaiting their trigger") != std::string::npos, "speculative guard must keep DFlash available before lazy grammar triggers");
195+
ok &= expect(sampling_cpp.find("if (common_sampler_blocks_speculative(gsmpl))") != std::string::npos, "speculative accept must stop when a token activates grammar/reasoning boundaries");
191196
ok &= expect(sampling_cpp.find("llama_sampler_apply(gsmpl->chain, &gsmpl->cur_p)") != std::string::npos, "reduced verifier must still run the sampler chain");
192197
ok &= expect(sampling_cpp.find("gsmpl->cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), -1, false }") != std::string::npos, "reduced verifier sampler must tolerate unsorted GPU top-K candidates");
193198
ok &= expect(sampling_cpp.find("common_reasoning_budget_get_state(gsmpl->rbudget) != REASONING_BUDGET_FORCING") != std::string::npos, "reduced verifier must allow passthrough reasoning-budget tracking");
194199
ok &= expect(sampling_cpp.find("llama_sampler_apply(gsmpl->rbudget, &gsmpl->cur_p)") != std::string::npos, "reduced verifier must preserve reasoning-budget sampler state");
195200
ok &= expect(server_context.find("dflash_select_reduced_verify_plan") != std::string::npos, "server must explicitly choose reduced verifier eligibility");
201+
ok &= expect(server_context.find("common_sampler_blocks_speculative(slot.smpl.get())") != std::string::npos, "DFlash server path must skip drafting when grammar/reasoning guard requires full sampling");
202+
ok &= expect(server_context.find("common_sampler_blocks_speculative(smpl)") != std::string::npos, "DFlash rejection sampling must stop at grammar/reasoning boundaries");
203+
ok &= expect(server_context.find("speculative_flat_result_has_bonus") != std::string::npos, "server must distinguish grammar-boundary stops from bonus-token accepts");
204+
ok &= expect(server_context.find("n_hidden_keep = ids.empty() ? 0 : n_accepted_draft + 1") != std::string::npos, "DFlash ring/tape keep count must include root plus accepted draft tokens");
205+
ok &= expect(server_context.find("common_speculative_accept(slot.spec.get(), n_accepted_draft)") != std::string::npos, "speculative stats must count accepted draft tokens, not bonus-token-shaped results");
206+
ok &= expect(server_context.find("llama_dflash_rollback(ctx, slot.id, seq_backup, slot.n_pos_before_draft, n_hidden_keep)") != std::string::npos, "DFlash rollback must use the hidden-state keep count at grammar boundaries");
207+
ok &= expect(server_context.find("dflash_suppressed_for_reasoning_tool_marker") != std::string::npos, "server must disable DFlash after raw tool markers inside hidden reasoning without steering generation");
208+
ok &= expect(server_task.find("state.update_chat_msg(content, true, oaicompat_msg_diffs, true)") != std::string::npos, "streaming responses must filter partial tool-call deltas");
209+
ok &= expect(server_task.find("task_result_has_complete_partial_tool_calls") != std::string::npos, "streaming responses must allow complete tool-call deltas before final EOS");
210+
ok &= expect(server_task.find("task_result_filter_incomplete_partial_tool_calls") != std::string::npos, "streaming responses must expose stable tool-call headers without partial arguments");
211+
ok &= expect(server_task.find("A partial stream may expose the stable tool name/id for UX") != std::string::npos, "partial tool-call streaming must document the header-only reliability boundary");
212+
ok &= expect(server_task.find("task_result_quarantine_raw_tool_text") != std::string::npos, "streaming responses must quarantine malformed raw tool markers in tool-parsing mode");
213+
ok &= expect(server_task.find("task_result_pos_is_in_code_fence") != std::string::npos, "raw marker quarantine must avoid code fence content");
214+
ok &= expect(server_task.find("task_result_starts_with_raw_tool_marker") != std::string::npos, "streaming responses must suppress parser fallback text for wrapperless raw tool calls");
215+
ok &= expect(server_task.find("task_result_freeze_text_fields") != std::string::npos, "incomplete parsed tool calls must not leak fallback text/reasoning deltas");
216+
ok &= expect(server_context.find("raw tool marker observed while lazy grammar is enabled") != std::string::npos, "DFlash must suppress after raw tool markers even outside parsed reasoning");
217+
ok &= expect(server_context.find("server_tail_pos_is_in_code_fence") != std::string::npos, "DFlash raw-marker suppression must avoid fenced code content");
218+
ok &= expect(server_context.find("server_tail_tool_marker_has_boundary") != std::string::npos, "DFlash raw-marker suppression must avoid embedded string false positives");
219+
ok &= expect(chat_auto_parser_generator.find("allow_direct_func_start") != std::string::npos, "tag-style parsers must accept valid direct function starts without outer wrappers");
220+
ok &= expect(chat_auto_parser_generator.find("autoparser.tools.function.name_prefix") != std::string::npos, "lazy grammar triggers must include structural function markers");
196221
ok &= expect(server_context.find("sampling.has_logit_bias() || sampling.ignore_eos") != std::string::npos, "server must not treat inactive precomputed EOG biases as active logit bias");
197222
ok &= expect(server_context.find("finite-reasoning-budget") != std::string::npos, "server must disable reduced verifier only for finite reasoning budgets");
198223
ok &= expect(server_context.find("llama_set_dflash_verify_logits(ctx, dflash_verify_graph_enabled") != std::string::npos, "server must enable reduced verifier graph once per eligible batch");
@@ -224,7 +249,7 @@ int main(int argc, char ** argv) {
224249
ok &= expect(server_context.find("rows_available") != std::string::npos, "server DFlash verifier padding must respect batch and ubatch capacity");
225250
ok &= expect(server_context.find("for (int idx : slot.spec_pad_i_batch)") != std::string::npos, "reduced verifier coverage must account for explicit padding rows");
226251
ok &= expect(server_context.find("const bool had_dflash_padding = !slot.spec_pad_i_batch.empty()") != std::string::npos, "server must remember verifier padding through accept bookkeeping");
227-
ok &= expect(server_context.find("const bool all_accepted_flat = (ids.size() == n_draft + 1) && !had_dflash_padding") != std::string::npos, "DFlash verifier padding must force rollback even when all real draft tokens were accepted");
252+
ok &= expect(server_context.find("const bool all_accepted_flat = (n_accepted_draft == (int) n_draft) && !had_dflash_padding") != std::string::npos, "DFlash verifier padding must force rollback even when all real draft tokens were accepted");
228253

229254
return ok ? 0 : 1;
230255
}

0 commit comments

Comments
 (0)