Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 44 additions & 55 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,14 @@ struct common_speculative_checkpoint {
size_t size() const {
return data.size();
}

size_t ckpt_size = 0;
};

struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;

bool use_ckpt = false;
struct common_speculative_checkpoint ckpt;
common_speculative_checkpoint ckpt;

common_sampler * smpl;

Expand Down Expand Up @@ -249,26 +247,16 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_batch_free(batch);
}

void begin(const llama_tokens & prompt) override {
if (use_ckpt && ckpt.size() > 0) {
// delete checkpoint
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
ckpt.pos_min = 0;
ckpt.pos_max = 0;
ckpt.n_tokens = 0;
ckpt.ckpt_size = 0;
ckpt.data.clear();
}
void begin(const llama_tokens & /*prompt*/) override {
}

size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
size_t create_checkpoint(int n_tokens_prompt) {
int slot_id = 0;
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
ckpt.n_tokens = n_tokens_prompt;
ckpt.data.resize(checkpoint_size);

const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
Expand All @@ -281,13 +269,13 @@ struct common_speculative_state_draft : public common_speculative_state {
return n;
}

size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
size_t restore_checkpoint() {
int slot_id = 0;
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_part_expected) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
}
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);

Expand Down Expand Up @@ -346,35 +334,45 @@ struct common_speculative_state_draft : public common_speculative_state {

const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);

if (use_ckpt && i_start > 0) {
LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__);
return;
}

// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_cur.size() &&
i + cur < (int) prompt_dft.size() &&
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
i + cur < (int) prompt_dft.size() &&
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
cur++;
}

if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}

if (use_ckpt) {
break;
}
}

LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
__func__, reuse_i, reuse_n);
if (use_ckpt && ckpt.n_tokens > reuse_n) {
LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens);

reuse_i = 0;
reuse_n = 0;

ckpt = {};
}

result.clear();
result.reserve(sparams.n_max);

bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
Expand All @@ -393,50 +391,38 @@ struct common_speculative_state_draft : public common_speculative_state {
return;
}

bool do_restore = false;
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
// This can happen after a partial acceptance (speculative decoding with checkpoints)
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
__func__, prompt_dft.size(), prompt_cur.size());
prompt_dft.resize(prompt_cur.size());
do_restore = true;
}

if (reuse_i > 0) {
GGML_ASSERT(!use_ckpt);

bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
return;
}
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);

prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
}

if (reuse_n < (int) prompt_dft.size() || do_restore) {
if (reuse_n < (int) prompt_dft.size()) {
if (use_ckpt) {
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
if (ckpt.n_tokens > 0) {
LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
restore_checkpoint();
reuse_n = ckpt.n_tokens;
prompt_dft.resize(reuse_n);
}
draft_restore_checkpoint(ckpt.ckpt_size);
reuse_n = ckpt.n_tokens;
prompt_dft.resize(reuse_n);
needs_ckpt = false;
} else {
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
__func__, reuse_n, prompt_dft.size());
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
return;
}
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
}
}
}

if (needs_ckpt) {
ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens);
}

// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);

Expand All @@ -450,12 +436,17 @@ struct common_speculative_state_draft : public common_speculative_state {
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens);

int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
__func__, ret, prompt_cur.size());
}

if (use_ckpt) {
create_checkpoint(prompt_dft.size());
}
}

const llama_pos n_past = prompt_dft.size();
Expand Down Expand Up @@ -784,17 +775,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
}

void accept(uint16_t n_accepted) override {
if (verbose) {
LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last);
}

// compute acceptance fraction if we have a recorded draft length
if (n_draft_last > 0) {
const double f_acc = (double)n_accepted / (double)n_draft_last;
if (f_acc < 0.5) {
n_low++;
if (n_low >= 3) {
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
if (verbose) {
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
}

mod.reset();
n_low = 0;
Expand Down
22 changes: 18 additions & 4 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ struct server_context_impl {
// slots / clients
std::vector<server_slot> slots;

int trace = 0;
int slots_debug = 0;
int n_empty_consecutive = 0;

Expand Down Expand Up @@ -918,12 +919,21 @@ struct server_context_impl {
slot.reset();
}

{
const char * LLAMA_TRACE = getenv("LLAMA_TRACE");
trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0;

if (trace) {
SRV_WRN("LLAMA_TRACE = %d\n", trace);
}
}

{
const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;

if (slots_debug) {
SRV_WRN("slots debug = %d\n", slots_debug);
SRV_WRN("LLAMA_SERVER_SLOTS_DEBUG = %d\n", slots_debug);
}
}

Expand Down Expand Up @@ -2974,13 +2984,15 @@ struct server_context_impl {
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
slot.spec_i_batch.clear();

SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size());

GGML_ASSERT(accepted.size() >= 1);

// check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) {
if (use_ckpt) {
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
}

// partial acceptance is not supported by the context -> truncate the draft and restore the state
slot.spec_draft = std::move(accepted);

Expand All @@ -3002,8 +3014,10 @@ struct server_context_impl {

continue;
}
}

LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size());
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
}

common_speculative_accept(slot.spec.get(), accepted.size() - 1);
Expand Down
Loading