Skip to content

Commit de71b5f

Browse files
authored
server : refactor "use checkpoint" logic (#22114)
1 parent 788fcbc commit de71b5f

7 files changed

Lines changed: 93 additions & 92 deletions

File tree

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
292292
hf_tag = "default";
293293
}
294294

295-
std::string model_endpoint = get_model_endpoint();
295+
std::string model_endpoint = common_get_model_endpoint();
296296
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
297297

298298
// prepare local path for caching

common/common.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
13821382

13831383
common_init_result::~common_init_result() = default;
13841384

1385-
std::string get_model_endpoint() {
1385+
std::string common_get_model_endpoint() {
13861386
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
13871387
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
13881388
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
@@ -1397,6 +1397,42 @@ std::string get_model_endpoint() {
13971397
return model_endpoint;
13981398
}
13991399

1400+
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
1401+
auto * mem = llama_get_memory(ctx);
1402+
if (mem == nullptr) {
1403+
return COMMON_CONTEXT_SEQ_RM_TYPE_NO;
1404+
}
1405+
1406+
common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART;
1407+
1408+
llama_memory_clear(mem, true);
1409+
1410+
// eval 2 tokens to check if the context is compatible
1411+
std::vector<llama_token> tmp;
1412+
tmp.push_back(0);
1413+
tmp.push_back(0);
1414+
1415+
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
1416+
if (ret != 0) {
1417+
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
1418+
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
1419+
goto done;
1420+
}
1421+
1422+
// try to remove the last tokens
1423+
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
1424+
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
1425+
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
1426+
goto done;
1427+
}
1428+
1429+
done:
1430+
llama_memory_clear(mem, true);
1431+
llama_synchronize(ctx);
1432+
1433+
return res;
1434+
}
1435+
14001436
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
14011437
std::vector<llama_adapter_lora *> loras;
14021438
std::vector<float> scales;

common/common.h

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,9 @@ struct common_params_speculative {
308308

309309
// ngram-based speculative decoding
310310

311-
uint16_t ngram_size_n = 12; // ngram size for lookup
312-
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
313-
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
314-
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
311+
uint16_t ngram_size_n = 12; // ngram size for lookup
312+
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
313+
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
315314

316315
std::shared_ptr<common_ngram_mod> ngram_mod;
317316

@@ -847,7 +846,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
847846
// clear LoRA adapters from context, then apply new list of adapters
848847
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
849848

850-
std::string get_model_endpoint();
849+
// model endpoint from env
850+
std::string common_get_model_endpoint();
851+
852+
//
853+
// Context utils
854+
//
855+
856+
enum common_context_seq_rm_type {
857+
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
858+
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
859+
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
860+
};
861+
862+
// check if the llama_context can remove sequences
863+
// note: clears the memory of the context
864+
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);
865+
851866

852867
//
853868
// Batch utils

common/hf-cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ static nl::json api_get(const std::string & url,
230230
static std::string get_repo_commit(const std::string & repo_id,
231231
const std::string & token) {
232232
try {
233-
auto endpoint = get_model_endpoint();
233+
auto endpoint = common_get_model_endpoint();
234234
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
235235

236236
if (!json.is_object() ||
@@ -308,7 +308,7 @@ hf_files get_repo_files(const std::string & repo_id,
308308
hf_files files;
309309

310310
try {
311-
auto endpoint = get_model_endpoint();
311+
auto endpoint = common_get_model_endpoint();
312312
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
313313

314314
if (!json.is_array()) {

common/speculative.cpp

Lines changed: 14 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state {
164164
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
165165
llama_context * ctx_dft;
166166

167+
bool use_ckpt = false;
167168
struct common_speculative_checkpoint ckpt;
168-
bool use_checkpoint;
169169

170170
common_sampler * smpl;
171171

@@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state {
180180
llama_context * ctx_tgt,
181181
llama_context * ctx_dft,
182182
const std::vector<std::pair<std::string, std::string>> & replacements,
183-
bool use_checkpoint)
183+
bool use_ckpt)
184184
: common_speculative_state(type)
185185
, ctx_tgt(ctx_tgt)
186186
, ctx_dft(ctx_dft)
187-
, use_checkpoint(use_checkpoint)
187+
, use_ckpt(use_ckpt)
188188
{
189189
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
190190
smpl = nullptr;
@@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state {
239239
}
240240

241241
void begin(const llama_tokens & prompt) override {
242-
if (use_checkpoint && ckpt.size() > 0) {
242+
if (use_ckpt && ckpt.size() > 0) {
243243
// delete checkpoint
244244
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
245245
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
@@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state {
351351

352352
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
353353
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
354-
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
354+
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
355355
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
356356
__func__, reuse_i, reuse_n);
357357
reuse_i = 0;
@@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state {
361361
result.clear();
362362
result.reserve(params.n_max);
363363

364-
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
365-
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
364+
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
365+
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
366366
llama_memory_clear(mem_dft, false);
367367
prompt_dft.clear();
368368
} else {
@@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state {
400400
}
401401

402402
if (reuse_n < (int) prompt_dft.size() || do_restore) {
403-
if (use_checkpoint) {
403+
if (use_ckpt) {
404404
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
405405
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
406406
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
@@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
912912
return it->second;
913913
}
914914

915-
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
916-
auto * mem = llama_get_memory(ctx_tgt);
917-
if (mem == nullptr) {
918-
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
919-
}
920-
921-
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
922-
923-
llama_memory_clear(mem, true);
924-
925-
// eval 2 tokens to check if the context is compatible
926-
std::vector<llama_token> tmp;
927-
tmp.push_back(0);
928-
tmp.push_back(0);
929-
930-
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
931-
if (ret != 0) {
932-
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
933-
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
934-
goto done;
935-
}
936-
937-
// try to remove the last tokens
938-
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
939-
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
940-
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
941-
goto done;
942-
}
943-
944-
done:
945-
llama_memory_clear(mem, true);
946-
llama_synchronize(ctx_tgt);
947-
948-
return res;
949-
}
950-
951915
// initialization of the speculative decoding system
952916
//
953917
common_speculative * common_speculative_init(
@@ -1022,11 +986,13 @@ common_speculative * common_speculative_init(
1022986
case COMMON_SPECULATIVE_TYPE_NONE:
1023987
break;
1024988
case COMMON_SPECULATIVE_TYPE_DRAFT: {
989+
const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
990+
1025991
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
1026-
/* .ctx_tgt = */ ctx_tgt,
1027-
/* .ctx_dft = */ ctx_dft,
1028-
/* .replacements = */ params.replacements,
1029-
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
992+
/* .ctx_tgt = */ ctx_tgt,
993+
/* .ctx_dft = */ ctx_dft,
994+
/* .replacements = */ params.replacements,
995+
/* .use_ckpt = */ use_ckpt
1030996
));
1031997
break;
1032998
}

common/speculative.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
1414
// convert type to string
1515
std::string common_speculative_type_to_str(enum common_speculative_type type);
1616

17-
enum common_speculative_compat_type {
18-
COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0,
19-
COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1,
20-
COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2,
21-
};
22-
23-
// check if the llama_context is compatible for speculative decoding
24-
// note: clears the memory of the context
25-
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt);
26-
2717
common_speculative * common_speculative_init(
2818
common_params_speculative & params,
2919
llama_context * ctx_tgt);

tools/server/server-context.cpp

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@ enum server_state {
7878
struct server_slot {
7979
int id;
8080

81-
// TODO: change to unique_ptrs for consistency:
8281
llama_context * ctx = nullptr;
8382

83+
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
84+
8485
// multimodal
8586
mtmd_context * mctx = nullptr;
8687

@@ -90,7 +91,6 @@ struct server_slot {
9091
server_prompt_checkpoint spec_ckpt;
9192
common_speculative_ptr spec;
9293

93-
9494
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
9595
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
9696
std::unique_ptr<const server_task> task;
@@ -343,7 +343,7 @@ struct server_slot {
343343

344344
if (!spec_draft.empty()) {
345345
// we have a previous (partial) draft to reuse
346-
if (task->params.speculative.use_checkpoints) {
346+
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
347347
GGML_ASSERT(!spec_ckpt.empty());
348348
}
349349
} else {
@@ -362,15 +362,13 @@ struct server_slot {
362362
spec_draft.clear();
363363
}
364364

365-
if (!spec_draft.empty() && params_spec.use_checkpoints) {
365+
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
366366
const auto n_tokens = prompt.tokens.size();
367367

368-
auto & ckpt = spec_ckpt;
369-
370-
ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
368+
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
371369

372370
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
373-
ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024);
371+
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
374372
}
375373
}
376374

@@ -871,14 +869,13 @@ struct server_context_impl {
871869

872870
slots.clear();
873871

874-
const auto spec_type = common_speculative_is_compat(ctx);
875-
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
872+
const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
873+
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
876874
SRV_WRN("%s", "speculative decoding not supported by this context\n");
877875
}
878876

879-
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) {
877+
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
880878
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
881-
params_base.speculative.use_checkpoints = true;
882879
}
883880

884881
// initialize slots
@@ -893,11 +890,13 @@ struct server_context_impl {
893890
slot.ctx = ctx;
894891
slot.n_ctx = n_ctx_slot;
895892

893+
slot.ctx_seq_rm_type = ctx_seq_rm_type;
894+
896895
slot.mctx = mctx;
897896
slot.prompt.tokens.has_mtmd = mctx != nullptr;
898897

899898
// try speculative decoding
900-
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
899+
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
901900
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
902901

903902
if (slot.spec) {
@@ -2588,15 +2587,11 @@ struct server_context_impl {
25882587

25892588
// make a checkpoint of the parts of the memory that cannot be rolled back.
25902589
// checkpoints are created only if:
2590+
// - the model does not support partial sequence removal
25912591
// - the model uses SWA and we are not using `swa_full`
2592-
// - the model architecture is marked as recurrent or hybrid
2593-
//
2594-
// TODO: try to make this conditional on the context or the memory module, instead of the model type
25952592
do_checkpoint = do_checkpoint && (
2596-
llama_model_is_recurrent(model) ||
2597-
llama_model_is_hybrid(model) ||
2598-
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
2599-
);
2593+
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
2594+
(llama_model_n_swa(model) > 0 && !params_base.swa_full));
26002595

26012596
bool has_mtmd = false;
26022597

@@ -2965,8 +2960,6 @@ struct server_context_impl {
29652960

29662961
// verify and try to accept the draft
29672962
{
2968-
const auto & params_spec = slot.task->params.speculative;
2969-
29702963
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
29712964

29722965
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
@@ -2979,13 +2972,14 @@ struct server_context_impl {
29792972

29802973
// check for partial draft acceptance
29812974
if (accepted.size() < slot.spec_draft.size() + 1) {
2982-
if (params_spec.use_checkpoints) {
2975+
if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
29832976
// partial acceptance is not supported by the context -> truncate the draft and restore the state
29842977
slot.spec_draft = std::move(accepted);
29852978

2986-
auto & ckpt = slot.spec_ckpt;
2979+
const auto & ckpt = slot.spec_ckpt;
29872980

2988-
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
2981+
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
2982+
ckpt.pos_min, ckpt.pos_max, ckpt.size());
29892983

29902984
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
29912985
if (n != ckpt.size()) {

0 commit comments

Comments
 (0)