Skip to content

Commit bcb5eeb

Browse files
authored
speculative-simple : add checkpoint support (ggml-org#22227)
* speculative-simple : add checkpoint support * cont : fix build
1 parent 225088e commit bcb5eeb

2 files changed

Lines changed: 112 additions & 10 deletions

File tree

examples/speculative-simple/speculative-simple.cpp

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,24 @@
88
#include <clocale>
99
#include <cstdio>
1010
#include <cstring>
11+
#include <cinttypes>
1112
#include <string>
1213
#include <vector>
14+
#include <utility>
15+
16+
struct spec_checkpoint {
17+
int64_t n_tokens = 0;
18+
19+
std::vector<uint8_t> data;
20+
21+
size_t size() const {
22+
return data.size();
23+
}
24+
25+
bool empty() const {
26+
return data.empty();
27+
}
28+
};
1329

1430
int main(int argc, char ** argv) {
1531
std::setlocale(LC_NUMERIC, "C");
@@ -46,6 +62,14 @@ int main(int argc, char ** argv) {
4662
model_tgt = llama_init_tgt->model();
4763
ctx_tgt = llama_init_tgt->context();
4864

65+
// check if the context supports partial sequence removal
66+
const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt);
67+
const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
68+
69+
if (use_ckpt) {
70+
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
71+
}
72+
4973
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
5074

5175
// load the draft model
@@ -119,7 +143,7 @@ int main(int argc, char ** argv) {
119143
const auto t_enc_start = ggml_time_us();
120144

121145
// target model sampling context
122-
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
146+
common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling));
123147

124148
// eval the prompt
125149
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
@@ -142,21 +166,61 @@ int main(int argc, char ** argv) {
142166

143167
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
144168

169+
size_t n_draft = 0;
170+
171+
llama_tokens draft;
172+
spec_checkpoint spec_ckpt;
173+
145174
const auto t_enc_end = ggml_time_us();
146175

147176
const auto t_dec_start = ggml_time_us();
148177

149178
while (true) {
150-
// optionally, generate draft tokens that can be appended to the target batch
179+
// generate or reuse draft tokens
151180
//
152181
// this is the most important part of the speculation. the more probable tokens that are provided here
153182
// the better the performance will be. in theory, this computation can be performed asynchronously and even
154183
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
155184
// from a cache or lookup tables.
156185
//
157-
llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
186+
if (draft.empty()) {
187+
// generate a new draft
188+
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
189+
190+
if ((int) draft.size() > params_spec.n_max) {
191+
LOG_WRN("draft size %zu exceeds max %d, truncating\n", draft.size(), params_spec.n_max);
192+
draft.resize(params_spec.n_max);
193+
}
194+
195+
if ((int) draft.size() < params_spec.n_min) {
196+
LOG_DBG("ignoring small draft: %zu < %d\n", draft.size(), params_spec.n_min);
197+
draft.clear();
198+
}
199+
200+
// save the original draft size
201+
n_draft = draft.size();
202+
203+
// save a checkpoint of the target context before evaluating the draft
204+
// this allows us to restore the state if partial draft acceptance occurs
205+
if (!draft.empty() && use_ckpt) {
206+
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
207+
spec_ckpt.data.resize(ckpt_size);
158208

159-
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
209+
const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
210+
GGML_ASSERT(n == ckpt_size);
211+
212+
spec_ckpt.n_tokens = (int64_t) prompt_tgt.size();
213+
LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n",
214+
spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
215+
}
216+
} else {
217+
// we have a previous (partial) draft to reuse from checkpoint restoration
218+
if (use_ckpt) {
219+
GGML_ASSERT(!spec_ckpt.empty());
220+
}
221+
}
222+
223+
GGML_ASSERT(n_draft > 0);
160224

161225
// always have a token to evaluate from before - id_last
162226
common_batch_clear(batch_tgt);
@@ -178,21 +242,51 @@ int main(int argc, char ** argv) {
178242
llama_decode(ctx_tgt, batch_tgt);
179243
}
180244

245+
// only save the sampler sampler state if we use checkpoints
246+
common_sampler_ptr smpl_save;
247+
if (use_ckpt) {
248+
smpl_save.reset(common_sampler_clone(smpl.get()));
249+
}
250+
181251
// sample from the full target batch and return the accepted tokens based on the target sampler
182252
//
183253
// for each token to be accepted, the sampler would have to sample that same token
184254
// in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
185255
// available logits from the batch and sample the next token until we run out of logits or the sampler
186256
// disagrees with the draft
187257
//
188-
const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
258+
auto ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft);
189259

190260
//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
191261

192262
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
193263

264+
// check for partial draft acceptance:
265+
// if the context doesn't support partial sequence removal, restore the checkpoint
266+
// and make the accepted tokens the new partial draft for the next iteration
267+
if (use_ckpt && ids.size() - 1 < draft.size()) {
268+
LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size());
269+
270+
draft = std::move(ids);
271+
272+
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
273+
GGML_ASSERT(n == spec_ckpt.size());
274+
275+
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1);
276+
277+
prompt_tgt.resize(spec_ckpt.n_tokens);
278+
smpl = std::move(smpl_save);
279+
280+
n_past = (int) prompt_tgt.size();
281+
282+
continue;
283+
}
284+
285+
common_speculative_accept(spec, ids.size() - 1);
286+
287+
// full acceptance: consume the draft and commit accepted tokens
194288
n_past += ids.size() - 1;
195-
n_drafted += draft.size(); // note: we ignore the discarded small drafts
289+
n_drafted += n_draft; // note: we ignore the discarded small drafts
196290
n_accept += ids.size() - 1;
197291
n_predict += ids.size();
198292

@@ -222,6 +316,9 @@ int main(int argc, char ** argv) {
222316

223317
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
224318

319+
// clear the draft since it has been consumed
320+
draft.clear();
321+
225322
{
226323
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
227324

@@ -254,11 +351,10 @@ int main(int argc, char ** argv) {
254351

255352
LOG_INF("\n");
256353
LOG_INF("target:\n\n");
257-
common_perf_print(ctx_tgt, smpl);
354+
common_perf_print(ctx_tgt, smpl.get());
258355

259356
llama_batch_free(batch_tgt);
260357

261-
common_sampler_free(smpl);
262358
common_speculative_free(spec);
263359

264360
llama_backend_free();

tools/server/server-context.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,7 +2961,13 @@ struct server_context_impl {
29612961

29622962
// verify and try to accept the draft
29632963
{
2964-
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
2964+
const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
2965+
2966+
// only save the sampler sampler state if we use checkpoints
2967+
common_sampler_ptr smpl_save;
2968+
if (use_ckpt) {
2969+
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
2970+
}
29652971

29662972
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
29672973
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
@@ -2973,7 +2979,7 @@ struct server_context_impl {
29732979

29742980
// check for partial draft acceptance
29752981
if (accepted.size() < slot.spec_draft.size() + 1) {
2976-
if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
2982+
if (use_ckpt) {
29772983
// partial acceptance is not supported by the context -> truncate the draft and restore the state
29782984
slot.spec_draft = std::move(accepted);
29792985

0 commit comments

Comments
 (0)