@@ -165,7 +165,7 @@ struct common_speculative_state_draft : public common_speculative_state {
165165 llama_context * ctx_dft;
166166
167167 struct common_speculative_checkpoint ckpt;
168- bool use_checkpoint;
168+ bool use_checkpoint;
169169
170170 common_sampler * smpl;
171171
@@ -401,7 +401,7 @@ struct common_speculative_state_draft : public common_speculative_state {
401401 if (reuse_n < (int ) prompt_dft.size () || do_restore) {
402402 if (use_checkpoint) {
403403 if (ckpt.n_tokens > (int64_t ) prompt_dft.size ()) {
404- LOG_INF (" %s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%zu , reuse_n=%d, prompt_dft.size=%zu\n " ,
404+ LOG_INF (" %s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 " , reuse_n=%d, prompt_dft.size=%zu\n " ,
405405 __func__, prompt_tgt.size (), ckpt.n_tokens , reuse_n, prompt_dft.size ());
406406 }
407407 draft_restore_checkpoint (ckpt.ckpt_size );
@@ -1207,36 +1207,36 @@ struct common_speculative_session::impl {
12071207 llama_tokens draft;
12081208
12091209 // use of checkpoints in speculative mode
1210- bool spec_has_ckpt = false ; // true if a checkpoint for rollback after partial speculation has been created
1211- uint16_t spec_ckpt_n_denials = 0 ; // number of drafts not accepted at the current position (0 or 1)
1212- size_t spec_ckpt_size_part = 0 ; // size of partial checkpoint
1210+ bool spec_has_ckpt = false ; // true if a checkpoint for rollback after partial speculation has been created
1211+ uint16_t spec_ckpt_n_denials = 0 ; // number of drafts not accepted at the current position (0 or 1)
12131212
12141213 // Speculative decoding stats
12151214 int32_t n_draft_total = 0 ; // Total draft tokens generated
12161215 int32_t n_draft_accepted = 0 ; // Draft tokens actually accepted
12171216
1218- impl (common_speculative_callback & callback,
1217+ impl (
12191218 const common_params_speculative & params,
1219+ common_speculative_callback & callback,
12201220 llama_context * ctx_tgt)
12211221 : callback(callback), params_spec(params), ctx_tgt(ctx_tgt) {
12221222 spec = common_speculative_init (params_spec, ctx_tgt);
12231223 }
12241224
1225- void begin (const llama_tokens & prompt_history) {
1225+ void begin (const llama_tokens & prompt_history) const {
12261226 common_speculative_begin (spec, prompt_history);
12271227 }
12281228
1229- bool has_batch_dft () {
1229+ bool has_batch_dft () const {
12301230 return !draft.empty ();
12311231 }
12321232
12331233 void leave_draft_state () {
12341234 draft.clear ();
1235- spec_ckpt_n_denials = 0 ;
1235+ spec_ckpt_n_denials = 0 ;
12361236 }
12371237
12381238 llama_tokens compute_draft (
1239- const llama_tokens & cached_text_tokens ,
1239+ const llama_tokens & tokens ,
12401240 llama_token id_last,
12411241 const int n_draft_max) {
12421242 if (spec == nullptr ) {
@@ -1249,10 +1249,11 @@ struct common_speculative_session::impl {
12491249 leave_draft_state ();
12501250 return draft;
12511251 }
1252+
12521253 if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1 ) {
12531254 // We shouldn't get two denials.
12541255 LOG_WRN (" %s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n " , __func__,
1255- cached_text_tokens .size (), spec_ckpt_n_denials, id_last, draft.size ());
1256+ tokens .size (), spec_ckpt_n_denials, id_last, draft.size ());
12561257 leave_draft_state ();
12571258 return draft;
12581259 }
@@ -1267,12 +1268,12 @@ struct common_speculative_session::impl {
12671268 }
12681269 // we use the shortened draft of previous speculation
12691270 LOG_DBG (" %s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n " , __func__,
1270- cached_text_tokens .size (), id_last, draft.size ());
1271+ tokens .size (), id_last, draft.size ());
12711272 } else if (spec_ckpt_n_denials > 1 ) {
12721273 GGML_ABORT (" illegal state: spec_ckpt_n_denials = %d > 1" , spec_ckpt_n_denials);
12731274 } else {
12741275 // call the speculative implementation to create a draft
1275- draft = common_speculative_draft (spec, params_spec, cached_text_tokens , id_last);
1276+ draft = common_speculative_draft (spec, params_spec, tokens , id_last);
12761277 LOG_DBG (" draft: id_last=%d, #draft=%zu\n " , id_last, draft.size ());
12771278 if (draft.empty ()) {
12781279 leave_draft_state ();
@@ -1286,15 +1287,15 @@ struct common_speculative_session::impl {
12861287 }
12871288
12881289 bool do_checkpoint = !draft.empty () && params_spec.use_checkpoints ;
1289- if (do_checkpoint && cached_text_tokens .size () > 5 && draft.size () >= 3 ) {
1290+ if (do_checkpoint && tokens .size () > 5 && draft.size () >= 3 ) {
12901291 LOG_DBG (" %s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n " ,
12911292 __func__,
1292- cached_text_tokens .size (),
1293+ tokens .size (),
12931294 draft.size (), spec_ckpt_n_denials,
12941295 do_checkpoint ? " yes" : " no" , id_last,
1295- cached_text_tokens[cached_text_tokens .size () - 3 ],
1296- cached_text_tokens[cached_text_tokens .size () - 2 ],
1297- cached_text_tokens[cached_text_tokens .size () - 1 ],
1296+ tokens[tokens .size () - 3 ],
1297+ tokens[tokens .size () - 2 ],
1298+ tokens[tokens .size () - 1 ],
12981299 draft[0 ], draft[1 ], draft[2 ]);
12991300 }
13001301
@@ -1305,13 +1306,12 @@ struct common_speculative_session::impl {
13051306 }
13061307
13071308 if (do_checkpoint) {
1308- const size_t n = callback.create_checkpoint ();
1309+ const size_t n = callback.create_checkpoint (tokens. size () );
13091310 if (n == 0 ) {
1310- LOG_WRN (" %s: checkpoint creation failed (#tokens=%zu)\n " , __func__, cached_text_tokens .size ());
1311+ LOG_WRN (" %s: checkpoint creation failed (#tokens=%zu)\n " , __func__, tokens .size ());
13111312 leave_draft_state ();
13121313 return draft;
13131314 }
1314- spec_ckpt_size_part = n;
13151315 spec_has_ckpt = true ;
13161316 }
13171317
@@ -1341,7 +1341,7 @@ struct common_speculative_session::impl {
13411341 if (spec_has_ckpt) {
13421342 // we need to rollback to the state before sampling the draft tokens
13431343 // (restore_checkpoint shortens context and slot.prompt.tokens)
1344- const size_t n = callback.restore_checkpoint (spec_ckpt_size_part );
1344+ const size_t n = callback.restore_checkpoint ();
13451345 LOG_DBG (" %s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n " ,
13461346 __func__,
13471347 ids.size () - 1 , n_draft, n);
@@ -1367,8 +1367,10 @@ struct common_speculative_session::impl {
13671367 return common_speculative_accept_response{std::move (ids), 0 , true };
13681368 }
13691369 }
1370+
13701371 const size_t draft_size_accepted = draft.size ();
13711372 LOG_DBG (" %s: draft.size=%zu, ids.size=%zu\n " , __func__, draft_size_accepted, ids.size ());
1373+
13721374 common_speculative_accept (spec, draft_size_accepted);
13731375 draft.clear ();
13741376
@@ -1401,15 +1403,14 @@ struct common_speculative_session::impl {
14011403
14021404 leave_draft_state ();
14031405
1404- spec_has_ckpt = false ;
1405- spec_ckpt_size_part = 0 ;
1406+ spec_has_ckpt = false ;
14061407 }
14071408};
14081409
14091410common_speculative_session::common_speculative_session (
1410- common_speculative_callback & callback,
14111411 const common_params_speculative & params,
1412- llama_context * ctx_tgt) : p_impl(new impl{callback, params, ctx_tgt}) {
1412+ common_speculative_callback & callback,
1413+ llama_context * ctx_tgt) : p_impl(new impl{params, callback, ctx_tgt}) {
14131414}
14141415
14151416common_speculative_session::~common_speculative_session () {
0 commit comments