Skip to content

Commit 029ca70

Browse files
committed
server : renamed spec checkpoints option
1 parent 3c9289a commit 029ca70

4 files changed

Lines changed: 16 additions & 13 deletions

File tree

common/arg.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3465,13 +3465,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34653465
}
34663466
).set_examples({LLAMA_EXAMPLE_SERVER}));
34673467
add_opt(common_arg(
3468-
{"--spec-ckpt-num-tries"}, "N",
3469-
string_format("number of tries for speculative decoding with recurrent memory (default: %d)", params.speculative.ckpt_num_tries),
3470-
[](common_params & params, int value) {
3471-
if (value < 0 || value > 10) {
3472-
throw std::invalid_argument("number of tries must be between 0 and 10 inclusive");
3468+
{"--spec-use-checkpoints"}, "[on|off|auto]",
3469+
string_format("use checkpoints to rewind token history in recurrent models ('on', 'off', or 'auto', default: %s)",
3470+
params.speculative.use_checkpoints ? "on" : "off"),
3471+
[](common_params & params, const std::string & value) {
3472+
if (is_truthy(value) || is_autoy(value)) {
3473+
params.speculative.use_checkpoints = true;
3474+
} else if (is_falsey(value)) {
3475+
params.speculative.use_checkpoints = false;
3476+
} else {
3477+
throw std::invalid_argument("invalid value for --spec-use-checkpoints");
34733478
}
3474-
params.speculative.ckpt_num_tries = value;
34753479
}
34763480
).set_examples({LLAMA_EXAMPLE_SERVER}));
34773481
add_opt(common_arg(

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ struct common_params_speculative {
270270
uint16_t ngram_size_n = 12; // ngram size for lookup
271271
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
272272
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
273-
uint16_t ckpt_num_tries = 0; // number of tries in case of recurrent memory
273+
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
274+
274275

275276
std::shared_ptr<common_ngram_mod> ngram_mod;
276277

common/speculative.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,8 +1136,8 @@ struct common_speculative_session::impl {
11361136
clear_draft();
11371137
return draft;
11381138
}
1139-
if (params_spec.ckpt_num_tries > 0
1140-
&& spec_ckpt_n_denials >= params_spec.ckpt_num_tries) {
1139+
if (params_spec.use_checkpoints
1140+
&& spec_ckpt_n_denials > 0) {
11411141
clear_draft();
11421142
return draft;
11431143
}
@@ -1166,7 +1166,7 @@ struct common_speculative_session::impl {
11661166
draft.resize(n_draft_max);
11671167
}
11681168

1169-
bool do_checkpoint = !draft.empty() && params_spec.ckpt_num_tries > 0;
1169+
bool do_checkpoint = !draft.empty() && params_spec.use_checkpoints;
11701170
if (do_checkpoint && cached_text_tokens.size() > 5) {
11711171
LOG_DBG("draft.size = %zu, n_spec_denials = %d, do_checkpoint = %s, tokens=[..., %d, %d, %d]\n",
11721172
draft.size(), spec_ckpt_n_denials,
@@ -1235,8 +1235,6 @@ struct common_speculative_session::impl {
12351235
return common_speculative_accept_response(std::move(ids), n_draft, true);
12361236
}
12371237

1238-
//spec_ckpt_n_accepted = (spec_ckpt_n_denials < params_spec.ckpt_num_tries) ? (int) (ids.size() - 1) : 0;
1239-
12401238
callback.batch_clear();
12411239
}
12421240
}

tools/server/server-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ struct server_context_impl {
838838
slot.prompt.tokens.has_mtmd = mctx != nullptr;
839839

840840
// try speculative decoding
841-
if (can_spec || params_base.speculative.ckpt_num_tries > 0) {
841+
if (can_spec || params_base.speculative.use_checkpoints) {
842842
if (mctx) {
843843
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
844844
return false;

0 commit comments

Comments
 (0)