Skip to content

Commit 82ecad0

Browse files
committed
server: stop DFlash at grammar tool boundaries
Keep DFlash active before lazy grammar triggers, then stop speculative accept/drafting once grammar, reasoning-budget forcing, or raw tool-call markers require normal token-by-token sampling. Track accepted draft tokens separately from hidden-state rows so DFlash rollback and ring updates stay aligned at grammar/tool boundaries. Fixes #5 Refs #6
1 parent 84efc67 commit 82ecad0

4 files changed

Lines changed: 207 additions & 42 deletions

File tree

common/sampling.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -653,19 +653,38 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
653653
return id;
654654
}
655655

656+
static bool common_sampler_has_speculative_unsafe_grammar(const struct common_sampler * gsmpl) {
657+
if (!gsmpl || !gsmpl->grmr) {
658+
return false;
659+
}
660+
661+
// Lazy grammars are safe to speculate while still awaiting their trigger.
662+
// Once triggered, grammar-constrained regions need normal full-vocab
663+
// sampling and one-token streaming/parser boundaries.
664+
return llama_sampler_grammar_is_active(gsmpl->grmr);
665+
}
666+
667+
bool common_sampler_blocks_speculative(const struct common_sampler * gsmpl) {
668+
if (!gsmpl) {
669+
return true;
670+
}
671+
if (common_sampler_has_speculative_unsafe_grammar(gsmpl)) {
672+
return true;
673+
}
674+
return common_reasoning_budget_get_state(gsmpl->rbudget) == REASONING_BUDGET_FORCING;
675+
}
676+
656677
bool common_sampler_supports_reduced(struct common_sampler * gsmpl) {
657678
if (!gsmpl) {
658679
return false;
659680
}
660-
// A grammar sampler exists but may be lazy+inactive (awaiting trigger).
661-
// Only reject when grammar is actively constraining tokens.
662-
if (gsmpl->grmr && llama_sampler_grammar_is_active(gsmpl->grmr)) {
681+
if (common_sampler_has_speculative_unsafe_grammar(gsmpl)) {
663682
return false;
664683
}
665-
if (common_reasoning_budget_get_state(gsmpl->rbudget) == REASONING_BUDGET_FORCING) {
666-
return false;
684+
if (common_reasoning_budget_get_state(gsmpl->rbudget) != REASONING_BUDGET_FORCING) {
685+
return true;
667686
}
668-
return true;
687+
return false;
669688
}
670689

671690
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
@@ -682,6 +701,10 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
682701

683702
result.push_back(id);
684703

704+
if (common_sampler_blocks_speculative(gsmpl)) {
705+
break;
706+
}
707+
685708
if (draft[i] != id) {
686709
break;
687710
}
@@ -760,6 +783,10 @@ std::vector<llama_token> common_sampler_sample_reduced_and_accept_n(
760783
common_sampler_accept(gsmpl, id, true);
761784
result.push_back(id);
762785

786+
if (common_sampler_blocks_speculative(gsmpl)) {
787+
break;
788+
}
789+
763790
if (draft[i] != id) {
764791
break;
765792
}

common/sampling.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
6767
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
6868

6969
bool common_sampler_supports_reduced(struct common_sampler * gsmpl);
70+
bool common_sampler_blocks_speculative(const struct common_sampler * gsmpl);
7071

7172
// generalized version of common_sampler_sample
7273
//

common/speculative.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ struct common_speculative_state {
255255

256256
// called after verification decode with logits still in ctx
257257
// batch_tokens: tokens that were in the batch [id_last, draft0, draft1, ...]
258-
// n_accepted: how many were accepted (ids.size(), including the bonus token)
258+
// n_accepted: number of decoded batch rows to commit (root + accepted draft tokens)
259259
virtual void update_logits(llama_context * /*ctx*/, const llama_tokens & /*batch_tokens*/, int /*n_accepted*/) {}
260260

261261
// tree variant: accept specific capture-buffer indices instead of a contiguous block.
@@ -2106,9 +2106,9 @@ struct common_speculative_state_dflash : public common_speculative_state {
21062106
void update_logits(llama_context * ctx, const llama_tokens & batch_tokens, int n_accepted) override {
21072107
GGML_UNUSED(ctx);
21082108
GGML_UNUSED(batch_tokens);
2109-
// n_accepted includes the bonus token: [id_last, draft0, ..., draftN-1] → accepted count
2110-
// the verification batch had (1 + n_draft) tokens
2111-
// only the first n_accepted tokens' hidden states should be kept
2109+
// In this path n_accepted means committed hidden-state rows, not output-token count.
2110+
// [id_last, draft0, ..., draftN-1] => root + accepted draft tokens.
2111+
// Boundary stops pass root + accepted draft tokens even when no bonus token was sampled.
21122112
append_target_hiddens(n_accepted);
21132113
}
21142114

0 commit comments

Comments
 (0)