Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2034,7 +2034,7 @@ bool common_prompt_batch_decode(
}

size_t common_prompt_checkpoint::size() const {
return data_tgt.size() + data_dft.size();
return data_tgt.size() + data_dft.size() + data_spec.size();
}

bool common_prompt_checkpoint::empty() const {
Expand All @@ -2049,6 +2049,7 @@ void common_prompt_checkpoint::clear() {

data_tgt.clear();
data_dft.clear();
data_spec.clear();
}

void common_prompt_checkpoint::update_pos(
Expand Down Expand Up @@ -2138,4 +2139,5 @@ void common_prompt_checkpoint::clear_tgt() {

void common_prompt_checkpoint::clear_dft() {
data_dft.clear();
data_spec.clear();
}
6 changes: 5 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ struct common_params_speculative {

uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
});

return needs_rs_seq ? draft.n_max : 0u;
Expand Down Expand Up @@ -1065,6 +1065,10 @@ struct common_prompt_checkpoint {
std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;

// (optional) speculative-decoding implementation state stashed with the checkpoint
// (e.g. eagle3's deferred-boundary g_embd row)
std::vector<uint8_t> data_spec;

size_t size() const;

bool empty() const;
Expand Down
72 changes: 72 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ struct common_speculative_impl {

virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;

// (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary).
virtual bool get_state(llama_seq_id /*seq_id*/, std::vector<uint8_t> & /*data*/) const { return false; }
virtual void set_state(llama_seq_id /*seq_id*/, const std::vector<uint8_t> & /*data*/) {}

// true if this implementation requires the target context to extract post-norm embeddings
virtual bool need_embd() const = 0;

Expand Down Expand Up @@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
(size_t) n_embd_dec * sizeof(float));
}

// we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
// their single-position checkpoints drop it on restore
bool need_boundary_stash() const {
const llama_model * model_tgt = llama_get_model(params.ctx_tgt);
return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt);
}

bool get_state(llama_seq_id seq_id, std::vector<uint8_t> & data) const override {
if (!need_boundary_stash()) {
return false;
}
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) {
return false;
}

const llama_pos pos = pending_pos_last[seq_id];
const std::vector<float> & g = pending_g_last[seq_id];

data.resize(sizeof(llama_pos) + g.size() * sizeof(float));
std::memcpy(data.data(), &pos, sizeof(llama_pos));
std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float));
return true;
}

void set_state(llama_seq_id seq_id, const std::vector<uint8_t> & data) override {
if (!need_boundary_stash()) {
return;
}
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) {
return;
}

llama_pos pos = -1;
std::memcpy(&pos, data.data(), sizeof(llama_pos));

pending_pos_last[seq_id] = pos;
pending_g_last[seq_id].resize(n_embd_dec);
std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float));
}

bool need_embd() const override {
return false;
}
Expand Down Expand Up @@ -2118,6 +2165,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
}
}

// TODO: support the case of more than one speculative implementations having a state
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data) {
if (spec == nullptr) {
return false;
}

for (auto & impl : spec->impls) {
if (impl->get_state(seq_id, data)) {
return true;
}
}

return false;
}

void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data) {
if (spec == nullptr) {
return;
}

for (auto & impl : spec->impls) {
impl->set_state(seq_id, data);
}
}

Comment thread
ggerganov marked this conversation as resolved.
void common_speculative_print_stats(const common_speculative * spec) {
if (spec == nullptr) {
return;
Expand Down
4 changes: 4 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec);
// informs the speculative context that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);

// (optional) get/set internal state
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data);
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data);

// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);

Expand Down
2 changes: 2 additions & 0 deletions src/models/qwen35.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para

// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
for (int il = 0; il < n_layer; ++il) {
res->t_layer_inp[il] = inpL;

ggml_tensor * inpSA = inpL;

cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
Expand Down
2 changes: 2 additions & 0 deletions src/models/qwen35moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p

// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
for (int il = 0; il < n_layer; ++il) {
res->t_layer_inp[il] = inpL;

ggml_tensor * inpSA = inpL;

cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
Expand Down
4 changes: 4 additions & 0 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,8 @@ struct server_context_impl {

cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
// stash the draft's speculative state with the checkpoint
common_speculative_get_state(spec.get(), slot.id, cur.data_spec);

SLT_INF(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
Expand Down Expand Up @@ -2981,6 +2983,8 @@ struct server_context_impl {
// restore the context checkpoint
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
// restore the draft's speculative state
common_speculative_set_state(spec.get(), slot.id, it->data_spec);

pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
Expand Down