88#include < clocale>
99#include < cstdio>
1010#include < cstring>
11+ #include < cinttypes>
1112#include < string>
1213#include < vector>
14+ #include < utility>
15+
16+ struct spec_checkpoint {
17+ int64_t n_tokens = 0 ;
18+
19+ std::vector<uint8_t > data;
20+
21+ size_t size () const {
22+ return data.size ();
23+ }
24+
25+ bool empty () const {
26+ return data.empty ();
27+ }
28+ };
1329
1430int main (int argc, char ** argv) {
1531 std::setlocale (LC_NUMERIC , " C" );
@@ -46,6 +62,14 @@ int main(int argc, char ** argv) {
4662 model_tgt = llama_init_tgt->model ();
4763 ctx_tgt = llama_init_tgt->context ();
4864
65+ // check if the context supports partial sequence removal
66+ const auto ctx_seq_rm = common_context_can_seq_rm (ctx_tgt);
67+ const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL );
68+
69+ if (use_ckpt) {
70+ LOG_INF (" speculative decoding will use checkpoints (context does not support partial sequence removal)\n " );
71+ }
72+
4973 const llama_vocab * vocab = llama_model_get_vocab (model_tgt);
5074
5175 // load the draft model
@@ -119,7 +143,7 @@ int main(int argc, char ** argv) {
119143 const auto t_enc_start = ggml_time_us ();
120144
121145 // target model sampling context
122- struct common_sampler * smpl = common_sampler_init (model_tgt, params.sampling );
146+ common_sampler_ptr smpl ( common_sampler_init (model_tgt, params.sampling ) );
123147
124148 // eval the prompt
125149 llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), inp.size () - 1 ));
@@ -142,21 +166,61 @@ int main(int argc, char ** argv) {
142166
143167 llama_batch batch_tgt = llama_batch_init (llama_n_batch (ctx_tgt), 0 , 1 );
144168
169+ size_t n_draft = 0 ;
170+
171+ llama_tokens draft;
172+ spec_checkpoint spec_ckpt;
173+
145174 const auto t_enc_end = ggml_time_us ();
146175
147176 const auto t_dec_start = ggml_time_us ();
148177
149178 while (true ) {
150- // optionally, generate draft tokens that can be appended to the target batch
179+ // generate or reuse draft tokens
151180 //
152181 // this is the most important part of the speculation. the more probable tokens that are provided here
153182 // the better the performance will be. in theory, this computation can be performed asynchronously and even
154183 // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
155184 // from a cache or lookup tables.
156185 //
157- llama_tokens draft = common_speculative_draft (spec, params_spec, prompt_tgt, id_last);
186+ if (draft.empty ()) {
187+ // generate a new draft
188+ draft = common_speculative_draft (spec, params_spec, prompt_tgt, id_last);
189+
190+ if ((int ) draft.size () > params_spec.n_max ) {
191+ LOG_WRN (" draft size %zu exceeds max %d, truncating\n " , draft.size (), params_spec.n_max );
192+ draft.resize (params_spec.n_max );
193+ }
194+
195+ if ((int ) draft.size () < params_spec.n_min ) {
196+ LOG_DBG (" ignoring small draft: %zu < %d\n " , draft.size (), params_spec.n_min );
197+ draft.clear ();
198+ }
199+
200+ // save the original draft size
201+ n_draft = draft.size ();
202+
203+ // save a checkpoint of the target context before evaluating the draft
204+ // this allows us to restore the state if partial draft acceptance occurs
205+ if (!draft.empty () && use_ckpt) {
206+ const size_t ckpt_size = llama_state_seq_get_size_ext (ctx_tgt, 0 , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
207+ spec_ckpt.data .resize (ckpt_size);
158208
159- // LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
209+ const size_t n = llama_state_seq_get_data_ext (ctx_tgt, spec_ckpt.data .data (), ckpt_size, 0 , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
210+ GGML_ASSERT (n == ckpt_size);
211+
212+ spec_ckpt.n_tokens = (int64_t ) prompt_tgt.size ();
213+ LOG_DBG (" created speculative checkpoint (n_tokens = %" PRId64 " , size = %.3f MiB)\n " ,
214+ spec_ckpt.n_tokens , (float ) spec_ckpt.data .size () / 1024 / 1024 );
215+ }
216+ } else {
217+ // we have a previous (partial) draft to reuse from checkpoint restoration
218+ if (use_ckpt) {
219+ GGML_ASSERT (!spec_ckpt.empty ());
220+ }
221+ }
222+
223+ GGML_ASSERT (n_draft > 0 );
160224
161225 // always have a token to evaluate from before - id_last
162226 common_batch_clear (batch_tgt);
@@ -178,21 +242,51 @@ int main(int argc, char ** argv) {
178242 llama_decode (ctx_tgt, batch_tgt);
179243 }
180244
245+ // only save the sampler sampler state if we use checkpoints
246+ common_sampler_ptr smpl_save;
247+ if (use_ckpt) {
248+ smpl_save.reset (common_sampler_clone (smpl.get ()));
249+ }
250+
181251 // sample from the full target batch and return the accepted tokens based on the target sampler
182252 //
183253 // for each token to be accepted, the sampler would have to sample that same token
184254 // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
185255 // available logits from the batch and sample the next token until we run out of logits or the sampler
186256 // disagrees with the draft
187257 //
188- const auto ids = common_sampler_sample_and_accept_n (smpl, ctx_tgt, draft);
258+ auto ids = common_sampler_sample_and_accept_n (smpl. get () , ctx_tgt, draft);
189259
190260 // LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
191261
192262 GGML_ASSERT (ids.size () > 0 ); // there will always be at least one accepted token
193263
264+ // check for partial draft acceptance:
265+ // if the context doesn't support partial sequence removal, restore the checkpoint
266+ // and make the accepted tokens the new partial draft for the next iteration
267+ if (use_ckpt && ids.size () - 1 < draft.size ()) {
268+ LOG_DBG (" partial acceptance: %zu < %zu, restoring checkpoint\n " , ids.size () - 1 , draft.size ());
269+
270+ draft = std::move (ids);
271+
272+ const size_t n = llama_state_seq_set_data_ext (ctx_tgt, spec_ckpt.data .data (), spec_ckpt.size (), 0 , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
273+ GGML_ASSERT (n == spec_ckpt.size ());
274+
275+ llama_memory_seq_rm (llama_get_memory (ctx_tgt), 0 , spec_ckpt.n_tokens , -1 );
276+
277+ prompt_tgt.resize (spec_ckpt.n_tokens );
278+ smpl = std::move (smpl_save);
279+
280+ n_past = (int ) prompt_tgt.size ();
281+
282+ continue ;
283+ }
284+
285+ common_speculative_accept (spec, ids.size () - 1 );
286+
287+ // full acceptance: consume the draft and commit accepted tokens
194288 n_past += ids.size () - 1 ;
195- n_drafted += draft. size () ; // note: we ignore the discarded small drafts
289+ n_drafted += n_draft ; // note: we ignore the discarded small drafts
196290 n_accept += ids.size () - 1 ;
197291 n_predict += ids.size ();
198292
@@ -222,6 +316,9 @@ int main(int argc, char ** argv) {
222316
223317 LOG_DBG (" accepted %d/%d draft tokens, the last target token is: (%d)\n " , (int ) ids.size () - 1 , (int ) draft.size (), id_last);
224318
319+ // clear the draft since it has been consumed
320+ draft.clear ();
321+
225322 {
226323 LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " , n_past);
227324
@@ -254,11 +351,10 @@ int main(int argc, char ** argv) {
254351
255352 LOG_INF (" \n " );
256353 LOG_INF (" target:\n\n " );
257- common_perf_print (ctx_tgt, smpl);
354+ common_perf_print (ctx_tgt, smpl. get () );
258355
259356 llama_batch_free (batch_tgt);
260357
261- common_sampler_free (smpl);
262358 common_speculative_free (spec);
263359
264360 llama_backend_free ();
0 commit comments