Skip to content

Commit db8e326

Browse files
committed
spec : introduce common_speculative_process()
1 parent 0d5dd61 commit db8e326

3 files changed

Lines changed: 97 additions & 32 deletions

File tree

common/speculative.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ struct common_speculative_impl {
149149

150150
virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0;
151151

152+
virtual bool process(const llama_batch & batch) = 0;
153+
152154
virtual void draft(common_speculative_draft_params_vec & dparams) = 0;
153155

154156
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
@@ -223,6 +225,20 @@ struct common_speculative_state_draft : public common_speculative_impl {
223225
// noop
224226
}
225227

228+
bool process(const llama_batch & batch) override {
229+
auto * ctx_dft = params.ctx_dft;
230+
231+
const int ret = llama_decode(ctx_dft, batch);
232+
233+
if (ret != 0) {
234+
LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret);
235+
236+
return false;
237+
}
238+
239+
return true;
240+
}
241+
226242
void draft(common_speculative_draft_params_vec & dparams) override {
227243
auto & ctx_dft = params.ctx_dft;
228244

@@ -345,6 +361,11 @@ struct common_speculative_state_eagle3 : public common_speculative_impl {
345361
// noop
346362
}
347363

364+
bool process(const llama_batch & /*batch*/) override {
365+
// TODO: implement
366+
return true;
367+
}
368+
348369
void draft(common_speculative_draft_params_vec & /*dparams*/) override {
349370
// TODO: implement
350371
}
@@ -372,6 +393,11 @@ struct common_speculative_state_ngram_simple : public common_speculative_impl {
372393
// noop
373394
}
374395

396+
bool process(const llama_batch & /*batch*/) override {
397+
// TODO: implement
398+
return true;
399+
}
400+
375401
void draft(common_speculative_draft_params_vec & dparams) override {
376402
assert(dparams.size() == n_seq);
377403

@@ -413,6 +439,11 @@ struct common_speculative_state_ngram_map_k : public common_speculative_impl {
413439
common_ngram_map_begin(config[seq_id], prompt);
414440
}
415441

442+
bool process(const llama_batch & /*batch*/) override {
443+
// TODO: implement
444+
return true;
445+
}
446+
416447
void draft(common_speculative_draft_params_vec & dparams) override {
417448
assert(dparams.size() == n_seq);
418449

@@ -559,6 +590,11 @@ struct common_speculative_state_ngram_mod : public common_speculative_impl {
559590
sinfo.n_draft_last = result.size();
560591
}
561592

593+
bool process(const llama_batch & /*batch*/) override {
594+
// TODO: implement
595+
return true;
596+
}
597+
562598
void draft(common_speculative_draft_params_vec & dparams) override {
563599
assert(dparams.size() == n_seq);
564600

@@ -706,6 +742,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_impl {
706742
}
707743
}
708744

745+
bool process(const llama_batch & /*batch*/) override {
746+
// TODO: implement
747+
return true;
748+
}
749+
709750
void draft(common_speculative_draft_params_vec & dparams) override {
710751
assert(dparams.size() == n_seq);
711752

@@ -937,6 +978,20 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
937978
}
938979
}
939980

981+
bool common_speculative_process(common_speculative * spec, const llama_batch & batch) {
982+
bool result = true;
983+
984+
if (spec == nullptr) {
985+
return result;
986+
}
987+
988+
for (auto & impl : spec->impls) {
989+
result = result && impl->process(batch);
990+
}
991+
992+
return result;
993+
}
994+
940995
void common_speculative_draft(common_speculative * spec) {
941996
if (spec == nullptr) {
942997
return;

common/speculative.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct common_speculative_draft_params {
2222
// this flag is used to chain the drafts through all the available implementations
2323
// after the first successful draft from an implementation, we set it
2424
// to false to prevent further drafts for that sequence
25+
// at the end of the draft() call, all drafting flags will be reset to false
2526
bool drafting = false;
2627

2728
// overrides individual configurations (-1 disabled)
@@ -43,8 +44,8 @@ common_speculative_draft_params & common_speculative_get_draft_params(common_spe
4344
// optionally call once at the beginning of a new generation
4445
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
4546

46-
// TODO: implement [TAG_COMMON_SPECULATIVE_PROCESS]
47-
//bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
47+
// process the batch and update the internal state of the speculative context
48+
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
4849

4950
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
5051
void common_speculative_draft(common_speculative * spec);

tools/server/server-context.cpp

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2676,6 +2676,7 @@ struct server_context_impl {
26762676
if (ctx_dft) {
26772677
// TODO: in the future, figure out how to infuse target embeddings to the images
26782678
// for now, we skip this for simplicity
2679+
// maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above?
26792680
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
26802681
if (res != 0) {
26812682
GGML_ABORT("failed to process multi-modal data on draft context\n");
@@ -2925,36 +2926,44 @@ struct server_context_impl {
29252926
// | Eagle3 | yes |
29262927
// | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982
29272928
//
2928-
// TODO: move to `common_speculative_process(spec, batch, ...)` [TAG_COMMON_SPECULATIVE_PROCESS]
2929-
if (ctx_dft) {
2930-
// TODO: update as needed for MTP, Eagle3, etc.
2931-
const bool need_tgt_embd = false;
2932-
2933-
if (need_tgt_embd) {
2934-
llama_synchronize(ctx_tgt);
2935-
}
2936-
2937-
// the logic here varies depending on the speculative decoding method
2938-
// - some draft contexts require embeddings from the target context, others don't
2939-
// - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings
2940-
// TODO: extract this in a function ?
2941-
{
2942-
// TODO: hook the embeddings from the last target batch here
2943-
if (llama_model_has_encoder(model_dft.get())) {
2944-
//llama_encode(ctx_dft, ...);
2945-
2946-
GGML_ABORT("not implemented yet\n");
2947-
}
2948-
2949-
const int ret = llama_decode(ctx_dft.get(), batch_view);
2950-
2951-
if (ret != 0) {
2952-
SRV_ERR("failed to decode draft batch, ret = %d\n", ret);
2953-
2954-
// TODO: handle error
2955-
break;
2956-
}
2957-
}
2929+
// note: this logic is now moved in `common_speculative_process()`
2930+
// keeping the sketch here until for a bit, until the logic is finalized
2931+
//
2932+
//if (ctx_dft) {
2933+
// // TODO: update as needed for MTP, Eagle3, etc.
2934+
// const bool need_tgt_embd = false;
2935+
2936+
// if (need_tgt_embd) {
2937+
// llama_synchronize(ctx_tgt);
2938+
// }
2939+
2940+
// // the logic here varies depending on the speculative decoding method
2941+
// // - some draft contexts require embeddings from the target context, others don't
2942+
// // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings
2943+
// // TODO: extract this in a function ?
2944+
// {
2945+
// // TODO: hook the embeddings from the last target batch here
2946+
// if (llama_model_has_encoder(model_dft.get())) {
2947+
// //llama_encode(ctx_dft, ...);
2948+
2949+
// GGML_ABORT("not implemented yet\n");
2950+
// }
2951+
2952+
// const int ret = llama_decode(ctx_dft.get(), batch_view);
2953+
2954+
// if (ret != 0) {
2955+
// SRV_ERR("failed to decode draft batch, ret = %d\n", ret);
2956+
2957+
// // TODO: handle error
2958+
// break;
2959+
// }
2960+
// }
2961+
//}
2962+
if (!common_speculative_process(spec.get(), batch_view)) {
2963+
SRV_ERR("%s", "failed to process speculative batch\n");
2964+
2965+
// TODO: handle error
2966+
break;
29582967
}
29592968

29602969
// move the head of the batch forward with the number of tokens we just processed

0 commit comments

Comments
 (0)