@@ -124,6 +124,66 @@ static void test_reasoning_budget(
124124 (void )sequence;
125125}
126126
127+ static llama_token get_forced_token (struct llama_sampler * sampler, llama_token max_token) {
128+ std::vector<llama_token_data> cur;
129+ const size_t n_vocab = (size_t ) max_token + 1 ;
130+ for (size_t i = 0 ; i < n_vocab; i++) {
131+ cur.emplace_back (llama_token_data{(llama_token) i, logf ((float ) (i + 1 )), 0 .0f });
132+ }
133+
134+ llama_token_data_array cur_p = { cur.data (), cur.size (), -1 , false };
135+ llama_sampler_apply (sampler, &cur_p);
136+
137+ size_t finite_count = 0 ;
138+ llama_token finite_token = LLAMA_TOKEN_NULL ;
139+ for (size_t i = 0 ; i < cur.size (); i++) {
140+ if (std::isfinite (cur[i].logit )) {
141+ finite_count++;
142+ finite_token = cur[i].id ;
143+ }
144+ }
145+
146+ GGML_ASSERT (finite_count == 1 && " sampler is not forcing exactly one token" );
147+ return finite_token;
148+ }
149+
150+ static void test_reasoning_budget_clone_mid_counting () {
151+ const std::vector<llama_token> start = {100 };
152+ const std::vector<llama_token> end = {101 };
153+ const std::vector<llama_token> forced = {102 , 101 };
154+
155+ auto * sampler = common_reasoning_budget_init (nullptr , start, end, forced, 2 , REASONING_BUDGET_IDLE );
156+
157+ llama_sampler_accept (sampler, 100 ); // COUNTING, remaining=2
158+ llama_sampler_accept (sampler, 50 ); // COUNTING, remaining=1
159+
160+ auto * clone = llama_sampler_clone (sampler);
161+ llama_sampler_accept (clone, 51 ); // should exhaust the cloned remaining budget
162+
163+ GGML_ASSERT (get_forced_token (clone, 102 ) == 102 && " cloned counting state lost remaining budget" );
164+
165+ llama_sampler_free (clone);
166+ llama_sampler_free (sampler);
167+ }
168+
169+ static void test_reasoning_budget_clone_mid_forcing () {
170+ const std::vector<llama_token> start = {100 };
171+ const std::vector<llama_token> end = {101 };
172+ const std::vector<llama_token> forced = {102 , 101 };
173+
174+ auto * sampler = common_reasoning_budget_init (nullptr , start, end, forced, 0 , REASONING_BUDGET_FORCING );
175+
176+ GGML_ASSERT (get_forced_token (sampler, 102 ) == 102 );
177+ llama_sampler_accept (sampler, 102 ); // advance to the second forced token
178+
179+ auto * clone = llama_sampler_clone (sampler);
180+
181+ GGML_ASSERT (get_forced_token (clone, 102 ) == 101 && " cloned forcing state lost force position" );
182+
183+ llama_sampler_free (clone);
184+ llama_sampler_free (sampler);
185+ }
186+
127187// UTF-8 boundary detection unit test
128188// Tests common_utf8_is_complete() from reasoning-budget.h
129189static void test_utf8_boundary_detection () {
@@ -250,7 +310,10 @@ int main(void) {
250310 7 ); // forcing continues through i=7
251311 }
252312
253- printf (" OK (6 tests passed)\n " );
313+ test_reasoning_budget_clone_mid_counting ();
314+ test_reasoning_budget_clone_mid_forcing ();
315+
316+ printf (" OK (8 tests passed)\n " );
254317
255318 printf (" Testing UTF-8 boundary detection... " );
256319 test_utf8_boundary_detection ();
0 commit comments