Skip to content

Commit 80afa33

Browse files
authored
spec : fix draft model checkpoints (#22521)
* spec : fix draft model checkpoints * cont : clean-up * cont : gate the ngram-mod reset warning behind verbose flag
1 parent b42c7fa commit 80afa33

2 files changed

Lines changed: 62 additions & 59 deletions

File tree

common/speculative.cpp

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,14 @@ struct common_speculative_checkpoint {
167167
size_t size() const {
168168
return data.size();
169169
}
170-
171-
size_t ckpt_size = 0;
172170
};
173171

174172
struct common_speculative_state_draft : public common_speculative_state {
175173
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
176174
llama_context * ctx_dft;
177175

178176
bool use_ckpt = false;
179-
struct common_speculative_checkpoint ckpt;
177+
common_speculative_checkpoint ckpt;
180178

181179
common_sampler * smpl;
182180

@@ -249,26 +247,16 @@ struct common_speculative_state_draft : public common_speculative_state {
249247
llama_batch_free(batch);
250248
}
251249

252-
void begin(const llama_tokens & prompt) override {
253-
if (use_ckpt && ckpt.size() > 0) {
254-
// delete checkpoint
255-
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
256-
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
257-
ckpt.pos_min = 0;
258-
ckpt.pos_max = 0;
259-
ckpt.n_tokens = 0;
260-
ckpt.ckpt_size = 0;
261-
ckpt.data.clear();
262-
}
250+
void begin(const llama_tokens & /*prompt*/) override {
263251
}
264252

265-
size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
253+
size_t create_checkpoint(int n_tokens_prompt) {
266254
int slot_id = 0;
267255
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
268256

269257
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
270258
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
271-
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
259+
ckpt.n_tokens = n_tokens_prompt;
272260
ckpt.data.resize(checkpoint_size);
273261

274262
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
@@ -281,13 +269,13 @@ struct common_speculative_state_draft : public common_speculative_state {
281269
return n;
282270
}
283271

284-
size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
272+
size_t restore_checkpoint() {
285273
int slot_id = 0;
286274
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
287275
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
288-
if (n != ckpt_size_part_expected) {
289-
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
290-
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
276+
if (n != ckpt.size()) {
277+
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
278+
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
291279
}
292280
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
293281

@@ -346,35 +334,45 @@ struct common_speculative_state_draft : public common_speculative_state {
346334

347335
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
348336

337+
if (use_ckpt && i_start > 0) {
338+
LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__);
339+
return;
340+
}
341+
349342
// reuse as much as possible from the old draft context
350343
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
351344
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
352345
int cur = 0;
353346
while (i_start + cur < (int) prompt_cur.size() &&
354-
i + cur < (int) prompt_dft.size() &&
355-
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
347+
i + cur < (int) prompt_dft.size() &&
348+
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
356349
cur++;
357350
}
358351

359352
if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
360353
reuse_i = i;
361354
reuse_n = cur;
362355
}
356+
357+
if (use_ckpt) {
358+
break;
359+
}
363360
}
364361

365362
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
366363
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
367-
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
368-
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
369-
__func__, reuse_i, reuse_n);
364+
if (use_ckpt && ckpt.n_tokens > reuse_n) {
365+
LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens);
366+
370367
reuse_i = 0;
371368
reuse_n = 0;
369+
370+
ckpt = {};
372371
}
373372

374373
result.clear();
375374
result.reserve(sparams.n_max);
376375

377-
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
378376
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
379377
llama_memory_clear(mem_dft, false);
380378
prompt_dft.clear();
@@ -393,50 +391,38 @@ struct common_speculative_state_draft : public common_speculative_state {
393391
return;
394392
}
395393

396-
bool do_restore = false;
397-
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
398-
// This can happen after a partial acceptance (speculative decoding with checkpoints)
399-
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
400-
__func__, prompt_dft.size(), prompt_cur.size());
401-
prompt_dft.resize(prompt_cur.size());
402-
do_restore = true;
403-
}
404-
405394
if (reuse_i > 0) {
395+
GGML_ASSERT(!use_ckpt);
396+
406397
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
407398
if (!is_removed) {
408399
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
400+
return;
409401
}
410402
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
411403

412404
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
413405
}
414406

415-
if (reuse_n < (int) prompt_dft.size() || do_restore) {
407+
if (reuse_n < (int) prompt_dft.size()) {
416408
if (use_ckpt) {
417-
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
418-
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
419-
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
409+
if (ckpt.n_tokens > 0) {
410+
LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
411+
restore_checkpoint();
412+
reuse_n = ckpt.n_tokens;
413+
prompt_dft.resize(reuse_n);
420414
}
421-
draft_restore_checkpoint(ckpt.ckpt_size);
422-
reuse_n = ckpt.n_tokens;
423-
prompt_dft.resize(reuse_n);
424-
needs_ckpt = false;
425415
} else {
426-
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
416+
const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1);
427417
if (!is_removed) {
428-
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
429-
__func__, reuse_n, prompt_dft.size());
418+
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
419+
return;
430420
}
431421
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
432422
}
433423
}
434424
}
435425

436-
if (needs_ckpt) {
437-
ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens);
438-
}
439-
440426
// prepare a batch to evaluate any new tokens in the prompt
441427
common_batch_clear(batch);
442428

@@ -450,12 +436,17 @@ struct common_speculative_state_draft : public common_speculative_state {
450436
// we should rarely end-up here during normal decoding
451437
if (batch.n_tokens > 0) {
452438
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
439+
LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens);
453440

454441
int ret = llama_decode(ctx_dft, batch);
455442
if (ret != 0 && ret != 1) {
456443
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
457444
__func__, ret, prompt_cur.size());
458445
}
446+
447+
if (use_ckpt) {
448+
create_checkpoint(prompt_dft.size());
449+
}
459450
}
460451

461452
const llama_pos n_past = prompt_dft.size();
@@ -784,17 +775,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
784775
}
785776

786777
void accept(uint16_t n_accepted) override {
787-
if (verbose) {
788-
LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last);
789-
}
790-
791778
// compute acceptance fraction if we have a recorded draft length
792779
if (n_draft_last > 0) {
793780
const double f_acc = (double)n_accepted / (double)n_draft_last;
794781
if (f_acc < 0.5) {
795782
n_low++;
796783
if (n_low >= 3) {
797-
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
784+
if (verbose) {
785+
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
786+
}
798787

799788
mod.reset();
800789
n_low = 0;

tools/server/server-context.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ struct server_context_impl {
680680
// slots / clients
681681
std::vector<server_slot> slots;
682682

683+
int trace = 0;
683684
int slots_debug = 0;
684685
int n_empty_consecutive = 0;
685686

@@ -918,12 +919,21 @@ struct server_context_impl {
918919
slot.reset();
919920
}
920921

922+
{
923+
const char * LLAMA_TRACE = getenv("LLAMA_TRACE");
924+
trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0;
925+
926+
if (trace) {
927+
SRV_WRN("LLAMA_TRACE = %d\n", trace);
928+
}
929+
}
930+
921931
{
922932
const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
923933
slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
924934

925935
if (slots_debug) {
926-
SRV_WRN("slots debug = %d\n", slots_debug);
936+
SRV_WRN("LLAMA_SERVER_SLOTS_DEBUG = %d\n", slots_debug);
927937
}
928938
}
929939

@@ -2974,13 +2984,15 @@ struct server_context_impl {
29742984
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
29752985
slot.spec_i_batch.clear();
29762986

2977-
SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size());
2978-
29792987
GGML_ASSERT(accepted.size() >= 1);
29802988

29812989
// check for partial draft acceptance
29822990
if (accepted.size() < slot.spec_draft.size() + 1) {
29832991
if (use_ckpt) {
2992+
if (trace > 0) {
2993+
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
2994+
}
2995+
29842996
// partial acceptance is not supported by the context -> truncate the draft and restore the state
29852997
slot.spec_draft = std::move(accepted);
29862998

@@ -3002,8 +3014,10 @@ struct server_context_impl {
30023014

30033015
continue;
30043016
}
3017+
}
30053018

3006-
LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size());
3019+
if (trace > 0) {
3020+
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
30073021
}
30083022

30093023
common_speculative_accept(slot.spec.get(), accepted.size() - 1);

0 commit comments

Comments
 (0)