1313#include < cstring>
1414#include < iomanip>
1515#include < map>
16+ #include < cinttypes>
1617
1718#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
1819#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -144,10 +145,28 @@ struct common_speculative_state {
144145 virtual void accept (uint16_t n_accepted) = 0;
145146};
146147
148+ struct common_speculative_checkpoint {
149+ llama_pos pos_min = 0 ;
150+ llama_pos pos_max = 0 ;
151+
152+ int64_t n_tokens = 0 ;
153+
154+ std::vector<uint8_t > data;
155+
156+ size_t size () const {
157+ return data.size ();
158+ }
159+
160+ size_t ckpt_size = 0 ;
161+ };
162+
147163struct common_speculative_state_draft : public common_speculative_state {
148164 llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
149165 llama_context * ctx_dft;
150166
167+ struct common_speculative_checkpoint ckpt;
168+ bool use_checkpoint;
169+
151170 common_sampler * smpl;
152171
153172 llama_batch batch;
@@ -160,10 +179,12 @@ struct common_speculative_state_draft : public common_speculative_state {
160179 enum common_speculative_type type,
161180 llama_context * ctx_tgt,
162181 llama_context * ctx_dft,
163- const std::vector<std::pair<std::string, std::string>> & replacements)
182+ const std::vector<std::pair<std::string, std::string>> & replacements,
183+ bool use_checkpoint)
164184 : common_speculative_state(type)
165185 , ctx_tgt(ctx_tgt)
166186 , ctx_dft(ctx_dft)
187+ , use_checkpoint(use_checkpoint)
167188 {
168189 batch = llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 );
169190 smpl = nullptr ;
@@ -218,7 +239,48 @@ struct common_speculative_state_draft : public common_speculative_state {
218239 }
219240
220241 void begin (const llama_tokens & prompt) override {
221- GGML_UNUSED (prompt);
242+ if (use_checkpoint && ckpt.size () > 0 ) {
243+ // delete checkpoint
244+ LOG_DBG (" %s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 " , size=%.3f MiB\n " ,
245+ __func__, prompt.size (), ckpt.pos_min , ckpt.pos_max , ckpt.n_tokens , (float ) ckpt.data .size () / 1024 / 1024 );
246+ ckpt.pos_min = 0 ;
247+ ckpt.pos_max = 0 ;
248+ ckpt.n_tokens = 0 ;
249+ ckpt.ckpt_size = 0 ;
250+ ckpt.data .clear ();
251+ }
252+ }
253+
254+ size_t draft_create_checkpoint (int n_tokens_prompt, int n_tokens_batch) {
255+ int slot_id = 0 ;
256+ const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
257+
258+ ckpt.pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx_dft), slot_id);
259+ ckpt.pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_dft), slot_id);
260+ ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
261+ ckpt.data .resize (checkpoint_size);
262+
263+ const size_t n = llama_state_seq_get_data_ext (ctx_dft, ckpt.data .data (), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
264+ if (n != checkpoint_size) {
265+ GGML_ABORT (" checkpoint size mismatch: expected %zu, got %zu\n " , checkpoint_size, n);
266+ }
267+
268+ LOG_DBG (" %s: pos_min = %d, pos_max = %d, size = %.3f MiB\n " , __func__,
269+ ckpt.pos_min , ckpt.pos_max , (float ) ckpt.data .size () / 1024 / 1024 );
270+ return n;
271+ }
272+
273+ size_t draft_restore_checkpoint (size_t ckpt_size_part_expected) {
274+ int slot_id = 0 ;
275+ LOG_DBG (" %s: pos_min = %d, pos_max = %d\n " , __func__, ckpt.pos_min , ckpt.pos_max );
276+ const size_t n = llama_state_seq_set_data_ext (ctx_dft, ckpt.data .data (), ckpt.size (), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
277+ if (n != ckpt_size_part_expected) {
278+ GGML_ABORT (" %s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu" ,
279+ __func__, ckpt.pos_min , ckpt.pos_max , ckpt.size (), ckpt_size_part_expected, n);
280+ }
281+ llama_memory_seq_rm (llama_get_memory (ctx_dft), slot_id, ckpt.pos_max + 1 , -1 );
282+
283+ return n;
222284 }
223285
224286 void draft (
@@ -236,8 +298,8 @@ struct common_speculative_state_draft : public common_speculative_state {
236298
237299 auto * mem_dft = llama_get_memory (ctx_dft);
238300
239- int reuse_i = 0 ;
240- int reuse_n = 0 ;
301+ int reuse_i = 0 ; // index of part to be reused in prompt_dft
302+ int reuse_n = 0 ; // length of part to be reused in prompt_dft
241303
242304 const int n_ctx = llama_n_ctx (ctx_dft) - params.n_max ;
243305
@@ -287,18 +349,26 @@ struct common_speculative_state_draft : public common_speculative_state {
287349 }
288350 }
289351
290- LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n " , __func__, reuse_i, reuse_n, (int ) prompt_dft.size ());
352+ LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n " ,
353+ __func__, reuse_i, reuse_n, prompt_dft.size (), prompt_cur.size ());
354+ if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0 ) {
355+ LOG_DBG (" %s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n " ,
356+ __func__, reuse_i, reuse_n);
357+ reuse_i = 0 ;
358+ reuse_n = 0 ;
359+ }
291360
292361 result.clear ();
293362 result.reserve (params.n_max );
294363
295- if (reuse_n == 0 ) {
364+ bool needs_ckpt = use_checkpoint && prompt_dft.size () > 0 ;
365+ if (reuse_n == 0 || (use_checkpoint && reuse_i > 0 )) {
296366 llama_memory_clear (mem_dft, false );
297367 prompt_dft.clear ();
298368 } else {
299369 // this happens when a previous draft has been discarded (for example, due to being too small), but the
300370 // target model agreed with it. in this case, we simply pass back the previous results to save compute
301- if (reuse_i + reuse_n < (int ) prompt_dft.size () && prompt_dft[reuse_i + reuse_n] == id_last) {
371+ if (reuse_i + reuse_n < (int64_t ) prompt_dft.size () && prompt_dft[reuse_i + reuse_n] == id_last) {
302372 for (int i = reuse_i + reuse_n + 1 ; i < (int ) prompt_dft.size (); ++i) {
303373 result.push_back (prompt_dft[i]);
304374
@@ -310,19 +380,50 @@ struct common_speculative_state_draft : public common_speculative_state {
310380 return ;
311381 }
312382
383+ bool do_restore = false ;
384+ if (prompt_dft.size () > prompt_cur.size () && reuse_i + reuse_n < (int64_t ) prompt_dft.size ()) {
385+ // This can happen after a partial acceptance (speculative decoding with checkpoints)
386+ LOG_DBG (" %s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n " ,
387+ __func__, prompt_dft.size (), prompt_cur.size ());
388+ prompt_dft.resize (prompt_cur.size ());
389+ do_restore = true ;
390+ }
391+
313392 if (reuse_i > 0 ) {
314- llama_memory_seq_rm (mem_dft, 0 , 0 , reuse_i);
393+ bool is_removed = llama_memory_seq_rm (mem_dft, 0 , 0 , reuse_i);
394+ if (!is_removed) {
395+ LOG_ERR (" %s: llama_memory_seq_rm failed, reuse_i=%d\n " , __func__, reuse_i);
396+ }
315397 llama_memory_seq_add (mem_dft, 0 , reuse_i, -1 , -reuse_i);
316398
317399 prompt_dft.erase (prompt_dft.begin (), prompt_dft.begin () + reuse_i);
318400 }
319401
320- if (reuse_n < (int ) prompt_dft.size ()) {
321- llama_memory_seq_rm (mem_dft, 0 , reuse_n, -1 );
322- prompt_dft.erase (prompt_dft.begin () + reuse_n, prompt_dft.end ());
402+ if (reuse_n < (int ) prompt_dft.size () || do_restore) {
403+ if (use_checkpoint) {
404+ if (ckpt.n_tokens > (int64_t ) prompt_dft.size ()) {
405+ LOG_INF (" %s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 " , reuse_n=%d, prompt_dft.size=%zu\n " ,
406+ __func__, prompt_tgt.size (), ckpt.n_tokens , reuse_n, prompt_dft.size ());
407+ }
408+ draft_restore_checkpoint (ckpt.ckpt_size );
409+ reuse_n = ckpt.n_tokens ;
410+ prompt_dft.resize (reuse_n);
411+ needs_ckpt = false ;
412+ } else {
413+ bool is_removed = llama_memory_seq_rm (mem_dft, 0 , reuse_n, -1 );
414+ if (!is_removed) {
415+ LOG_ERR (" %s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n " ,
416+ __func__, reuse_n, prompt_dft.size ());
417+ }
418+ prompt_dft.erase (prompt_dft.begin () + reuse_n, prompt_dft.end ());
419+ }
323420 }
324421 }
325422
423+ if (needs_ckpt) {
424+ ckpt.ckpt_size = draft_create_checkpoint (prompt_dft.size (), batch.n_tokens );
425+ }
426+
326427 // prepare a batch to evaluate any new tokens in the prompt
327428 common_batch_clear (batch);
328429
@@ -337,7 +438,11 @@ struct common_speculative_state_draft : public common_speculative_state {
337438 if (batch.n_tokens > 0 ) {
338439 // LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
339440
340- llama_decode (ctx_dft, batch);
441+ int ret = llama_decode (ctx_dft, batch);
442+ if (ret != 0 && ret != 1 ) {
443+ LOG_WRN (" %s: llama_decode returned %d, prompt_cur.size=%zu\n " ,
444+ __func__, ret, prompt_cur.size ());
445+ }
341446 }
342447
343448 const llama_pos n_past = prompt_dft.size ();
@@ -351,7 +456,11 @@ struct common_speculative_state_draft : public common_speculative_state {
351456
352457 LOG_DBG (" %s: draft prompt: %s\n " , __func__, string_from (ctx_dft, prompt_dft).c_str ());
353458
354- llama_decode (ctx_dft, batch);
459+ int ret = llama_decode (ctx_dft, batch);
460+ if (ret != 0 && ret != 1 ) {
461+ LOG_WRN (" %s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n " ,
462+ __func__, ret, prompt_cur.size (), prompt_dft.size ());
463+ }
355464
356465 common_sampler_reset (smpl);
357466
@@ -387,7 +496,11 @@ struct common_speculative_state_draft : public common_speculative_state {
387496 common_batch_add (batch, id, n_past + i + 1 , { 0 }, true );
388497
389498 // evaluate the drafted tokens on the draft model
390- llama_decode (ctx_dft, batch);
499+ ret = llama_decode (ctx_dft, batch);
500+ if (ret != 0 ) {
501+ LOG_WRN (" %s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n " ,
502+ __func__, i, ret, prompt_cur.size (), prompt_dft.size ());
503+ }
391504
392505 prompt_dft.push_back (id);
393506 }
@@ -739,6 +852,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
739852
740853struct common_speculative {
741854 std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
855+
742856 common_speculative_state * curr_impl = nullptr ; // current implementation in use (for stats)
743857};
744858
@@ -798,13 +912,13 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
798912 return it->second ;
799913}
800914
801- bool common_speculative_is_compat (llama_context * ctx_tgt) {
915+ common_speculative_compat_type common_speculative_is_compat (llama_context * ctx_tgt) {
802916 auto * mem = llama_get_memory (ctx_tgt);
803917 if (mem == nullptr ) {
804- return false ;
918+ return COMMON_SPECULATIVE_COMPAT_TYPE_NO ;
805919 }
806920
807- bool res = true ;
921+ common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL ;
808922
809923 llama_memory_clear (mem, true );
810924
@@ -816,14 +930,14 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {
816930 int ret = llama_decode (ctx_tgt, llama_batch_get_one (tmp.data (), tmp.size ()));
817931 if (ret != 0 ) {
818932 LOG_ERR (" %s: llama_decode() failed: %d\n " , __func__, ret);
819- res = false ;
933+ res = COMMON_SPECULATIVE_COMPAT_TYPE_NO ;
820934 goto done;
821935 }
822936
823937 // try to remove the last tokens
824938 if (!llama_memory_seq_rm (mem, 0 , 1 , -1 )) {
825939 LOG_WRN (" %s: the target context does not support partial sequence removal\n " , __func__);
826- res = false ;
940+ res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT ;
827941 goto done;
828942 }
829943
@@ -909,9 +1023,10 @@ common_speculative * common_speculative_init(
9091023 break ;
9101024 case COMMON_SPECULATIVE_TYPE_DRAFT: {
9111025 impls.push_back (std::make_unique<common_speculative_state_draft>(config.type ,
912- /* .ctx_tgt = */ ctx_tgt,
913- /* .ctx_dft = */ ctx_dft,
914- /* .replacements = */ params.replacements
1026+ /* .ctx_tgt = */ ctx_tgt,
1027+ /* .ctx_dft = */ ctx_dft,
1028+ /* .replacements = */ params.replacements ,
1029+ /* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
9151030 ));
9161031 break ;
9171032 }
@@ -966,7 +1081,8 @@ common_speculative * common_speculative_init(
9661081 }
9671082
9681083 auto * result = new common_speculative {
969- /* .impls = */ std::move (impls)
1084+ /* .impls = */ std::move (impls),
1085+ /* .curr_impl = */ nullptr ,
9701086 };
9711087
9721088 return result;
0 commit comments