@@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state {
164164 llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
165165 llama_context * ctx_dft;
166166
167+ bool use_ckpt = false ;
167168 struct common_speculative_checkpoint ckpt;
168- bool use_checkpoint;
169169
170170 common_sampler * smpl;
171171
@@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state {
180180 llama_context * ctx_tgt,
181181 llama_context * ctx_dft,
182182 const std::vector<std::pair<std::string, std::string>> & replacements,
183- bool use_checkpoint )
183+ bool use_ckpt )
184184 : common_speculative_state(type)
185185 , ctx_tgt(ctx_tgt)
186186 , ctx_dft(ctx_dft)
187- , use_checkpoint(use_checkpoint )
187+ , use_ckpt(use_ckpt )
188188 {
189189 batch = llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 );
190190 smpl = nullptr ;
@@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state {
239239 }
240240
241241 void begin (const llama_tokens & prompt) override {
242- if (use_checkpoint && ckpt.size () > 0 ) {
242+ if (use_ckpt && ckpt.size () > 0 ) {
243243 // delete checkpoint
244244 LOG_DBG (" %s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 " , size=%.3f MiB\n " ,
245245 __func__, prompt.size (), ckpt.pos_min , ckpt.pos_max , ckpt.n_tokens , (float ) ckpt.data .size () / 1024 / 1024 );
@@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state {
351351
352352 LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n " ,
353353 __func__, reuse_i, reuse_n, prompt_dft.size (), prompt_cur.size ());
354- if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0 ) {
354+ if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0 ) {
355355 LOG_DBG (" %s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n " ,
356356 __func__, reuse_i, reuse_n);
357357 reuse_i = 0 ;
@@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state {
361361 result.clear ();
362362 result.reserve (params.n_max );
363363
364- bool needs_ckpt = use_checkpoint && prompt_dft.size () > 0 ;
365- if (reuse_n == 0 || (use_checkpoint && reuse_i > 0 )) {
364+ bool needs_ckpt = use_ckpt && prompt_dft.size () > 0 ;
365+ if (reuse_n == 0 || (use_ckpt && reuse_i > 0 )) {
366366 llama_memory_clear (mem_dft, false );
367367 prompt_dft.clear ();
368368 } else {
@@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state {
400400 }
401401
402402 if (reuse_n < (int ) prompt_dft.size () || do_restore) {
403- if (use_checkpoint ) {
403+ if (use_ckpt ) {
404404 if (ckpt.n_tokens > (int64_t ) prompt_dft.size ()) {
405405 LOG_INF (" %s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 " , reuse_n=%d, prompt_dft.size=%zu\n " ,
406406 __func__, prompt_tgt.size (), ckpt.n_tokens , reuse_n, prompt_dft.size ());
@@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
912912 return it->second ;
913913}
914914
915- common_speculative_compat_type common_speculative_is_compat (llama_context * ctx_tgt) {
916- auto * mem = llama_get_memory (ctx_tgt);
917- if (mem == nullptr ) {
918- return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
919- }
920-
921- common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
922-
923- llama_memory_clear (mem, true );
924-
925- // eval 2 tokens to check if the context is compatible
926- std::vector<llama_token> tmp;
927- tmp.push_back (0 );
928- tmp.push_back (0 );
929-
930- int ret = llama_decode (ctx_tgt, llama_batch_get_one (tmp.data (), tmp.size ()));
931- if (ret != 0 ) {
932- LOG_ERR (" %s: llama_decode() failed: %d\n " , __func__, ret);
933- res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
934- goto done;
935- }
936-
937- // try to remove the last tokens
938- if (!llama_memory_seq_rm (mem, 0 , 1 , -1 )) {
939- LOG_WRN (" %s: the target context does not support partial sequence removal\n " , __func__);
940- res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
941- goto done;
942- }
943-
944- done:
945- llama_memory_clear (mem, true );
946- llama_synchronize (ctx_tgt);
947-
948- return res;
949- }
950-
951915// initialization of the speculative decoding system
952916//
953917common_speculative * common_speculative_init (
@@ -1022,11 +986,13 @@ common_speculative * common_speculative_init(
1022986 case COMMON_SPECULATIVE_TYPE_NONE:
1023987 break ;
1024988 case COMMON_SPECULATIVE_TYPE_DRAFT: {
989+ const bool use_ckpt = common_context_can_seq_rm (ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
990+
1025991 impls.push_back (std::make_unique<common_speculative_state_draft>(config.type ,
1026- /* .ctx_tgt = */ ctx_tgt,
1027- /* .ctx_dft = */ ctx_dft,
1028- /* .replacements = */ params.replacements ,
1029- /* .use_checkpoint= */ params. use_checkpoints // TODO: this should be based on the draft model!
992+ /* .ctx_tgt = */ ctx_tgt,
993+ /* .ctx_dft = */ ctx_dft,
994+ /* .replacements = */ params.replacements ,
995+ /* .use_ckpt = */ use_ckpt
1030996 ));
1031997 break ;
1032998 }
0 commit comments