Skip to content

Commit ca9e370

Browse files
committed
cont : clean-up
1 parent cfb6c4c commit ca9e370

6 files changed

Lines changed: 100 additions & 93 deletions

File tree

common/speculative.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ struct common_speculative_state_draft : public common_speculative_state {
165165
llama_context * ctx_dft;
166166

167167
struct common_speculative_checkpoint ckpt;
168-
bool use_checkpoint;
168+
bool use_checkpoint;
169169

170170
common_sampler * smpl;
171171

@@ -401,7 +401,7 @@ struct common_speculative_state_draft : public common_speculative_state {
401401
if (reuse_n < (int) prompt_dft.size() || do_restore) {
402402
if (use_checkpoint) {
403403
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
404-
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%zu, reuse_n=%d, prompt_dft.size=%zu\n",
404+
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
405405
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
406406
}
407407
draft_restore_checkpoint(ckpt.ckpt_size);
@@ -1207,36 +1207,36 @@ struct common_speculative_session::impl {
12071207
llama_tokens draft;
12081208

12091209
// use of checkpoints in speculative mode
1210-
bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created
1211-
uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position (0 or 1)
1212-
size_t spec_ckpt_size_part = 0; // size of partial checkpoint
1210+
bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created
1211+
uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position (0 or 1)
12131212

12141213
// Speculative decoding stats
12151214
int32_t n_draft_total = 0; // Total draft tokens generated
12161215
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
12171216

1218-
impl(common_speculative_callback & callback,
1217+
impl(
12191218
const common_params_speculative & params,
1219+
common_speculative_callback & callback,
12201220
llama_context * ctx_tgt)
12211221
: callback(callback), params_spec(params), ctx_tgt(ctx_tgt) {
12221222
spec = common_speculative_init(params_spec, ctx_tgt);
12231223
}
12241224

1225-
void begin(const llama_tokens & prompt_history) {
1225+
void begin(const llama_tokens & prompt_history) const {
12261226
common_speculative_begin(spec, prompt_history);
12271227
}
12281228

1229-
bool has_batch_dft() {
1229+
bool has_batch_dft() const {
12301230
return !draft.empty();
12311231
}
12321232

12331233
void leave_draft_state() {
12341234
draft.clear();
1235-
spec_ckpt_n_denials = 0;
1235+
spec_ckpt_n_denials = 0;
12361236
}
12371237

12381238
llama_tokens compute_draft(
1239-
const llama_tokens & cached_text_tokens,
1239+
const llama_tokens & tokens,
12401240
llama_token id_last,
12411241
const int n_draft_max) {
12421242
if (spec == nullptr) {
@@ -1249,10 +1249,11 @@ struct common_speculative_session::impl {
12491249
leave_draft_state();
12501250
return draft;
12511251
}
1252+
12521253
if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1) {
12531254
// We shouldn't get two denials.
12541255
LOG_WRN("%s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n", __func__,
1255-
cached_text_tokens.size(), spec_ckpt_n_denials, id_last, draft.size());
1256+
tokens.size(), spec_ckpt_n_denials, id_last, draft.size());
12561257
leave_draft_state();
12571258
return draft;
12581259
}
@@ -1267,12 +1268,12 @@ struct common_speculative_session::impl {
12671268
}
12681269
// we use the shortened draft of previous speculation
12691270
LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__,
1270-
cached_text_tokens.size(), id_last, draft.size());
1271+
tokens.size(), id_last, draft.size());
12711272
} else if (spec_ckpt_n_denials > 1) {
12721273
GGML_ABORT("illegal state: spec_ckpt_n_denials = %d > 1", spec_ckpt_n_denials);
12731274
} else {
12741275
// call the speculative implementation to create a draft
1275-
draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last);
1276+
draft = common_speculative_draft(spec, params_spec, tokens, id_last);
12761277
LOG_DBG("draft: id_last=%d, #draft=%zu\n", id_last, draft.size());
12771278
if (draft.empty()) {
12781279
leave_draft_state();
@@ -1286,15 +1287,15 @@ struct common_speculative_session::impl {
12861287
}
12871288

12881289
bool do_checkpoint = !draft.empty() && params_spec.use_checkpoints;
1289-
if (do_checkpoint && cached_text_tokens.size() > 5 && draft.size() >= 3) {
1290+
if (do_checkpoint && tokens.size() > 5 && draft.size() >= 3) {
12901291
LOG_DBG("%s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n",
12911292
__func__,
1292-
cached_text_tokens.size(),
1293+
tokens.size(),
12931294
draft.size(), spec_ckpt_n_denials,
12941295
do_checkpoint ? "yes" : "no", id_last,
1295-
cached_text_tokens[cached_text_tokens.size() - 3],
1296-
cached_text_tokens[cached_text_tokens.size() - 2],
1297-
cached_text_tokens[cached_text_tokens.size() - 1],
1296+
tokens[tokens.size() - 3],
1297+
tokens[tokens.size() - 2],
1298+
tokens[tokens.size() - 1],
12981299
draft[0], draft[1], draft[2]);
12991300
}
13001301

@@ -1305,13 +1306,12 @@ struct common_speculative_session::impl {
13051306
}
13061307

13071308
if (do_checkpoint) {
1308-
const size_t n = callback.create_checkpoint();
1309+
const size_t n = callback.create_checkpoint(tokens.size());
13091310
if (n == 0) {
1310-
LOG_WRN("%s: checkpoint creation failed (#tokens=%zu)\n", __func__, cached_text_tokens.size());
1311+
LOG_WRN("%s: checkpoint creation failed (#tokens=%zu)\n", __func__, tokens.size());
13111312
leave_draft_state();
13121313
return draft;
13131314
}
1314-
spec_ckpt_size_part = n;
13151315
spec_has_ckpt = true;
13161316
}
13171317

@@ -1341,7 +1341,7 @@ struct common_speculative_session::impl {
13411341
if (spec_has_ckpt) {
13421342
// we need to rollback to the state before sampling the draft tokens
13431343
// (restore_checkpoint shortens context and slot.prompt.tokens)
1344-
const size_t n = callback.restore_checkpoint(spec_ckpt_size_part);
1344+
const size_t n = callback.restore_checkpoint();
13451345
LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
13461346
__func__,
13471347
ids.size() - 1, n_draft, n);
@@ -1367,8 +1367,10 @@ struct common_speculative_session::impl {
13671367
return common_speculative_accept_response{std::move(ids), 0, true};
13681368
}
13691369
}
1370+
13701371
const size_t draft_size_accepted = draft.size();
13711372
LOG_DBG("%s: draft.size=%zu, ids.size=%zu\n", __func__, draft_size_accepted, ids.size());
1373+
13721374
common_speculative_accept(spec, draft_size_accepted);
13731375
draft.clear();
13741376

@@ -1401,15 +1403,14 @@ struct common_speculative_session::impl {
14011403

14021404
leave_draft_state();
14031405

1404-
spec_has_ckpt = false;
1405-
spec_ckpt_size_part = 0;
1406+
spec_has_ckpt = false;
14061407
}
14071408
};
14081409

14091410
common_speculative_session::common_speculative_session(
1410-
common_speculative_callback & callback,
14111411
const common_params_speculative & params,
1412-
llama_context * ctx_tgt) : p_impl(new impl{callback, params, ctx_tgt}) {
1412+
common_speculative_callback & callback,
1413+
llama_context * ctx_tgt) : p_impl(new impl{params, callback, ctx_tgt}) {
14131414
}
14141415

14151416
common_speculative_session::~common_speculative_session() {

common/speculative.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ struct common_speculative_callback {
7676

7777
// Creates a checkpoint of the current state of the context.
7878
// Returns the size of the checkpoint in bytes.
79-
virtual size_t create_checkpoint() = 0;
79+
virtual size_t create_checkpoint(int64_t n_tokens) = 0;
8080

8181
// Restore a checkpoint previously created by create_checkpoint().
8282
// Returns the size of the restored checkpoint in bytes.
83-
virtual size_t restore_checkpoint(size_t ckpt_size_part_expected) = 0;
83+
virtual size_t restore_checkpoint() = 0;
8484

8585
// Delete a checkpoint previously created by create_checkpoint().
8686
virtual void delete_checkpoint() = 0;
@@ -99,8 +99,8 @@ struct common_speculative_accept_response {
9999
struct common_speculative_session {
100100

101101
common_speculative_session(
102-
common_speculative_callback & callback,
103102
const common_params_speculative & params,
103+
common_speculative_callback & callback,
104104
llama_context * ctx_tgt);
105105

106106
~common_speculative_session();

tools/server/server-common.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,10 @@ const llama_tokens & server_tokens::get_text_tokens() const {
396396
return tokens;
397397
}
398398

399+
const llama_tokens & server_tokens::get_tokens() const {
400+
return tokens;
401+
}
402+
399403
void server_tokens::set_token(llama_pos pos, llama_token id) {
400404
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
401405
tokens[pos] = id;

tools/server/server-common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ struct server_tokens {
192192
// for compatibility with speculative decoding, ctx shift, slot save/load
193193
const llama_tokens & get_text_tokens() const;
194194

195+
const llama_tokens & get_tokens() const;
196+
195197
// for compatibility with speculative decoding
196198
void set_token(llama_pos pos, llama_token id);
197199

0 commit comments

Comments
 (0)