Skip to content

Commit dad4e2e

Browse files
TheTomclaude
andcommitted
Cherry-pick upstream speculative decoding for hybrid models
Cherry-picks 4 upstream PRs to enable speculative decoding on hybrid MoE+SSM architectures (Qwen3.6-35B-A3B): - ggml-org#19493 — speculative checkpointing (save/restore recurrent state) - ggml-org#22114 — refactor "use checkpoint" logic - ggml-org#22168 — reset i_last on low acceptance streak - ggml-org#22223 — add --spec-default argument Smoke tested on M5 Max with turbo4 KV — zero regression. Co-Authored-By: tturney@psyguard.ai Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4d24ad8 commit dad4e2e

13 files changed

Lines changed: 481 additions & 227 deletions

common/arg.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
291291
hf_tag = "default";
292292
}
293293

294-
std::string model_endpoint = get_model_endpoint();
294+
std::string model_endpoint = common_get_model_endpoint();
295295
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
296296

297297
// prepare local path for caching
@@ -3890,6 +3890,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
38903890
}
38913891
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
38923892

3893+
add_opt(common_arg(
3894+
{"--spec-default"},
3895+
string_format("enable default speculative decoding config"),
3896+
[](common_params & params) {
3897+
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
3898+
params.speculative.ngram_size_n = 24;
3899+
params.speculative.n_min = 48;
3900+
params.speculative.n_max = 64;
3901+
}
3902+
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
3903+
38933904
return ctx_arg;
38943905
}
38953906

common/chat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2334,7 +2334,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
23342334
? input
23352335
: params.generation_prompt + input;
23362336

2337-
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
2337+
//LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
23382338

23392339
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
23402340
if (params.debug) {

common/common.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
13871387

13881388
common_init_result::~common_init_result() = default;
13891389

1390-
std::string get_model_endpoint() {
1390+
std::string common_get_model_endpoint() {
13911391
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
13921392
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
13931393
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
@@ -1402,6 +1402,42 @@ std::string get_model_endpoint() {
14021402
return model_endpoint;
14031403
}
14041404

1405+
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
1406+
auto * mem = llama_get_memory(ctx);
1407+
if (mem == nullptr) {
1408+
return COMMON_CONTEXT_SEQ_RM_TYPE_NO;
1409+
}
1410+
1411+
common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART;
1412+
1413+
llama_memory_clear(mem, true);
1414+
1415+
// eval 2 tokens to check if the context is compatible
1416+
std::vector<llama_token> tmp;
1417+
tmp.push_back(0);
1418+
tmp.push_back(0);
1419+
1420+
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
1421+
if (ret != 0) {
1422+
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
1423+
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
1424+
goto done;
1425+
}
1426+
1427+
// try to remove the last tokens
1428+
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
1429+
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
1430+
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
1431+
goto done;
1432+
}
1433+
1434+
done:
1435+
llama_memory_clear(mem, true);
1436+
llama_synchronize(ctx);
1437+
1438+
return res;
1439+
}
1440+
14051441
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
14061442
std::vector<llama_adapter_lora *> loras;
14071443
std::vector<float> scales;

common/common.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <sstream>
1111
#include <string>
1212
#include <string_view>
13-
#include <variant>
1413
#include <vector>
1514
#include <map>
1615

@@ -315,15 +314,15 @@ struct common_params_speculative {
315314
// general-purpose speculative decoding parameters
316315

317316
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
318-
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
317+
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
319318
float p_split = 0.1f; // speculative decoding split probability
320319
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
321320

322321
// ngram-based speculative decoding
323322

324-
uint16_t ngram_size_n = 12; // ngram size for lookup
325-
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
326-
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
323+
uint16_t ngram_size_n = 12; // ngram size for lookup
324+
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
325+
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
327326

328327
std::shared_ptr<common_ngram_mod> ngram_mod;
329328

@@ -859,7 +858,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
859858
// clear LoRA adapters from context, then apply new list of adapters
860859
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
861860

862-
std::string get_model_endpoint();
861+
// model endpoint from env
862+
std::string common_get_model_endpoint();
863+
864+
//
865+
// Context utils
866+
//
867+
868+
enum common_context_seq_rm_type {
869+
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
870+
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
871+
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
872+
};
873+
874+
// check if the llama_context can remove sequences
875+
// note: clears the memory of the context
876+
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);
877+
863878

864879
//
865880
// Batch utils

common/hf-cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ static nl::json api_get(const std::string & url,
229229
static std::string get_repo_commit(const std::string & repo_id,
230230
const std::string & token) {
231231
try {
232-
auto endpoint = get_model_endpoint();
232+
auto endpoint = common_get_model_endpoint();
233233
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
234234

235235
if (!json.is_object() ||
@@ -307,7 +307,7 @@ hf_files get_repo_files(const std::string & repo_id,
307307
hf_files files;
308308

309309
try {
310-
auto endpoint = get_model_endpoint();
310+
auto endpoint = common_get_model_endpoint();
311311
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
312312

313313
if (!json.is_array()) {

common/ngram-map.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ void common_ngram_map_begin(
208208
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
209209
}
210210

211-
map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
211+
map.idx_last_check = size_begin;
212212
map.size_last_begin = size_begin;
213213
}
214214

@@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map,
231231
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
232232
}
233233

234-
if (map.idx_last_check > cur_len) {
234+
if (map.idx_last_check > cur_len) {
235235
// Should not happen because of common_ngram_map_begin().
236236
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
237237
}
@@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
386386
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
387387
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
388388

389-
map.last_draft_created = false;
389+
map.last_draft_created = true;
390390
map.last_draft_key_idx = key_offset;
391391
map.last_draft_value_idx = 0; // value 0 is used for simple mode
392392
return;
@@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
524524
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
525525

526526
// update the value statistics
527-
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
527+
LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
528528
n_accepted, curr_value.n_accepted);
529529
curr_value.n_accepted = n_accepted;
530530
}

0 commit comments

Comments
 (0)