Skip to content

Commit edc9b88

Browse files
committed
server : fix spec checkpoints, logging
1 parent 029ca70 commit edc9b88

3 files changed

Lines changed: 50 additions & 42 deletions

File tree

common/speculative.cpp

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,21 +1136,24 @@ struct common_speculative_session::impl {
11361136
clear_draft();
11371137
return draft;
11381138
}
1139-
if (params_spec.use_checkpoints
1140-
&& spec_ckpt_n_denials > 0) {
1139+
if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1) {
1140+
// We shouldn't get two denials.
1141+
LOG_WRN("%s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n", __func__,
1142+
cached_text_tokens.size(), spec_ckpt_n_denials, id_last, draft.size());
11411143
clear_draft();
11421144
return draft;
11431145
}
11441146

1145-
if (spec_ckpt_n_denials > 0) {
1147+
if (spec_ckpt_n_denials == 1) {
11461148
// there is a previous speculation which wasn't accepted in full length
11471149
if (draft.empty()) {
11481150
LOG_WRN("%s: draft of length 0 after denied checkpoint\n", __func__);
11491151
clear_draft();
11501152
return draft;
11511153
}
11521154
// we use the shortened draft of previous speculation
1153-
LOG_INF("%s: resuse shortened draft, size=%zu\n", __func__, draft.size());
1155+
LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__,
1156+
cached_text_tokens.size(), id_last, draft.size());
11541157
} else {
11551158
// call the speculative implementation to create a draft
11561159
draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last);
@@ -1167,32 +1170,35 @@ struct common_speculative_session::impl {
11671170
}
11681171

11691172
bool do_checkpoint = !draft.empty() && params_spec.use_checkpoints;
1170-
if (do_checkpoint && cached_text_tokens.size() > 5) {
1171-
LOG_DBG("draft.size = %zu, n_spec_denials = %d, do_checkpoint = %s, tokens=[..., %d, %d, %d]\n",
1173+
if (do_checkpoint && cached_text_tokens.size() > 5 && draft.size() >= 3) {
1174+
LOG_DBG("%s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n",
1175+
__func__,
1176+
cached_text_tokens.size(),
11721177
draft.size(), spec_ckpt_n_denials,
1173-
do_checkpoint ? "yes" : "no",
1178+
do_checkpoint ? "yes" : "no", id_last,
11741179
cached_text_tokens[cached_text_tokens.size() - 3],
11751180
cached_text_tokens[cached_text_tokens.size() - 2],
1176-
cached_text_tokens[cached_text_tokens.size() - 1]);
1181+
cached_text_tokens[cached_text_tokens.size() - 1],
1182+
draft[0], draft[1], draft[2]);
1183+
}
1184+
1185+
if (params_spec.n_min > (int) draft.size()) {
1186+
LOG_DBG("ignoring small draft: %d < %d\n", (int) draft.size(), params_spec.n_min);
1187+
clear_draft();
1188+
return draft;
11771189
}
11781190

11791191
if (do_checkpoint) {
11801192
const size_t n = callback.create_checkpoint();
11811193
if (n == 0) {
1182-
LOG_WRN("checkpoint creation failed");
1194+
LOG_WRN("%s: checkpoint creation failed (#tokens=%zu)\n", __func__, cached_text_tokens.size());
11831195
clear_draft();
11841196
return draft;
11851197
}
11861198
spec_ckpt_size_part = n;
11871199
spec_has_ckpt = true;
11881200
}
11891201

1190-
if (params_spec.n_min > (int) draft.size()) {
1191-
LOG_DBG("ignoring small draft: %d < %d\n", (int) draft.size(), params_spec.n_min);
1192-
clear_draft();
1193-
return draft;
1194-
}
1195-
11961202
// add last sampled token to the batch
11971203
callback.batch_add_token(id_last, true);
11981204

@@ -1219,27 +1225,31 @@ struct common_speculative_session::impl {
12191225
if (spec_has_ckpt) {
12201226
// we need to rollback to the state before sampling the draft tokens
12211227
const size_t n = callback.restore_checkpoint(spec_ckpt_size_part);
1222-
LOG_INF("partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
1223-
ids.size() -1 , n_draft, n);
1228+
LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
1229+
__func__,
1230+
ids.size() - 1, n_draft, n);
12241231

1225-
// rollback to the state before sampling the draft tokens
1226-
1227-
// Delete Checkpoint
1232+
// delete Checkpoint
12281233
callback.delete_checkpoint();
12291234
spec_has_ckpt = false;
12301235

1231-
if (n_draft > 0 && spec_ckpt_n_denials == 0) {
1236+
spec_ckpt_n_denials++;
1237+
if (ids.size() > 1u + static_cast<std::size_t>(params_spec.n_min) && spec_ckpt_n_denials == 1) {
12321238
// we will do the batch again but with the shortened draft
1233-
spec_ckpt_n_denials++;
1234-
12351239
return common_speculative_accept_response(std::move(ids), n_draft, true);
12361240
}
12371241

1238-
callback.batch_clear();
1242+
LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
1243+
draft.clear();
1244+
1245+
// use the sampled token only
1246+
ids.resize(1);
1247+
// drafted tokens in prompt have been deleted in restore_checkpoint(...).
1248+
return common_speculative_accept_response{std::move(ids), 0, false};
12391249
}
12401250
}
12411251
const size_t draft_size_accepted = draft.size();
1242-
LOG_DBG("%s: draft.size=%zu\n", __func__, draft_size_accepted);
1252+
LOG_DBG("%s: draft.size=%zu, ids.size=%zu\n", __func__, draft_size_accepted, ids.size());
12431253
common_speculative_accept(spec, draft_size_accepted);
12441254
draft.clear();
12451255

common/speculative.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ struct common_speculative_callback {
6161
// Add a token to the draft sequence.
6262
virtual void batch_add_token(const llama_token token, bool logits) = 0;
6363

64-
// Clears the batch context.
65-
virtual void batch_clear() = 0;
66-
6764
// Sample and accept tokens from the main model.
6865
virtual llama_tokens sampler_sample_and_accept_n(const llama_tokens & drafted) = 0;
6966

tools/server/server-context.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
#include "server-context.h"
23
#include "server-common.h"
34
#include "server-http.h"
@@ -55,7 +56,7 @@ struct server_slot {
5556
mtmd_context * mctx = nullptr;
5657

5758
std::unique_ptr<common_speculative_callback> spec_callback;
58-
common_speculative_session * spec_session = nullptr;
59+
std::unique_ptr<common_speculative_session> spec_session = nullptr;
5960

6061
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
6162
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
@@ -638,10 +639,6 @@ struct server_context_impl {
638639
slot.prompt.tokens.push_back(token);
639640
}
640641

641-
void batch_clear() override {
642-
common_batch_clear(ctx_impl.batch);
643-
}
644-
645642
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
646643
if (slot.i_batch_dft.size() != 1 + drafted.size()) {
647644
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
@@ -662,15 +659,15 @@ struct server_context_impl {
662659
const auto & cur_with_size = ctx_impl.get_checkpoint(slot, pos_min, pos_max);
663660
auto & cur = cur_with_size.checkpoint;
664661

665-
SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
662+
SLT_DBG(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
666663
(int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
667664
return cur_with_size.size;
668665
}
669666

670667
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
671668
auto & ckpt = slot.prompt.checkpoints.back();
672669

673-
SLT_INF(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
670+
SLT_DBG(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
674671
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
675672
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
676673
if (n != ckpt_size_part_expected) {
@@ -844,7 +841,7 @@ struct server_context_impl {
844841
return false;
845842
}
846843
slot.spec_callback = std::make_unique<server_speculative_callback>(slot, *this);
847-
slot.spec_session = new common_speculative_session(*slot.spec_callback,
844+
slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback,
848845
params_base.speculative, slot.ctx);
849846
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
850847
}
@@ -2156,12 +2153,16 @@ struct server_context_impl {
21562153
// generate draft tokens in speculative decoding mode
21572154
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
21582155
// perform the speculative drafting for all sequences at the same time in a single batch
2156+
llama_tokens draft;
21592157
const int n_draft_max_slot = slot.get_n_draft_max();
2160-
2161-
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
2162-
llama_tokens draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot);
2163-
if (draft.size() > 0) {
2164-
SLT_DBG(slot, "compute_draft: #tokens=%d\n", (int) draft.size());
2158+
if (n_draft_max_slot > 0) {
2159+
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
2160+
// compute draft and add draft to internal batch
2161+
draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot);
2162+
if (draft.size() > 0) {
2163+
SLT_DBG(slot, "compute_draft: #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n",
2164+
cached_text_tokens.size(), draft.size(), slot.i_batch_dft.size());
2165+
}
21652166
}
21662167

21672168
if (draft.empty()) {
@@ -2857,7 +2858,7 @@ struct server_context_impl {
28572858
slot.i_batch_dft.clear();
28582859
const size_t n_draft = accept_response.draft_size_initial;
28592860
if (accept_response.skip_acceptance) {
2860-
SLT_INF(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft);
2861+
SLT_DBG(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft);
28612862
continue;
28622863
}
28632864
const auto ids = accept_response.tokens;

0 commit comments

Comments
 (0)