Skip to content

Commit 1c5d208

Browse files
committed
fix: defer common_sampler_accept until after p_accept resolution
Fixes sampler state bug identified by Ooooze - previously common_sampler_accept was called with target id before p_accept check, leaving grammar FSM and gsmpl->prev tracking wrong token when draft token was substituted.
1 parent 12706d0 commit 1c5d208

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

common/sampling.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,6 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
614614
size_t i = 0;
615615
for (; i < draft.size(); i++) {
616616
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
617-
common_sampler_accept(gsmpl, id, true);
618-
result.push_back(id);
619617
if (draft[i] != id) {
620618
if (p_accept > 0.0f) {
621619
const float * logits = llama_get_logits_ith(ctx, idxs[i]);
@@ -625,12 +623,17 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
625623
for (int j = 0; j < n_vocab; j++) sum += expf(logits[j] - max_l);
626624
const float p_main = expf(logits[draft[i]] - max_l) / sum;
627625
if (p_main >= p_accept) {
628-
result.back() = draft[i];
626+
common_sampler_accept(gsmpl, draft[i], true);
627+
result.push_back(draft[i]);
629628
continue;
630629
}
631630
}
631+
common_sampler_accept(gsmpl, id, true);
632+
result.push_back(id);
632633
break;
633634
}
635+
common_sampler_accept(gsmpl, id, true);
636+
result.push_back(id);
634637
}
635638

636639
if (i == draft.size()) {

0 commit comments

Comments
 (0)