Skip to content

Commit d775992

Browse files
authored
common : do not pass prompt tokens to reasoning budget sampler (#22488)
1 parent 41a63be commit d775992

4 files changed

Lines changed: 38 additions & 70 deletions

File tree

common/reasoning-budget.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -232,34 +232,6 @@ static struct llama_sampler * common_reasoning_budget_init_state(
232232
);
233233
}
234234

235-
struct llama_sampler * common_reasoning_budget_init(
236-
const struct llama_vocab * vocab,
237-
const std::vector<llama_token> & start_tokens,
238-
const std::vector<llama_token> & end_tokens,
239-
const std::vector<llama_token> & forced_tokens,
240-
int32_t budget,
241-
const std::vector<llama_token> & prefill_tokens) {
242-
// Determine initial state from prefill: COUNTING if the prefill begins with
243-
// the start sequence but does not also contain the end sequence after it.
244-
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
245-
if (!prefill_tokens.empty() && !start_tokens.empty() &&
246-
prefill_tokens.size() >= start_tokens.size() &&
247-
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
248-
initial_state = REASONING_BUDGET_COUNTING;
249-
// If the end sequence also follows the start in the prefill, reasoning
250-
// was opened and immediately closed — stay IDLE.
251-
if (!end_tokens.empty() &&
252-
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
253-
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
254-
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
255-
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
256-
initial_state = REASONING_BUDGET_IDLE;
257-
}
258-
}
259-
}
260-
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
261-
}
262-
263235
struct llama_sampler * common_reasoning_budget_init(
264236
const struct llama_vocab * vocab,
265237
const std::vector<llama_token> & start_tokens,

common/reasoning-budget.h

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,14 @@ enum common_reasoning_budget_state {
2929
// end_tokens - token sequence for natural deactivation
3030
// forced_tokens - token sequence forced when budget expires
3131
// budget - max tokens allowed in the reasoning block
32-
// prefill_tokens - tokens already present in the prompt (generation prompt);
33-
// used to determine the initial state: COUNTING if they begin
34-
// with start_tokens (but don't also end with end_tokens),
35-
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
32+
// initial_state - initial state
3633
//
3734
struct llama_sampler * common_reasoning_budget_init(
3835
const struct llama_vocab * vocab,
3936
const std::vector<llama_token> & start_tokens,
4037
const std::vector<llama_token> & end_tokens,
4138
const std::vector<llama_token> & forced_tokens,
4239
int32_t budget,
43-
const std::vector<llama_token> & prefill_tokens = {});
44-
45-
// Variant that takes an explicit initial state (used by tests and clone).
46-
// COUNTING with budget <= 0 is promoted to FORCING.
47-
struct llama_sampler * common_reasoning_budget_init(
48-
const struct llama_vocab * vocab,
49-
const std::vector<llama_token> & start_tokens,
50-
const std::vector<llama_token> & end_tokens,
51-
const std::vector<llama_token> & forced_tokens,
52-
int32_t budget,
53-
common_reasoning_budget_state initial_state);
40+
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE);
5441

5542
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);

common/sampling.cpp

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -260,32 +260,35 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
260260
}
261261
}
262262

263-
// Feed generation prompt tokens to the grammar sampler so it advances past
264-
// tokens the template already placed in the prompt.
265-
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
263+
// Compute prefill tokens from the generation prompt
266264
std::vector<llama_token> prefill_tokens;
267-
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
265+
if (!params.generation_prompt.empty()) {
268266
GGML_ASSERT(vocab != nullptr);
269-
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
270-
if (!prefill_tokens.empty()) {
271-
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
272-
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
273-
// Some tokenizers will add a space before the first special token, need to remove
274-
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
267+
auto tokens = common_tokenize(vocab, params.generation_prompt, false, true);
268+
for (size_t i = 0; i < tokens.size(); i++) {
269+
std::string piece = common_token_to_piece(vocab, tokens[i], true);
270+
if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) {
271+
// Some tokenizers will add a space before the first special token, need to exclude
272+
continue;
275273
}
274+
LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str());
275+
prefill_tokens.push_back(tokens[i]);
276276
}
277+
}
277278

278-
if (grmr && !params.grammar_lazy) {
279-
try {
280-
for (const auto & token : prefill_tokens) {
281-
llama_sampler_accept(grmr, token);
282-
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
283-
}
284-
} catch (std::exception &e) {
285-
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
286-
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
287-
throw e;
279+
// Feed generation prompt tokens to the grammar sampler so it advances past
280+
// tokens the template already placed in the prompt.
281+
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
282+
if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) {
283+
try {
284+
for (const auto & token : prefill_tokens) {
285+
llama_sampler_accept(grmr, token);
286+
LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token);
288287
}
288+
} catch (std::exception &e) {
289+
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
290+
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
291+
throw e;
289292
}
290293
}
291294

@@ -296,8 +299,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
296299
params.reasoning_budget_start,
297300
params.reasoning_budget_end,
298301
params.reasoning_budget_forced,
299-
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
300-
prefill_tokens);
302+
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens);
303+
304+
for (const auto & token : prefill_tokens) {
305+
llama_sampler_accept(rbudget, token);
306+
LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token);
307+
}
301308
}
302309

303310
if (params.has_logit_bias()) {
@@ -431,17 +438,19 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) {
431438
return true;
432439
}
433440

434-
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
441+
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated) {
435442
if (!gsmpl) {
436443
return;
437444
}
438445

439446
const auto tm = gsmpl->tm();
440447

441448
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
442-
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
449+
const auto accept_grammar = is_generated && grammar_should_apply(gsmpl);
443450

444-
llama_sampler_accept(gsmpl->rbudget, token);
451+
if (gsmpl->rbudget && is_generated) {
452+
llama_sampler_accept(gsmpl->rbudget, token);
453+
}
445454

446455
if (gsmpl->grmr && accept_grammar) {
447456
llama_sampler_accept(gsmpl->grmr, token);

common/sampling.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
4141

4242
void common_sampler_free(struct common_sampler * gsmpl);
4343

44-
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
45-
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
44+
// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler
45+
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated);
4646
void common_sampler_reset (struct common_sampler * gsmpl);
4747
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
4848

0 commit comments

Comments
 (0)