diff --git a/common/speculative.cpp b/common/speculative.cpp index bda9993b159..bbf88fa6e71 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -167,8 +167,6 @@ struct common_speculative_checkpoint { size_t size() const { return data.size(); } - - size_t ckpt_size = 0; }; struct common_speculative_state_draft : public common_speculative_state { @@ -176,7 +174,7 @@ struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_dft; bool use_ckpt = false; - struct common_speculative_checkpoint ckpt; + common_speculative_checkpoint ckpt; common_sampler * smpl; @@ -249,26 +247,16 @@ struct common_speculative_state_draft : public common_speculative_state { llama_batch_free(batch); } - void begin(const llama_tokens & prompt) override { - if (use_ckpt && ckpt.size() > 0) { - // delete checkpoint - LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n", - __func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); - ckpt.pos_min = 0; - ckpt.pos_max = 0; - ckpt.n_tokens = 0; - ckpt.ckpt_size = 0; - ckpt.data.clear(); - } + void begin(const llama_tokens & /*prompt*/) override { } - size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) { + size_t create_checkpoint(int n_tokens_prompt) { int slot_id = 0; const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); - ckpt.n_tokens = n_tokens_prompt - n_tokens_batch; + ckpt.n_tokens = n_tokens_prompt; ckpt.data.resize(checkpoint_size); const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); @@ -281,13 +269,13 @@ struct common_speculative_state_draft : public common_speculative_state { return n; } - size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) { + size_t restore_checkpoint() { int slot_id = 0; LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != ckpt_size_part_expected) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); + if (n != ckpt.size()) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size()); } llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); @@ -346,13 +334,18 @@ struct common_speculative_state_draft : public common_speculative_state { const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); + if (use_ckpt && i_start > 0) { + LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__); + return; + } + // reuse as much as possible from the old draft context // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt for (int i = 0; i < (int) prompt_dft.size(); ++i) { int cur = 0; while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { + i + cur < (int) prompt_dft.size() && + prompt_cur[i_start + cur] == prompt_dft[i + cur]) { cur++; } @@ -360,21 +353,26 @@ struct common_speculative_state_draft : public common_speculative_state { reuse_i = i; reuse_n = cur; } + + if (use_ckpt) { + break; + } } LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) { - LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", - __func__, reuse_i, reuse_n); + if (use_ckpt && ckpt.n_tokens > reuse_n) { + LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens); + reuse_i = 0; reuse_n = 0; + + ckpt = {}; } result.clear(); result.reserve(sparams.n_max); - bool needs_ckpt = use_ckpt && prompt_dft.size() > 0; if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { llama_memory_clear(mem_dft, false); prompt_dft.clear(); @@ -393,50 +391,38 @@ struct common_speculative_state_draft : public common_speculative_state { return; } - bool do_restore = false; - if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) { - // This can happen after a partial acceptance (speculative decoding with checkpoints) - LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n", - __func__, prompt_dft.size(), prompt_cur.size()); - prompt_dft.resize(prompt_cur.size()); - do_restore = true; - } - if (reuse_i > 0) { + GGML_ASSERT(!use_ckpt); + bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); if (!is_removed) { LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); + return; } llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); } - if (reuse_n < (int) prompt_dft.size() || do_restore) { + if (reuse_n < (int) prompt_dft.size()) { if (use_ckpt) { - if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { - LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n", - __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); + if (ckpt.n_tokens > 0) { + LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); + restore_checkpoint(); + reuse_n = ckpt.n_tokens; + prompt_dft.resize(reuse_n); } - draft_restore_checkpoint(ckpt.ckpt_size); - reuse_n = ckpt.n_tokens; - prompt_dft.resize(reuse_n); - needs_ckpt = false; } else { - bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1); if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", - __func__, reuse_n, prompt_dft.size()); + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); + return; } prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); } } } - if (needs_ckpt) { - ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens); - } - // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); @@ -450,12 +436,17 @@ struct common_speculative_state_draft : public common_speculative_state { // we should rarely end-up here during normal decoding if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens); int ret = llama_decode(ctx_dft, batch); if (ret != 0 && ret != 1) { LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", __func__, ret, prompt_cur.size()); } + + if (use_ckpt) { + create_checkpoint(prompt_dft.size()); + } } const llama_pos n_past = prompt_dft.size(); @@ -784,17 +775,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } void accept(uint16_t n_accepted) override { - if (verbose) { - LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); - } - // compute acceptance fraction if we have a recorded draft length if (n_draft_last > 0) { const double f_acc = (double)n_accepted / (double)n_draft_last; if (f_acc < 0.5) { n_low++; if (n_low >= 3) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + if (verbose) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + } mod.reset(); n_low = 0; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ee8366d28c2..2d3003f03a8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -680,6 +680,7 @@ struct server_context_impl { // slots / clients std::vector slots; + int trace = 0; int slots_debug = 0; int n_empty_consecutive = 0; @@ -918,12 +919,21 @@ struct server_context_impl { slot.reset(); } + { + const char * LLAMA_TRACE = getenv("LLAMA_TRACE"); + trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0; + + if (trace) { + SRV_WRN("LLAMA_TRACE = %d\n", trace); + } + } + { const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; if (slots_debug) { - SRV_WRN("slots debug = %d\n", slots_debug); + SRV_WRN("LLAMA_SERVER_SLOTS_DEBUG = %d\n", slots_debug); } } @@ -2974,13 +2984,15 @@ struct server_context_impl { auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); slot.spec_i_batch.clear(); - SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size()); - GGML_ASSERT(accepted.size() >= 1); // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { if (use_ckpt) { + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); + } + // partial acceptance is not supported by the context -> truncate the draft and restore the state slot.spec_draft = std::move(accepted); @@ -3002,8 +3014,10 @@ struct server_context_impl { continue; } + } - LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size()); + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); } common_speculative_accept(slot.spec.get(), accepted.size() - 1);