@@ -167,16 +167,14 @@ struct common_speculative_checkpoint {
167167 size_t size () const {
168168 return data.size ();
169169 }
170-
171- size_t ckpt_size = 0 ;
172170};
173171
174172struct common_speculative_state_draft : public common_speculative_state {
175173 llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
176174 llama_context * ctx_dft;
177175
178176 bool use_ckpt = false ;
179- struct common_speculative_checkpoint ckpt;
177+ common_speculative_checkpoint ckpt;
180178
181179 common_sampler * smpl;
182180
@@ -249,26 +247,16 @@ struct common_speculative_state_draft : public common_speculative_state {
249247 llama_batch_free (batch);
250248 }
251249
252- void begin (const llama_tokens & prompt) override {
253- if (use_ckpt && ckpt.size () > 0 ) {
254- // delete checkpoint
255- LOG_DBG (" %s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 " , size=%.3f MiB\n " ,
256- __func__, prompt.size (), ckpt.pos_min , ckpt.pos_max , ckpt.n_tokens , (float ) ckpt.data .size () / 1024 / 1024 );
257- ckpt.pos_min = 0 ;
258- ckpt.pos_max = 0 ;
259- ckpt.n_tokens = 0 ;
260- ckpt.ckpt_size = 0 ;
261- ckpt.data .clear ();
262- }
250+ void begin (const llama_tokens & /* prompt*/ ) override {
263251 }
264252
265- size_t draft_create_checkpoint (int n_tokens_prompt, int n_tokens_batch ) {
253+ size_t create_checkpoint (int n_tokens_prompt) {
266254 int slot_id = 0 ;
267255 const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
268256
269257 ckpt.pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx_dft), slot_id);
270258 ckpt.pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_dft), slot_id);
271- ckpt.n_tokens = n_tokens_prompt - n_tokens_batch ;
259+ ckpt.n_tokens = n_tokens_prompt;
272260 ckpt.data .resize (checkpoint_size);
273261
274262 const size_t n = llama_state_seq_get_data_ext (ctx_dft, ckpt.data .data (), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
@@ -281,13 +269,13 @@ struct common_speculative_state_draft : public common_speculative_state {
281269 return n;
282270 }
283271
284- size_t draft_restore_checkpoint ( size_t ckpt_size_part_expected ) {
272+ size_t restore_checkpoint ( ) {
285273 int slot_id = 0 ;
286274 LOG_DBG (" %s: pos_min = %d, pos_max = %d\n " , __func__, ckpt.pos_min , ckpt.pos_max );
287275 const size_t n = llama_state_seq_set_data_ext (ctx_dft, ckpt.data .data (), ckpt.size (), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
288- if (n != ckpt_size_part_expected ) {
289- GGML_ABORT (" %s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu " ,
290- __func__, ckpt.pos_min , ckpt.pos_max , ckpt.size (), ckpt_size_part_expected, n );
276+ if (n != ckpt. size () ) {
277+ GGML_ABORT (" %s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu" ,
278+ __func__, ckpt.pos_min , ckpt.pos_max , ckpt.size ());
291279 }
292280 llama_memory_seq_rm (llama_get_memory (ctx_dft), slot_id, ckpt.pos_max + 1 , -1 );
293281
@@ -346,35 +334,45 @@ struct common_speculative_state_draft : public common_speculative_state {
346334
347335 const int i_start = std::max<int >(0 , (int ) prompt_cur.size () - n_ctx);
348336
337+ if (use_ckpt && i_start > 0 ) {
338+ LOG_WRN (" %s: context shift is not supported with checkpoint-based contexts - skipping\n " , __func__);
339+ return ;
340+ }
341+
349342 // reuse as much as possible from the old draft context
350343 // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
351344 for (int i = 0 ; i < (int ) prompt_dft.size (); ++i) {
352345 int cur = 0 ;
353346 while (i_start + cur < (int ) prompt_cur.size () &&
354- i + cur < (int ) prompt_dft.size () &&
355- prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
347+ i + cur < (int ) prompt_dft.size () &&
348+ prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
356349 cur++;
357350 }
358351
359352 if ((cur >= 256 || n_ctx >= (int ) prompt_cur.size ()) && cur > reuse_n) {
360353 reuse_i = i;
361354 reuse_n = cur;
362355 }
356+
357+ if (use_ckpt) {
358+ break ;
359+ }
363360 }
364361
365362 LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n " ,
366363 __func__, reuse_i, reuse_n, prompt_dft.size (), prompt_cur.size ());
367- if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0 ) {
368- LOG_DBG (" %s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n= %d) -> (0, 0) \n " ,
369- __func__, reuse_i, reuse_n);
364+ if (use_ckpt && ckpt.n_tokens > reuse_n ) {
365+ LOG_DBG (" %s: checkpoint (n_tokens = %d) is outdated -> delete it \n " , __func__, ( int ) ckpt. n_tokens );
366+
370367 reuse_i = 0 ;
371368 reuse_n = 0 ;
369+
370+ ckpt = {};
372371 }
373372
374373 result.clear ();
375374 result.reserve (sparams.n_max );
376375
377- bool needs_ckpt = use_ckpt && prompt_dft.size () > 0 ;
378376 if (reuse_n == 0 || (use_ckpt && reuse_i > 0 )) {
379377 llama_memory_clear (mem_dft, false );
380378 prompt_dft.clear ();
@@ -393,50 +391,38 @@ struct common_speculative_state_draft : public common_speculative_state {
393391 return ;
394392 }
395393
396- bool do_restore = false ;
397- if (prompt_dft.size () > prompt_cur.size () && reuse_i + reuse_n < (int64_t ) prompt_dft.size ()) {
398- // This can happen after a partial acceptance (speculative decoding with checkpoints)
399- LOG_DBG (" %s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n " ,
400- __func__, prompt_dft.size (), prompt_cur.size ());
401- prompt_dft.resize (prompt_cur.size ());
402- do_restore = true ;
403- }
404-
405394 if (reuse_i > 0 ) {
395+ GGML_ASSERT (!use_ckpt);
396+
406397 bool is_removed = llama_memory_seq_rm (mem_dft, 0 , 0 , reuse_i);
407398 if (!is_removed) {
408399 LOG_ERR (" %s: llama_memory_seq_rm failed, reuse_i=%d\n " , __func__, reuse_i);
400+ return ;
409401 }
410402 llama_memory_seq_add (mem_dft, 0 , reuse_i, -1 , -reuse_i);
411403
412404 prompt_dft.erase (prompt_dft.begin (), prompt_dft.begin () + reuse_i);
413405 }
414406
415- if (reuse_n < (int ) prompt_dft.size () || do_restore ) {
407+ if (reuse_n < (int ) prompt_dft.size ()) {
416408 if (use_ckpt) {
417- if (ckpt.n_tokens > (int64_t ) prompt_dft.size ()) {
418- LOG_INF (" %s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 " , reuse_n=%d, prompt_dft.size=%zu\n " ,
419- __func__, prompt_tgt.size (), ckpt.n_tokens , reuse_n, prompt_dft.size ());
409+ if (ckpt.n_tokens > 0 ) {
410+ LOG_DBG (" %s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n " , __func__, reuse_n, prompt_dft.size ());
411+ restore_checkpoint ();
412+ reuse_n = ckpt.n_tokens ;
413+ prompt_dft.resize (reuse_n);
420414 }
421- draft_restore_checkpoint (ckpt.ckpt_size );
422- reuse_n = ckpt.n_tokens ;
423- prompt_dft.resize (reuse_n);
424- needs_ckpt = false ;
425415 } else {
426- bool is_removed = llama_memory_seq_rm (mem_dft, 0 , reuse_n, -1 );
416+ const bool is_removed = llama_memory_seq_rm (mem_dft, 0 , reuse_n, -1 );
427417 if (!is_removed) {
428- LOG_ERR (" %s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n " ,
429- __func__, reuse_n, prompt_dft. size ()) ;
418+ LOG_ERR (" %s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n " , __func__, reuse_n, prompt_dft. size ());
419+ return ;
430420 }
431421 prompt_dft.erase (prompt_dft.begin () + reuse_n, prompt_dft.end ());
432422 }
433423 }
434424 }
435425
436- if (needs_ckpt) {
437- ckpt.ckpt_size = draft_create_checkpoint (prompt_dft.size (), batch.n_tokens );
438- }
439-
440426 // prepare a batch to evaluate any new tokens in the prompt
441427 common_batch_clear (batch);
442428
@@ -450,12 +436,17 @@ struct common_speculative_state_draft : public common_speculative_state {
450436 // we should rarely end-up here during normal decoding
451437 if (batch.n_tokens > 0 ) {
452438 // LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
439+ LOG_DBG (" %s: draft prompt batch: %d tokens\n " , __func__, batch.n_tokens );
453440
454441 int ret = llama_decode (ctx_dft, batch);
455442 if (ret != 0 && ret != 1 ) {
456443 LOG_WRN (" %s: llama_decode returned %d, prompt_cur.size=%zu\n " ,
457444 __func__, ret, prompt_cur.size ());
458445 }
446+
447+ if (use_ckpt) {
448+ create_checkpoint (prompt_dft.size ());
449+ }
459450 }
460451
461452 const llama_pos n_past = prompt_dft.size ();
@@ -784,17 +775,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
784775 }
785776
786777 void accept (uint16_t n_accepted) override {
787- if (verbose) {
788- LOG_INF (" %s: accepted %d tokens from %zu drafted tokens\n " , __func__, n_accepted, n_draft_last);
789- }
790-
791778 // compute acceptance fraction if we have a recorded draft length
792779 if (n_draft_last > 0 ) {
793780 const double f_acc = (double )n_accepted / (double )n_draft_last;
794781 if (f_acc < 0.5 ) {
795782 n_low++;
796783 if (n_low >= 3 ) {
797- LOG_WRN (" %s: low acceptance streak (%d) – resetting ngram_mod\n " , __func__, n_low);
784+ if (verbose) {
785+ LOG_WRN (" %s: low acceptance streak (%d) – resetting ngram_mod\n " , __func__, n_low);
786+ }
798787
799788 mod.reset ();
800789 n_low = 0 ;
0 commit comments