Skip to content

Commit b14e3fb

Browse files
spec: support eagle3 for qwen3.5 & 3.6 (#24593)
* spec: support qwen3.5 & 3.6 eagle3 draft * eagle3: Add deferred boundary checkpoints restore support for hybrid models * apply suggestions Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * spec: adapt to API change * spec: fix naming * cont : add TODO --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 159d093 commit b14e3fb

7 files changed

Lines changed: 92 additions & 2 deletions

File tree

common/common.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2034,7 +2034,7 @@ bool common_prompt_batch_decode(
20342034
}
20352035

20362036
size_t common_prompt_checkpoint::size() const {
2037-
return data_tgt.size() + data_dft.size();
2037+
return data_tgt.size() + data_dft.size() + data_spec.size();
20382038
}
20392039

20402040
bool common_prompt_checkpoint::empty() const {
@@ -2049,6 +2049,7 @@ void common_prompt_checkpoint::clear() {
20492049

20502050
data_tgt.clear();
20512051
data_dft.clear();
2052+
data_spec.clear();
20522053
}
20532054

20542055
void common_prompt_checkpoint::update_pos(
@@ -2138,4 +2139,5 @@ void common_prompt_checkpoint::clear_tgt() {
21382139

21392140
void common_prompt_checkpoint::clear_dft() {
21402141
data_dft.clear();
2142+
data_spec.clear();
21412143
}

common/common.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ struct common_params_speculative {
363363

364364
uint32_t need_n_rs_seq() const {
365365
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
366-
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
366+
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
367367
});
368368

369369
return needs_rs_seq ? draft.n_max : 0u;
@@ -1065,6 +1065,10 @@ struct common_prompt_checkpoint {
10651065
std::vector<uint8_t> data_tgt;
10661066
std::vector<uint8_t> data_dft;
10671067

1068+
// (optional) speculative-decoding implementation state stashed with the checkpoint
1069+
// (e.g. eagle3's deferred-boundary g_embd row)
1070+
std::vector<uint8_t> data_spec;
1071+
10681072
size_t size() const;
10691073

10701074
bool empty() const;

common/speculative.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ struct common_speculative_impl {
161161

162162
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
163163

164+
// (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary).
165+
virtual bool get_state(llama_seq_id /*seq_id*/, std::vector<uint8_t> & /*data*/) const { return false; }
166+
virtual void set_state(llama_seq_id /*seq_id*/, const std::vector<uint8_t> & /*data*/) {}
167+
164168
// true if this implementation requires the target context to extract post-norm embeddings
165169
virtual bool need_embd() const = 0;
166170

@@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
841845
(size_t) n_embd_dec * sizeof(float));
842846
}
843847

848+
// we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
849+
// their single-position checkpoints drop it on restore
850+
bool need_boundary_stash() const {
851+
const llama_model * model_tgt = llama_get_model(params.ctx_tgt);
852+
return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt);
853+
}
854+
855+
bool get_state(llama_seq_id seq_id, std::vector<uint8_t> & data) const override {
856+
if (!need_boundary_stash()) {
857+
return false;
858+
}
859+
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) {
860+
return false;
861+
}
862+
863+
const llama_pos pos = pending_pos_last[seq_id];
864+
const std::vector<float> & g = pending_g_last[seq_id];
865+
866+
data.resize(sizeof(llama_pos) + g.size() * sizeof(float));
867+
std::memcpy(data.data(), &pos, sizeof(llama_pos));
868+
std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float));
869+
return true;
870+
}
871+
872+
void set_state(llama_seq_id seq_id, const std::vector<uint8_t> & data) override {
873+
if (!need_boundary_stash()) {
874+
return;
875+
}
876+
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
877+
return;
878+
}
879+
if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) {
880+
return;
881+
}
882+
883+
llama_pos pos = -1;
884+
std::memcpy(&pos, data.data(), sizeof(llama_pos));
885+
886+
pending_pos_last[seq_id] = pos;
887+
pending_g_last[seq_id].resize(n_embd_dec);
888+
std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float));
889+
}
890+
844891
bool need_embd() const override {
845892
return false;
846893
}
@@ -2118,6 +2165,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
21182165
}
21192166
}
21202167

2168+
// TODO: support the case of more than one speculative implementations having a state
2169+
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data) {
2170+
if (spec == nullptr) {
2171+
return false;
2172+
}
2173+
2174+
for (auto & impl : spec->impls) {
2175+
if (impl->get_state(seq_id, data)) {
2176+
return true;
2177+
}
2178+
}
2179+
2180+
return false;
2181+
}
2182+
2183+
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data) {
2184+
if (spec == nullptr) {
2185+
return;
2186+
}
2187+
2188+
for (auto & impl : spec->impls) {
2189+
impl->set_state(seq_id, data);
2190+
}
2191+
}
2192+
21212193
void common_speculative_print_stats(const common_speculative * spec) {
21222194
if (spec == nullptr) {
21232195
return;

common/speculative.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec);
6868
// informs the speculative context that n_accepted tokens were accepted by the target model
6969
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
7070

71+
// (optional) get/set internal state
72+
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data);
73+
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data);
74+
7175
// print statistics about the speculative decoding
7276
void common_speculative_print_stats(const common_speculative * spec);
7377

src/models/qwen35.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
156156

157157
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
158158
for (int il = 0; il < n_layer; ++il) {
159+
res->t_layer_inp[il] = inpL;
160+
159161
ggml_tensor * inpSA = inpL;
160162

161163
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);

src/models/qwen35moe.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
179179

180180
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
181181
for (int il = 0; il < n_layer; ++il) {
182+
res->t_layer_inp[il] = inpL;
183+
182184
ggml_tensor * inpSA = inpL;
183185

184186
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);

tools/server/server-context.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,8 @@ struct server_context_impl {
21722172

21732173
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
21742174
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
2175+
// stash the draft's speculative state with the checkpoint
2176+
common_speculative_get_state(spec.get(), slot.id, cur.data_spec);
21752177

21762178
SLT_INF(slot,
21772179
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
@@ -2998,6 +3000,8 @@ struct server_context_impl {
29983000
// restore the context checkpoint
29993001
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
30003002
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3003+
// restore the draft's speculative state
3004+
common_speculative_set_state(spec.get(), slot.id, it->data_spec);
30013005

30023006
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
30033007
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);

0 commit comments

Comments
 (0)