@@ -1136,21 +1136,24 @@ struct common_speculative_session::impl {
11361136 clear_draft ();
11371137 return draft;
11381138 }
1139- if (params_spec.use_checkpoints
1140- && spec_ckpt_n_denials > 0 ) {
1139+ if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1 ) {
1140+ // We shouldn't get two denials.
1141+ LOG_WRN (" %s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n " , __func__,
1142+ cached_text_tokens.size (), spec_ckpt_n_denials, id_last, draft.size ());
11411143 clear_draft ();
11421144 return draft;
11431145 }
11441146
1145- if (spec_ckpt_n_denials > 0 ) {
1147+ if (spec_ckpt_n_denials == 1 ) {
11461148 // there is a previous speculation which wasn't accepted in full length
11471149 if (draft.empty ()) {
11481150 LOG_WRN (" %s: draft of length 0 after denied checkpoint\n " , __func__);
11491151 clear_draft ();
11501152 return draft;
11511153 }
11521154 // we use the shortened draft of previous speculation
1153- LOG_INF (" %s: resuse shortened draft, size=%zu\n " , __func__, draft.size ());
1155+ LOG_DBG (" %s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n " , __func__,
1156+ cached_text_tokens.size (), id_last, draft.size ());
11541157 } else {
11551158 // call the speculative implementation to create a draft
11561159 draft = common_speculative_draft (spec, params_spec, cached_text_tokens, id_last);
@@ -1167,32 +1170,35 @@ struct common_speculative_session::impl {
11671170 }
11681171
11691172 bool do_checkpoint = !draft.empty () && params_spec.use_checkpoints ;
1170- if (do_checkpoint && cached_text_tokens.size () > 5 ) {
1171- LOG_DBG (" draft.size = %zu, n_spec_denials = %d, do_checkpoint = %s, tokens=[..., %d, %d, %d]\n " ,
1173+ if (do_checkpoint && cached_text_tokens.size () > 5 && draft.size () >= 3 ) {
1174+ LOG_DBG (" %s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n " ,
1175+ __func__,
1176+ cached_text_tokens.size (),
11721177 draft.size (), spec_ckpt_n_denials,
1173- do_checkpoint ? " yes" : " no" ,
1178+ do_checkpoint ? " yes" : " no" , id_last,
11741179 cached_text_tokens[cached_text_tokens.size () - 3 ],
11751180 cached_text_tokens[cached_text_tokens.size () - 2 ],
1176- cached_text_tokens[cached_text_tokens.size () - 1 ]);
1181+ cached_text_tokens[cached_text_tokens.size () - 1 ],
1182+ draft[0 ], draft[1 ], draft[2 ]);
1183+ }
1184+
1185+ if (params_spec.n_min > (int ) draft.size ()) {
1186+ LOG_DBG (" ignoring small draft: %d < %d\n " , (int ) draft.size (), params_spec.n_min );
1187+ clear_draft ();
1188+ return draft;
11771189 }
11781190
11791191 if (do_checkpoint) {
11801192 const size_t n = callback.create_checkpoint ();
11811193 if (n == 0 ) {
1182- LOG_WRN (" checkpoint creation failed" );
1194+ LOG_WRN (" %s: checkpoint creation failed (#tokens=%zu) \n " , __func__, cached_text_tokens. size () );
11831195 clear_draft ();
11841196 return draft;
11851197 }
11861198 spec_ckpt_size_part = n;
11871199 spec_has_ckpt = true ;
11881200 }
11891201
1190- if (params_spec.n_min > (int ) draft.size ()) {
1191- LOG_DBG (" ignoring small draft: %d < %d\n " , (int ) draft.size (), params_spec.n_min );
1192- clear_draft ();
1193- return draft;
1194- }
1195-
11961202 // add last sampled token to the batch
11971203 callback.batch_add_token (id_last, true );
11981204
@@ -1219,27 +1225,31 @@ struct common_speculative_session::impl {
12191225 if (spec_has_ckpt) {
12201226 // we need to rollback to the state before sampling the draft tokens
12211227 const size_t n = callback.restore_checkpoint (spec_ckpt_size_part);
1222- LOG_INF (" partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n " ,
1223- ids.size () -1 , n_draft, n);
1228+ LOG_DBG (" %s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n " ,
1229+ __func__,
1230+ ids.size () - 1 , n_draft, n);
12241231
1225- // rollback to the state before sampling the draft tokens
1226-
1227- // Delete Checkpoint
1232+ // delete Checkpoint
12281233 callback.delete_checkpoint ();
12291234 spec_has_ckpt = false ;
12301235
1231- if (n_draft > 0 && spec_ckpt_n_denials == 0 ) {
1236+ spec_ckpt_n_denials++;
1237+ if (ids.size () > 1u + static_cast <std::size_t >(params_spec.n_min ) && spec_ckpt_n_denials == 1 ) {
12321238 // we will do the batch again but with the shortened draft
1233- spec_ckpt_n_denials++;
1234-
12351239 return common_speculative_accept_response (std::move (ids), n_draft, true );
12361240 }
12371241
1238- callback.batch_clear ();
1242+ LOG_DBG (" %s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n " , __func__, n_draft, ids.size ());
1243+ draft.clear ();
1244+
1245+ // use the sampled token only
1246+ ids.resize (1 );
1247+ // drafted tokens in prompt have been deleted in restore_checkpoint(...).
1248+ return common_speculative_accept_response{std::move (ids), 0 , false };
12391249 }
12401250 }
12411251 const size_t draft_size_accepted = draft.size ();
1242- LOG_DBG (" %s: draft.size=%zu\n " , __func__, draft_size_accepted);
1252+ LOG_DBG (" %s: draft.size=%zu, ids.size=%zu \n " , __func__, draft_size_accepted, ids. size () );
12431253 common_speculative_accept (spec, draft_size_accepted);
12441254 draft.clear ();
12451255
0 commit comments