diff --git a/common/common.cpp b/common/common.cpp index b01772e1cbfe..f3f114f68245 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2034,7 +2034,7 @@ bool common_prompt_batch_decode( } size_t common_prompt_checkpoint::size() const { - return data_tgt.size() + data_dft.size(); + return data_tgt.size() + data_dft.size() + data_spec.size(); } bool common_prompt_checkpoint::empty() const { @@ -2049,6 +2049,7 @@ void common_prompt_checkpoint::clear() { data_tgt.clear(); data_dft.clear(); + data_spec.clear(); } void common_prompt_checkpoint::update_pos( @@ -2138,4 +2139,5 @@ void common_prompt_checkpoint::clear_tgt() { void common_prompt_checkpoint::clear_dft() { data_dft.clear(); + data_spec.clear(); } diff --git a/common/common.h b/common/common.h index 040b9cf23312..535a4ed335ad 100644 --- a/common/common.h +++ b/common/common.h @@ -363,7 +363,7 @@ struct common_params_speculative { uint32_t need_n_rs_seq() const { bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) { - return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; + return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3; }); return needs_rs_seq ? draft.n_max : 0u; @@ -1065,6 +1065,10 @@ struct common_prompt_checkpoint { std::vector data_tgt; std::vector data_dft; + // (optional) speculative-decoding implementation state stashed with the checkpoint + // (e.g. eagle3's deferred-boundary g_embd row) + std::vector data_spec; + size_t size() const; bool empty() const; diff --git a/common/speculative.cpp b/common/speculative.cpp index 6f387f2cfc13..9c20585dc3e3 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -161,6 +161,10 @@ struct common_speculative_impl { virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0; + // (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary). + virtual bool get_state(llama_seq_id /*seq_id*/, std::vector & /*data*/) const { return false; } + virtual void set_state(llama_seq_id /*seq_id*/, const std::vector & /*data*/) {} + // true if this implementation requires the target context to extract post-norm embeddings virtual bool need_embd() const = 0; @@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { (size_t) n_embd_dec * sizeof(float)); } + // we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets: + // their single-position checkpoints drop it on restore + bool need_boundary_stash() const { + const llama_model * model_tgt = llama_get_model(params.ctx_tgt); + return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt); + } + + bool get_state(llama_seq_id seq_id, std::vector & data) const override { + if (!need_boundary_stash()) { + return false; + } + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) { + return false; + } + + const llama_pos pos = pending_pos_last[seq_id]; + const std::vector & g = pending_g_last[seq_id]; + + data.resize(sizeof(llama_pos) + g.size() * sizeof(float)); + std::memcpy(data.data(), &pos, sizeof(llama_pos)); + std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float)); + return true; + } + + void set_state(llama_seq_id seq_id, const std::vector & data) override { + if (!need_boundary_stash()) { + return; + } + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { + return; + } + if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) { + return; + } + + llama_pos pos = -1; + std::memcpy(&pos, data.data(), sizeof(llama_pos)); + + pending_pos_last[seq_id] = pos; + pending_g_last[seq_id].resize(n_embd_dec); + std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float)); + } + bool need_embd() const override { return false; } @@ -2118,6 +2165,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u } } +// TODO: support the case of more than one speculative implementations having a state +bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector & data) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->get_state(seq_id, data)) { + return true; + } + } + + return false; +} + +void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector & data) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + impl->set_state(seq_id, data); + } +} + void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index bf76ad709e26..c58fac3cc6d0 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec); // informs the speculative context that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); +// (optional) get/set internal state +bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector & data); +void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector & data); + // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 6783d98ec204..d8ffe43ae76c 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -156,6 +156,8 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index eb5e9a406a15..7b0876cbb04b 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -179,6 +179,8 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 31280d63c4ea..59048889e09f 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2157,6 +2157,8 @@ struct server_context_impl { cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + // stash the draft's speculative state with the checkpoint + common_speculative_get_state(spec.get(), slot.id, cur.data_spec); SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", @@ -2981,6 +2983,8 @@ struct server_context_impl { // restore the context checkpoint it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + // restore the draft's speculative state + common_speculative_set_state(spec.get(), slot.id, it->data_spec); pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);