Skip to content

Commit ec8bc44

Browse files
committed
cont : minor
1 parent b3bd3bd commit ec8bc44

4 files changed

Lines changed: 46 additions & 42 deletions

File tree

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
938938
}
939939

940940
void common_speculative_draft(common_speculative * spec) {
941-
if (!spec) {
941+
if (spec == nullptr) {
942942
return;
943943
}
944944

common/speculative.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,10 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
4646
// TODO: implement [TAG_COMMON_SPECULATIVE_PROCESS]
4747
//bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
4848

49-
// generate drafts for the sequences specified in dparams
50-
// requires that `dparams.size() == n_seq` using during common_speculative_init()
49+
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
5150
void common_speculative_draft(common_speculative * spec);
5251

53-
// informs the speculative decoder that n_accepted tokens were accepted by the target model
52+
// informs the speculative context that n_accepted tokens were accepted by the target model
5453
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
5554

5655
// print statistics about the speculative decoding

src/llama-context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2484,6 +2484,7 @@ class llama_io_write_device : public llama_io_write_i {
24842484
} else {
24852485
//LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
24862486

2487+
// save the old buffer and allocate the new tensors in it
24872488
auto buf = std::move(mbuf_cur.buf);
24882489

24892490
mbuf_cur = std::move(mbuf);

tools/server/server-context.cpp

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2204,45 +2204,45 @@ struct server_context_impl {
22042204

22052205
if (spec) {
22062206
common_speculative_get_draft_params(spec.get(), slot.id).drafting = false;
2207-
}
22082207

2209-
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
2210-
const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
2208+
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
2209+
const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
22112210

2212-
const int n_draft_max = slot.get_n_draft_max();
2211+
const int n_draft_max = slot.get_n_draft_max();
22132212

2214-
if (n_draft_max > 0) {
2215-
GGML_ASSERT(slot.can_speculate());
2213+
if (n_draft_max > 0) {
2214+
GGML_ASSERT(slot.can_speculate());
22162215

2217-
if (!slot.spec_draft.empty()) {
2218-
// we have a previous (partial) draft to reuse
2219-
if (use_ckpt_tgt) {
2220-
GGML_ASSERT(!slot.spec_ckpt.empty());
2221-
}
2222-
} else {
2223-
GGML_ASSERT(slot.spec_i_batch.empty());
2216+
if (!slot.spec_draft.empty()) {
2217+
// we have a previous (partial) draft to reuse
2218+
if (use_ckpt_tgt) {
2219+
GGML_ASSERT(!slot.spec_ckpt.empty());
2220+
}
2221+
} else {
2222+
GGML_ASSERT(slot.spec_i_batch.empty());
22242223

2225-
slot.spec_ckpt.update_pos(
2226-
slot.prompt.n_tokens(),
2227-
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id),
2228-
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id));
2224+
slot.spec_ckpt.update_pos(
2225+
slot.prompt.n_tokens(),
2226+
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id),
2227+
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id));
22292228

2230-
if (use_ckpt_dft) {
2231-
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
2232-
}
2229+
if (use_ckpt_dft) {
2230+
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
2231+
}
22332232

2234-
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
2233+
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
22352234

2236-
common_speculative_get_draft_params(spec.get(), slot.id) = {
2237-
/* .drafting = */ true,
2238-
/* .n_max = */ n_draft_max,
2239-
/* .n_past = */ slot.prompt.n_tokens(),
2240-
/* .id_last = */ slot.sampled,
2241-
/* .prompt = */ &slot.spec_prompt,
2242-
/* .result = */ &slot.spec_draft,
2243-
};
2235+
common_speculative_get_draft_params(spec.get(), slot.id) = {
2236+
/* .drafting = */ true,
2237+
/* .n_max = */ n_draft_max,
2238+
/* .n_past = */ slot.prompt.n_tokens(),
2239+
/* .id_last = */ slot.sampled,
2240+
/* .prompt = */ &slot.spec_prompt,
2241+
/* .result = */ &slot.spec_draft,
2242+
};
22442243

2245-
drafting.push_back(&slot);
2244+
drafting.push_back(&slot);
2245+
}
22462246
}
22472247
}
22482248
}
@@ -2256,29 +2256,33 @@ struct server_context_impl {
22562256
for (auto * slot_ptr : drafting) {
22572257
auto & slot = *slot_ptr;
22582258

2259-
slot.n_draft_total += slot.spec_draft.size();
2259+
auto & draft = slot.spec_draft;
2260+
auto & ckpt = slot.spec_ckpt;
2261+
2262+
slot.n_draft_total += draft.size();
22602263

22612264
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
22622265
if (ctx_dft) {
2263-
slot.spec_ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
2266+
ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
22642267

2265-
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, slot.spec_ckpt.pos_max + 1, -1);
2268+
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1);
22662269
}
22672270

2268-
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
2271+
if (!draft.empty()) {
2272+
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
22692273

2270-
if (!slot.spec_draft.empty()) {
22712274
if (use_ckpt_tgt) {
22722275
//const int64_t t_start = ggml_time_us();
22732276

2274-
slot.spec_ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
2277+
ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
22752278

22762279
//const int64_t t_total = ggml_time_us() - t_start;
22772280
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
22782281

22792282
SLT_DBG(slot, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n",
2280-
slot.spec_ckpt.pos_min, slot.spec_ckpt.pos_max, slot.prompt.n_tokens(),
2281-
(float) slot.spec_ckpt.size() / 1024 / 1024, (float) slot.spec_ckpt.data_dft.size() / 1024 / 1024);
2283+
ckpt.pos_min, ckpt.pos_max, slot.prompt.n_tokens(),
2284+
(float) ckpt.size() / 1024 / 1024,
2285+
(float) ckpt.data_dft.size() / 1024 / 1024);
22822286
}
22832287
}
22842288
}

0 commit comments

Comments
 (0)