Skip to content

Commit 4c7b337

Browse files
ggerganovPetros Sideris
authored andcommitted
spec : parallel drafting support (ggml-org#22838)
* spec : refactor * spec : drop support for incompatible vocabs * spec : update common_speculative_init() * cont : pass seq_id * cont : dedup ctx_seq_rm_type * server : sketch the ctx_dft decode loop * server : draft prompt cache and checkpoints * server : improve ctx names * server, spec : transition to unified spec context * cont : sync main and drft contexts * cont : async drft eval when possible * cont : handle non-ckpt models * cont : pass correct n_past for drafting * cont : process images throught the draft context * spec : handle draft running out of context * server : fix mtmd draft processing * server : fix URL for draft model * server : add comment * server : clean-up + dry * speculative-simple : update * spec : fix n_past type * server : fix slot ctx_drft ptr * tools : update readme * naming : improve consistency * spec : refactor for multi-sequence speculative context * cont : prepare params * cont : prepare params * spec : support parallel drafts * server : support parallel drafting * llama : reuse device buffers when possible * server, spec : clean-up * cont : clean-up * cont : minor * spec : reset `drafting` flag at the end * spec : introduce `common_speculative_process()` * spec : allow for multiple spec types (chain of speculators) * replace old type field of type common_speculative_type in the common_params_speculative struct with a vector to allow multiple types to be specified * introduce common_get_enabled_speculative_impls(const std::vector<enum common_speculative_type>) to figure out which implementations the user has enabled * introduce common_speculative_type_from_names(const std::vector<std::string> & names) to parse the already user provided spec types * all speculators run sequentially, best one wins (we verify its drafted tokens) * maximize expected accepted tokens for current round by calculating the product between the probability of accepting current token (n_acc_tokens / n_gen_drafts) and the draft's length --------- Co-authored-by: Petros Sideris <petros.sideris@nokia.com>
1 parent b365b82 commit 4c7b337

14 files changed

Lines changed: 712 additions & 389 deletions

File tree

common/arg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
606606
}
607607
}
608608
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
609-
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
609+
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
610610
}
611611

612612
// model is required (except for server)
@@ -3483,7 +3483,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34833483
[](common_params & params, int value) {
34843484
if (value < 0) {
34853485
throw std::invalid_argument("invalid value");
3486-
}
3486+
}
34873487
for (int i = 0; i < value; ++i) {
34883488
static std::list<std::string> buft_overrides_draft;
34893489
buft_overrides_draft.push_back(llm_ffn_exps_block_regex(i));
@@ -3660,7 +3660,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
36603660
[](common_params & params, int value) {
36613661
if (value < 1 || value > 1024) {
36623662
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
3663-
}
3663+
}
36643664
params.speculative.ngram_map_k.size_n = value;
36653665
}
36663666
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
@@ -3670,7 +3670,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
36703670
[](common_params & params, int value) {
36713671
if (value < 1 || value > 1024) {
36723672
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
3673-
}
3673+
}
36743674
params.speculative.ngram_map_k.size_m = value;
36753675
}
36763676
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));

common/common.cpp

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1428,7 +1428,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
14281428

14291429
// try to remove the last tokens
14301430
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
1431-
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
1431+
LOG_WRN("%s: the context does not support partial sequence removal\n", __func__);
14321432
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
14331433
goto done;
14341434
}
@@ -1966,3 +1966,102 @@ bool common_prompt_batch_decode(
19661966

19671967
return true;
19681968
}
1969+
1970+
size_t common_prompt_checkpoint::size() const {
1971+
return data_tgt.size() + data_dft.size();
1972+
}
1973+
1974+
bool common_prompt_checkpoint::empty() const {
1975+
return data_tgt.empty();
1976+
}
1977+
1978+
void common_prompt_checkpoint::clear() {
1979+
n_tokens = 0;
1980+
1981+
pos_min = 0;
1982+
pos_max = 0;
1983+
1984+
data_tgt.clear();
1985+
data_dft.clear();
1986+
}
1987+
1988+
void common_prompt_checkpoint::update_pos(
1989+
int64_t n_tokens,
1990+
llama_pos pos_min,
1991+
llama_pos pos_max) {
1992+
this->n_tokens = n_tokens;
1993+
this->pos_min = pos_min;
1994+
this->pos_max = pos_max;
1995+
}
1996+
1997+
void common_prompt_checkpoint::update_tgt(
1998+
llama_context * ctx,
1999+
llama_seq_id seq_id,
2000+
llama_state_seq_flags flags) {
2001+
if (ctx == nullptr) {
2002+
return;
2003+
}
2004+
2005+
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
2006+
2007+
data_tgt.resize(ckpt_size);
2008+
2009+
const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags);
2010+
if (n != ckpt_size) {
2011+
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
2012+
}
2013+
}
2014+
2015+
void common_prompt_checkpoint::update_dft(
2016+
llama_context * ctx,
2017+
llama_seq_id seq_id,
2018+
llama_state_seq_flags flags) {
2019+
if (ctx == nullptr) {
2020+
return;
2021+
}
2022+
2023+
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
2024+
2025+
data_dft.resize(ckpt_size);
2026+
2027+
const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags);
2028+
if (n != ckpt_size) {
2029+
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
2030+
}
2031+
}
2032+
2033+
void common_prompt_checkpoint::load_tgt(
2034+
llama_context * ctx,
2035+
llama_seq_id seq_id,
2036+
llama_state_seq_flags flags) const {
2037+
if (ctx == nullptr) {
2038+
return;
2039+
}
2040+
2041+
if (data_tgt.empty()) {
2042+
return;
2043+
}
2044+
2045+
const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags);
2046+
if (n != data_tgt.size()) {
2047+
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n);
2048+
}
2049+
}
2050+
2051+
void common_prompt_checkpoint::load_dft(
2052+
llama_context * ctx,
2053+
llama_seq_id seq_id,
2054+
llama_state_seq_flags flags) const {
2055+
if (ctx == nullptr) {
2056+
return;
2057+
}
2058+
2059+
if (data_dft.empty()) {
2060+
return;
2061+
}
2062+
2063+
const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags);
2064+
if (n != data_dft.size()) {
2065+
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
2066+
}
2067+
}

common/common.h

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,6 @@ struct common_params_model {
295295
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
296296
};
297297

298-
struct common_ngram_mod;
299-
300298
// draft-model-based speculative decoding parameters
301299
struct common_params_speculative_draft {
302300
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
@@ -307,11 +305,9 @@ struct common_params_speculative_draft {
307305

308306
common_params_model mparams;
309307

310-
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
311-
312-
llama_context_params cparams; // these are the parameters for the draft llama_context
308+
llama_context * ctx_tgt = nullptr;
309+
llama_context * ctx_dft = nullptr;
313310

314-
int32_t n_ctx = 0; // draft context size
315311
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
316312

317313
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
@@ -322,7 +318,6 @@ struct common_params_speculative_draft {
322318

323319
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
324320

325-
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
326321
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
327322
};
328323

@@ -331,9 +326,6 @@ struct common_params_speculative_ngram_mod {
331326

332327
int32_t n_max = 64;
333328
int32_t n_min = 48;
334-
335-
// shared instance of the ngram container for all speculative decoding contexts
336-
std::shared_ptr<common_ngram_mod> obj;
337329
};
338330

339331
struct common_params_speculative_ngram_map {
@@ -348,8 +340,7 @@ struct common_params_speculative_ngram_cache {
348340
};
349341

350342
struct common_params_speculative {
351-
// TODO: become a vector in order to support "chains of speculators"
352-
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
343+
std::vector<enum common_speculative_type> types = { COMMON_SPECULATIVE_TYPE_NONE };
353344

354345
common_params_speculative_draft draft;
355346

@@ -1026,3 +1017,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
10261017

10271018
// "adamw" or "sgd" (case insensitive)
10281019
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
1020+
1021+
//
1022+
// prompt utils
1023+
//
1024+
1025+
struct common_prompt_checkpoint {
1026+
int64_t n_tokens;
1027+
1028+
llama_pos pos_min;
1029+
llama_pos pos_max;
1030+
1031+
std::vector<uint8_t> data_tgt;
1032+
std::vector<uint8_t> data_dft;
1033+
1034+
size_t size() const;
1035+
1036+
bool empty() const;
1037+
void clear();
1038+
1039+
void update_pos(
1040+
int64_t n_tokens,
1041+
llama_pos pos_min,
1042+
llama_pos pos_max);
1043+
1044+
void update_tgt(
1045+
llama_context * ctx,
1046+
llama_seq_id seq_id,
1047+
llama_state_seq_flags flags);
1048+
1049+
void update_dft(
1050+
llama_context * ctx,
1051+
llama_seq_id seq_id,
1052+
llama_state_seq_flags flags);
1053+
1054+
void load_tgt(
1055+
llama_context * ctx,
1056+
llama_seq_id seq_id,
1057+
llama_state_seq_flags flags) const;
1058+
1059+
void load_dft(
1060+
llama_context * ctx,
1061+
llama_seq_id seq_id,
1062+
llama_state_seq_flags flags) const;
1063+
};

common/speculative.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ struct common_speculative_state_draft : public common_speculative_state {
344344
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
345345
int cur = 0;
346346
while (i_start + cur < (int) prompt_cur.size() &&
347-
i + cur < (int) prompt_dft.size() &&
348-
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
347+
i + cur < (int) prompt_dft.size() &&
348+
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
349349
cur++;
350350
}
351351

@@ -418,7 +418,7 @@ struct common_speculative_state_draft : public common_speculative_state {
418418
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
419419
return;
420420
}
421-
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
421+
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
422422
}
423423
}
424424
}
@@ -782,7 +782,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
782782
n_low++;
783783
if (n_low >= 3) {
784784
if (verbose) {
785-
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
785+
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
786786
}
787787

788788
mod.reset();
@@ -1065,12 +1065,12 @@ common_speculative * common_speculative_init(
10651065
uint16_t mgram_size_value = ngram_map.size_value;
10661066

10671067
auto config_simple = common_ngram_simple_config {
1068-
/* .size_ngram = */ ngram_size_key,
1069-
/* .size_mgram = */ mgram_size_value
1068+
/* .size_ngram = */ ngram_size_key,
1069+
/* .size_mgram = */ mgram_size_value
10701070
};
10711071
auto state = std::make_unique<common_speculative_state_ngram_simple>(
1072-
/* .type = */ config.type,
1073-
/* .state = */ config_simple
1072+
/* .type = */ config.type,
1073+
/* .state = */ config_simple
10741074
);
10751075
impls.push_back(std::move(state));
10761076
break;

common/speculative.h

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,59 @@
55

66
struct common_speculative;
77

8+
// comma separated list the provided types
9+
std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types);
10+
811
// comma separated list of all types
9-
std::string common_speculative_type_name_str();
12+
const char * common_speculative_all_types_str();
13+
14+
// parse user provided types
15+
std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names);
1016

1117
// convert string to type
1218
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
1319

1420
// convert type to string
1521
std::string common_speculative_type_to_str(enum common_speculative_type type);
1622

17-
common_speculative * common_speculative_init(
18-
common_params_speculative & params,
19-
llama_context * ctx_tgt);
23+
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);
2024

2125
void common_speculative_free(common_speculative * spec);
2226

27+
struct common_speculative_draft_params {
28+
// this flag is used to chain the drafts through all the available implementations
29+
// after the first successful draft from an implementation, we set it
30+
// to false to prevent further drafts for that sequence
31+
// at the end of the draft() call, all drafting flags will be reset to false
32+
bool drafting = false;
33+
34+
// overrides individual configurations (-1 disabled)
35+
// can be used to constraint the max draft based on the remaining context size
36+
int32_t n_max = -1;
37+
38+
llama_pos n_past;
39+
llama_token id_last;
40+
41+
// TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls
42+
const llama_tokens * prompt;
43+
44+
// the generated draft from the last _draft() call
45+
llama_tokens * result;
46+
};
47+
48+
common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id);
49+
2350
// optionally call once at the beginning of a new generation
24-
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
51+
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
2552

26-
// sample up to n_draft tokens and add them to the batch using the draft model
27-
llama_tokens common_speculative_draft(
28-
common_speculative * spec,
29-
const common_params_speculative & params,
30-
const llama_tokens & prompt,
31-
llama_token id_last);
53+
// process the batch and update the internal state of the speculative context
54+
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
3255

33-
// informs the speculative decoder that n_accepted tokens were accepted by the target model
34-
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
56+
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
57+
void common_speculative_draft(common_speculative * spec);
3558

36-
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
37-
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
59+
// informs the speculative context that n_accepted tokens were accepted by the target model
60+
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
3861

3962
// print statistics about the speculative decoding
4063
void common_speculative_print_stats(const common_speculative * spec);

0 commit comments

Comments
 (0)