Skip to content

Commit 1856e35

Browse files
committed
reasoning-budget: clone should do a deep-copy
1 parent d528444 commit 1856e35

2 files changed

Lines changed: 74 additions & 12 deletions

File tree

common/reasoning-budget.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,12 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
171171
ctx->force_pos = 0;
172172
}
173173

174-
// forward declaration for use in clone
175174
static struct llama_sampler * common_reasoning_budget_init_state(
176175
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
177176
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
178177
int32_t budget, common_reasoning_budget_state initial_state);
179178

180-
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
181-
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
182-
return common_reasoning_budget_init_state(
183-
ctx->vocab,
184-
ctx->start_matcher.tokens,
185-
ctx->end_matcher.tokens,
186-
ctx->forced_tokens,
187-
ctx->budget,
188-
ctx->state);
189-
}
179+
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl);
190180

191181
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
192182
delete (common_reasoning_budget_ctx *) smpl->ctx;
@@ -205,6 +195,15 @@ static struct llama_sampler_i common_reasoning_budget_i = {
205195
/* .backend_set_input = */ nullptr,
206196
};
207197

198+
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
199+
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
200+
201+
return llama_sampler_init(
202+
/* .iface = */ &common_reasoning_budget_i,
203+
/* .ctx = */ new common_reasoning_budget_ctx(*ctx)
204+
);
205+
}
206+
208207
static struct llama_sampler * common_reasoning_budget_init_state(
209208
const struct llama_vocab * vocab,
210209
const std::vector<llama_token> & start_tokens,

tests/test-reasoning-budget.cpp

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,66 @@ static void test_reasoning_budget(
124124
(void)sequence;
125125
}
126126

127+
static llama_token get_forced_token(struct llama_sampler * sampler, llama_token max_token) {
128+
std::vector<llama_token_data> cur;
129+
const size_t n_vocab = (size_t) max_token + 1;
130+
for (size_t i = 0; i < n_vocab; i++) {
131+
cur.emplace_back(llama_token_data{(llama_token) i, logf((float) (i + 1)), 0.0f});
132+
}
133+
134+
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
135+
llama_sampler_apply(sampler, &cur_p);
136+
137+
size_t finite_count = 0;
138+
llama_token finite_token = LLAMA_TOKEN_NULL;
139+
for (size_t i = 0; i < cur.size(); i++) {
140+
if (std::isfinite(cur[i].logit)) {
141+
finite_count++;
142+
finite_token = cur[i].id;
143+
}
144+
}
145+
146+
GGML_ASSERT(finite_count == 1 && "sampler is not forcing exactly one token");
147+
return finite_token;
148+
}
149+
150+
static void test_reasoning_budget_clone_mid_counting() {
151+
const std::vector<llama_token> start = {100};
152+
const std::vector<llama_token> end = {101};
153+
const std::vector<llama_token> forced = {102, 101};
154+
155+
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 2, REASONING_BUDGET_IDLE);
156+
157+
llama_sampler_accept(sampler, 100); // COUNTING, remaining=2
158+
llama_sampler_accept(sampler, 50); // COUNTING, remaining=1
159+
160+
auto * clone = llama_sampler_clone(sampler);
161+
llama_sampler_accept(clone, 51); // should exhaust the cloned remaining budget
162+
163+
GGML_ASSERT(get_forced_token(clone, 102) == 102 && "cloned counting state lost remaining budget");
164+
165+
llama_sampler_free(clone);
166+
llama_sampler_free(sampler);
167+
}
168+
169+
static void test_reasoning_budget_clone_mid_forcing() {
170+
const std::vector<llama_token> start = {100};
171+
const std::vector<llama_token> end = {101};
172+
const std::vector<llama_token> forced = {102, 101};
173+
174+
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 0, REASONING_BUDGET_FORCING);
175+
176+
GGML_ASSERT(get_forced_token(sampler, 102) == 102);
177+
llama_sampler_accept(sampler, 102); // advance to the second forced token
178+
179+
auto * clone = llama_sampler_clone(sampler);
180+
181+
GGML_ASSERT(get_forced_token(clone, 102) == 101 && "cloned forcing state lost force position");
182+
183+
llama_sampler_free(clone);
184+
llama_sampler_free(sampler);
185+
}
186+
127187
// UTF-8 boundary detection unit test
128188
// Tests common_utf8_is_complete() from reasoning-budget.h
129189
static void test_utf8_boundary_detection() {
@@ -250,7 +310,10 @@ int main(void) {
250310
7); // forcing continues through i=7
251311
}
252312

253-
printf("OK (6 tests passed)\n");
313+
test_reasoning_budget_clone_mid_counting();
314+
test_reasoning_budget_clone_mid_forcing();
315+
316+
printf("OK (8 tests passed)\n");
254317

255318
printf("Testing UTF-8 boundary detection... ");
256319
test_utf8_boundary_detection();

0 commit comments

Comments
 (0)