Skip to content

Commit 35a41cb

Browse files
srogmannggerganovCISC
authored andcommitted
spec : add self‑speculative decoding (no draft model required) + refactor (ggml-org#18471)
* server: introduce self-speculative decoding * server: moved self-call into speculative.cpp * can_speculate() includes self-speculation Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: can_speculate() tests self-spec * server: replace can_speculate() with slot.can_speculate() Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * common: use %zu format specifier for size_t in logging Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * server: can_speculate() requires a task instance * common: ngram map, config self-speculative decoding * common: add enum common_speculative_type * common: add vector of speculative states * common: add option --spec-draftless * server: cleanup (remove slot.batch_spec, rename) * common: moved self-spec impl to ngram-map * common: cleanup (use common_speculative_state_draft) * spec : refactor * cont : naming * spec: remove --spec-config * doc: (draftless) speculative decoding * common: print performance in spec decoding * minor : cleanup * common : better names * minor : cleanup + fix build * minor: comments * CODEOWNERS: add common/ngram-map.* (ggml-org#18471) * common : rename speculative.draftless_type -> speculative.type * ngram-map : fix uninitialized values * ngram-map : take into account the input can become shorter * ngram-map : revert len check for now * arg : change `--spec-draftless` -> `--spec-type` * spec : add common_speculative_state::accept() * spec : refactor + add common_speculative_begin() * spec : fix begin() call with mtmd * spec : additional refactor + remove common_speculative_params --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent fe13ab6 commit 35a41cb

19 files changed

Lines changed: 1640 additions & 435 deletions

CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
/common/jinja/ @ngxson @CISC @aldehir
1919
/common/llguidance.* @ggerganov
2020
/common/log.* @ggerganov
21+
/common/ngram-map.* @srogmann
2122
/common/peg-parser.* @aldehir
2223
/common/sampling.* @ggerganov
2324
/common/speculative.* @ggerganov

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
7373
log.h
7474
ngram-cache.cpp
7575
ngram-cache.h
76+
ngram-map.cpp
77+
ngram-map.h
7678
peg-parser.cpp
7779
peg-parser.h
7880
preset.cpp

common/arg.cpp

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "json-schema-to-grammar.h"
77
#include "log.h"
88
#include "sampling.h"
9+
#include "speculative.h"
910
#include "preset.h"
1011

1112
// fix problem with std::min and std::max
@@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
579580
params.mmproj = res.mmproj;
580581
}
581582
// only download mmproj if the current example is using it
582-
for (auto & ex : mmproj_examples) {
583+
for (const auto & ex : mmproj_examples) {
583584
if (ctx_arg.ex == ex) {
584585
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
585586
break;
586587
}
587588
}
588-
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
589-
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
589+
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
590+
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
590591
}
591592

592593
// model is required (except for server)
@@ -1216,16 +1217,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12161217
{"-lcs", "--lookup-cache-static"}, "FNAME",
12171218
"path to static lookup cache to use for lookup decoding (not updated by generation)",
12181219
[](common_params & params, const std::string & value) {
1219-
params.lookup_cache_static = value;
1220+
params.speculative.lookup_cache_static = value;
12201221
}
1221-
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
1222+
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
12221223
add_opt(common_arg(
12231224
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
12241225
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
12251226
[](common_params & params, const std::string & value) {
1226-
params.lookup_cache_dynamic = value;
1227+
params.speculative.lookup_cache_dynamic = value;
12271228
}
1228-
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
1229+
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
12291230
add_opt(common_arg(
12301231
{"-c", "--ctx-size"}, "N",
12311232
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
@@ -2563,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
25632564
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
25642565
"Same as --hf-repo, but for the draft model (default: unused)",
25652566
[](common_params & params, const std::string & value) {
2566-
params.speculative.model.hf_repo = value;
2567+
params.speculative.mparams_dft.hf_repo = value;
25672568
}
25682569
).set_env("LLAMA_ARG_HFD_REPO"));
25692570
add_opt(common_arg(
@@ -3384,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
33843385
{"-md", "--model-draft"}, "FNAME",
33853386
"draft model for speculative decoding (default: unused)",
33863387
[](common_params & params, const std::string & value) {
3387-
params.speculative.model.path = value;
3388+
params.speculative.mparams_dft.path = value;
33883389
}
33893390
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
33903391
add_opt(common_arg(
@@ -3394,6 +3395,66 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
33943395
params.speculative.replacements.push_back({ tgt, dft });
33953396
}
33963397
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
3398+
add_opt(common_arg(
3399+
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]",
3400+
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
3401+
common_speculative_type_to_str(params.speculative.type).c_str()),
3402+
[](common_params & params, const std::string & value) {
3403+
if (value == "none") {
3404+
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
3405+
} else if (value == "ngram-cache") {
3406+
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
3407+
} else if (value == "ngram-simple") {
3408+
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
3409+
} else if (value == "ngram-map-k") {
3410+
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
3411+
} else if (value == "ngram-map-k4v") {
3412+
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
3413+
} else {
3414+
throw std::invalid_argument("unknown speculative decoding type without draft model");
3415+
}
3416+
}
3417+
).set_examples({LLAMA_EXAMPLE_SERVER}));
3418+
add_opt(common_arg(
3419+
{"--spec-ngram-size-n"}, "N",
3420+
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
3421+
[](common_params & params, int value) {
3422+
if (value < 1 || value > 1024) {
3423+
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
3424+
}
3425+
params.speculative.ngram_size_n = value;
3426+
}
3427+
).set_examples({LLAMA_EXAMPLE_SERVER}));
3428+
add_opt(common_arg(
3429+
{"--spec-ngram-size-m"}, "N",
3430+
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
3431+
[](common_params & params, int value) {
3432+
if (value < 1 || value > 1024) {
3433+
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
3434+
}
3435+
params.speculative.ngram_size_m = value;
3436+
}
3437+
).set_examples({LLAMA_EXAMPLE_SERVER}));
3438+
add_opt(common_arg(
3439+
{"--spec-ngram-check-rate"}, "N",
3440+
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
3441+
[](common_params & params, int value) {
3442+
if (value < 1) {
3443+
throw std::invalid_argument("ngram check rate must be at least 1");
3444+
}
3445+
params.speculative.ngram_check_rate = value;
3446+
}
3447+
).set_examples({LLAMA_EXAMPLE_SERVER}));
3448+
add_opt(common_arg(
3449+
{"--spec-ngram-min-hits"}, "N",
3450+
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
3451+
[](common_params & params, int value) {
3452+
if (value < 1) {
3453+
throw std::invalid_argument("ngram min hits must be at least 1");
3454+
}
3455+
params.speculative.ngram_min_hits = value;
3456+
}
3457+
).set_examples({LLAMA_EXAMPLE_SERVER}));
33973458
add_opt(common_arg(
33983459
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
33993460
string_format(
@@ -3620,8 +3681,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
36203681
[](common_params & params) {
36213682
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
36223683
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
3623-
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
3624-
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
3684+
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
3685+
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
36253686
params.port = 8012;
36263687
params.n_ubatch = 1024;
36273688
params.n_batch = 1024;
@@ -3636,8 +3697,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
36363697
[](common_params & params) {
36373698
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
36383699
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
3639-
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
3640-
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
3700+
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
3701+
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
36413702
params.port = 8012;
36423703
params.n_ubatch = 1024;
36433704
params.n_batch = 1024;

common/common.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
10971097
if (params.fit_params) {
10981098
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
10991099
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
1100-
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
1100+
params.tensor_split,
1101+
params.tensor_buft_overrides.data(),
1102+
params.fit_params_target.data(),
1103+
params.fit_params_min_ctx,
11011104
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
11021105
}
11031106

@@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
12081211
return pimpl->lora;
12091212
}
12101213

1211-
void common_init_result::free_context() {
1212-
pimpl->context.reset();
1213-
}
1214-
12151214
common_init_result_ptr common_init_from_params(common_params & params) {
12161215
common_init_result_ptr res(new common_init_result(params));
12171216

common/common.h

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ enum common_params_sampling_config : uint64_t {
164164
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
165165
};
166166

167+
enum common_speculative_type {
168+
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
169+
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
170+
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
171+
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
172+
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
173+
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
174+
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
175+
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
176+
};
167177

168178
// sampling parameters
169179
struct common_params_sampling {
@@ -243,24 +253,50 @@ struct common_params_model {
243253
};
244254

245255
struct common_params_speculative {
246-
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
256+
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
247257

248-
int32_t n_ctx = 0; // draft context size
249-
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
250-
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
251-
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
252-
float p_split = 0.1f; // speculative decoding split probability
253-
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
254-
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
255-
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
258+
// general-purpose speculative decoding parameters
259+
260+
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
261+
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
262+
float p_split = 0.1f; // speculative decoding split probability
263+
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
264+
265+
// ngram-based speculative decoding
266+
267+
uint16_t ngram_size_n = 12; // ngram size for lookup
268+
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
269+
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
270+
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
271+
272+
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
273+
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
274+
275+
// draft-model speculative decoding
276+
277+
struct common_params_model mparams_dft;
278+
279+
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
280+
281+
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
282+
283+
int32_t n_ctx = 0; // draft context size
284+
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
256285

257286
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
258287
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
259288

260289
struct cpu_params cpuparams;
261290
struct cpu_params cpuparams_batch;
262291

263-
struct common_params_model model;
292+
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
293+
294+
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
295+
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
296+
297+
bool has_dft() const {
298+
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
299+
}
264300
};
265301

266302
struct common_params_vocoder {
@@ -378,8 +414,6 @@ struct common_params {
378414
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
379415
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
380416
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
381-
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
382-
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
383417
std::string logits_file = ""; // file for saving *all* logits // NOLINT
384418

385419
// llama-debug specific options
@@ -575,10 +609,6 @@ struct common_params {
575609
// return false from callback to abort model loading or true to continue
576610
llama_progress_callback load_progress_callback = NULL;
577611
void * load_progress_callback_user_data = NULL;
578-
579-
bool has_speculative() const {
580-
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
581-
}
582612
};
583613

584614
// call once at the start of a program if it uses libcommon
@@ -714,8 +744,6 @@ struct common_init_result {
714744

715745
std::vector<llama_adapter_lora_ptr> & lora();
716746

717-
void free_context();
718-
719747
private:
720748
struct impl;
721749
std::unique_ptr<impl> pimpl;

common/ngram-cache.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,12 @@ void common_ngram_cache_draft(
192192
break;
193193
}
194194

195-
LOG(" - draft candidate: token=%d\n", drafted_token);
195+
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
196196
draft.push_back(drafted_token);
197197
}
198198
}
199199

200-
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
200+
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
201201
std::ofstream file_out(filename, std::ios::binary);
202202
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
203203
const common_ngram ngram = item.first;
@@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
217217
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
218218
}
219219
}
220-
221220
}
222221

223-
common_ngram_cache common_ngram_cache_load(std::string & filename) {
222+
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
224223
std::ifstream hashmap_file(filename, std::ios::binary);
225224
if (!hashmap_file) {
226225
throw std::ifstream::failure("Unable to open file " + filename);

common/ngram-cache.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ void common_ngram_cache_draft(
8888
// Save an ngram cache to a file.
8989
// ngram_cache: the ngram cache to save.
9090
// filename: the path under which to save the ngram cache.
91-
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
91+
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
9292

9393
// Load an ngram cache saved with common_ngram_cache_save.
9494
// filename: the path from which to load the ngram cache.
9595
// returns: an ngram cache containing the information saved to filename.
96-
common_ngram_cache common_ngram_cache_load(std::string & filename);
96+
common_ngram_cache common_ngram_cache_load(const std::string & filename);
9797

9898
// Merge two ngram caches.
9999
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.

0 commit comments

Comments
 (0)