Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ struct common_speculative_state {

// Optional hook: drain any in-flight async work (prepare_next) and discard.
virtual void cancel() {}

// Phase C.2.1 — cold-restart hook (foundational, no behavior change here).
// Stronger than cancel(): clears all per-iteration state accumulated during a generation.
// Default is cancel(); MTP overrides to also zero h_idx + adaptive-skip counters +
// cached spec params.
virtual void reset() { cancel(); }
};

struct common_speculative_state_draft : public common_speculative_state {
Expand Down Expand Up @@ -845,6 +851,26 @@ struct common_speculative_state_mtp : public common_speculative_state {
mtp_drain_pending_discard();
}

// Phase C.2.1 — cold-restart MTP state at a known boundary (e.g. image-encoding → text continuation).
// Drains any in-flight draft (like cancel) AND zeroes h_idx + adaptive-skip counters +
// cached spec params. Post-condition: next begin()/draft() pair behaves as if MTP was
// just constructed. KV memory and embeddings setting on the target are untouched —
// the host owns those.
void reset() override {
// 1. drain in-flight async draft and clear the one-shot skip flag (cancel semantics)
skip_streak_last_draft = false;
mtp_drain_pending_discard();

// 2. zero per-iteration h_prev pointer + adaptive-skip tracking
h_idx = -1;
prev_n_acc_drafts = 0;
zero_accept_streak = 0;

// 3. forget cached spec params from prior draft() call so the next draft re-computes
// n_steps from scratch when the host passes fresh params.
last_spec_params = common_params_speculative{};
}

void prepare_next(llama_token id_last) override {
// Kill switch for A/B testing depth-2 vs sync.
static const bool depth2_disabled = []() {
Expand Down Expand Up @@ -1569,6 +1595,15 @@ void common_speculative_cancel(common_speculative * spec) {
}
}

void common_speculative_reset(common_speculative * spec) {
if (spec == nullptr) {
return;
}
for (auto & impl : spec->impls) {
impl->reset();
}
}

void common_speculative_print_stats(const common_speculative * spec) {
if (spec == nullptr) {
return;
Expand Down
17 changes: 17 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,22 @@ void common_speculative_prepare_next(common_speculative * spec, llama_token id_l
// snapshot (e.g. slot stop / release / new request seq_rm). Safe no-op when nothing is pending.
void common_speculative_cancel(common_speculative * spec);

// Phase C.2.1 — Cold-restart the speculative state machine (foundational API, no behavior change here).
//
// Stronger than cancel(): in addition to draining any in-flight draft, this clears all
// per-iteration state accumulated during a generation — h_idx is reset to its default
// (-1 = "last output"), draft-history counters used by adaptive skip (prev_n_acc_drafts,
// zero_accept_streak, skip_streak_last_draft) are zeroed, and any cached spec params from
// the previous draft() call are forgotten. After reset(), the implementation behaves as
// if begin() had just been called on a fresh prompt.
//
// Intended use: at known state-boundaries that are NOT prompt boundaries but DO invalidate
// the assistant's hidden-state assumptions — e.g. when a slot transitions from image-encoding
// (where MTP was gated off) back to text continuation (where MTP should re-engage from a clean
// slate). The next few text tokens incur the usual warmup cost but state desync is avoided.
//
// Safe no-op for non-MTP implementations.
void common_speculative_reset(common_speculative * spec);

// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);
4 changes: 4 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
# these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries)
llama_build_and_test(test-sampling.cpp)
llama_build_and_test(test-speculative-mtp.cpp)
# Phase C.2.0 — server_tokens coexistence APIs unit tests
llama_build_and_test(test-server-tokens.cpp)
target_include_directories(test-server-tokens PRIVATE ${PROJECT_SOURCE_DIR}/tools/server ${PROJECT_SOURCE_DIR}/tools/mtmd)
target_link_libraries(test-server-tokens PRIVATE server-context)
llama_build_and_test(test-reasoning-budget.cpp)
llama_build_and_test(test-grammar-parser.cpp)
llama_build_and_test(test-grammar-integration.cpp)
Expand Down
121 changes: 121 additions & 0 deletions tests/test-server-tokens.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Phase C.2.0 — unit tests for server_tokens coexistence APIs introduced for MTP+mmproj dispatch.
//
// Scope:
// - is_pure_text_continuation(from_idx)
// - last_image_end_idx()
// - get_text_tokens_post_media()
//
// These APIs are foundational and do not change runtime behavior; they expose information that
// future per-batch dispatch will use. This file covers what is testable WITHOUT loading a model
// or running the full mtmd pipeline:
// - non-multimodal (has_mtmd=false) buffers
// - empty multimodal (has_mtmd=true, no chunks) buffers
// - empty buffer (size==0) edge cases
//
// The WITH-image cases require a real mtmd_input_chunk (which goes through the mtmd public API
// requiring an image file + clip model). Those are covered by integration tests in C.2.4 once
// the dispatch behavior is wired up.

#include "server-common.h"

#include <cstdio>
#include <cstdlib>

#define CHECK(cond, msg) \
do { \
if (!(cond)) { \
std::fprintf(stderr, "FAIL %s:%d %s (cond: %s)\n", __FILE__, __LINE__, msg, #cond); \
std::exit(1); \
} \
} while (0)

static void test_non_mtmd_empty_buffer() {
llama_tokens t;
server_tokens st(t, /*has_mtmd*/ false);

CHECK(st.size() == 0, "empty size");
CHECK(st.empty(), "empty()");
CHECK(st.last_image_end_idx() == 0, "last_image_end_idx empty");
CHECK(st.is_pure_text_continuation(0), "pure-text @ 0 (empty)");
CHECK(st.is_pure_text_continuation(100), "pure-text @ 100 (empty/past-end)");

llama_tokens out = st.get_text_tokens_post_media();
CHECK(out.empty(), "post-media tail empty for empty buffer");
}

static void test_non_mtmd_text_only() {
llama_tokens t = {1, 2, 3, 4, 5};
server_tokens st(t, /*has_mtmd*/ false);

CHECK(st.size() == 5, "size==5");
CHECK(!st.empty(), "!empty");
CHECK(st.last_image_end_idx() == 0, "last_image_end_idx text-only -> 0");

// is_pure_text_continuation always true when !has_mtmd
CHECK(st.is_pure_text_continuation(0), "pure-text @ 0");
CHECK(st.is_pure_text_continuation(3), "pure-text @ 3");
CHECK(st.is_pure_text_continuation(5), "pure-text @ 5 (at end)");
CHECK(st.is_pure_text_continuation(999), "pure-text @ 999 (past end)");

// For non-mtmd, get_text_tokens_post_media returns all tokens (no NULL stripped because none present).
llama_tokens out = st.get_text_tokens_post_media();
CHECK(out.size() == 5, "post-media tail size matches buffer");
for (size_t i = 0; i < out.size(); ++i) {
CHECK(out[i] == t[i], "post-media tail token matches");
}

// get_text_tokens() must still return the canonical reference for non-mtmd path.
const llama_tokens & ref = st.get_text_tokens();
CHECK(ref.size() == 5, "get_text_tokens() size");
CHECK(ref.data() != out.data(), "post-media tail is a distinct copy");
}

static void test_mtmd_empty_chunks() {
// server_tokens with has_mtmd=true but no media chunks added: same observable behavior as non-mtmd
// for the new APIs (per-API contract: empty map → return as text-only).
// We construct via the llama_tokens ctor + force has_mtmd=true via the public mutable field
// (server_tokens exposes has_mtmd as public — see server-common.h:126).
llama_tokens t = {10, 20, 30};
server_tokens st(t, /*has_mtmd*/ false);
st.has_mtmd = true; // simulate mtmd-enabled buffer with no chunks yet

CHECK(st.last_image_end_idx() == 0, "mtmd+empty-map: last_image_end_idx==0");
CHECK(st.is_pure_text_continuation(0), "mtmd+empty-map: pure @ 0");
CHECK(st.is_pure_text_continuation(3), "mtmd+empty-map: pure @ 3");
CHECK(st.is_pure_text_continuation(999), "mtmd+empty-map: pure @ past-end");

llama_tokens out = st.get_text_tokens_post_media();
CHECK(out.size() == 3, "mtmd+empty-map: tail returns all text");
CHECK(out[0] == 10 && out[1] == 20 && out[2] == 30, "mtmd+empty-map: tail content matches");
}

static void test_pure_text_continuation_semantics() {
// The contract: is_pure_text_continuation(from_idx) returns true iff there is NO image chunk
// extending past from_idx. We can verify the non-mtmd / empty-mtmd branches here (the
// with-image branch is exercised by integration tests once mtmd is wired up).
llama_tokens t = {7, 8, 9};
server_tokens st(t, false);

CHECK(st.is_pure_text_continuation(0), "from_idx<size: true");
CHECK(st.is_pure_text_continuation(2), "from_idx<size: true");
CHECK(st.is_pure_text_continuation(3), "from_idx==size: true");
CHECK(st.is_pure_text_continuation(4), "from_idx>size: true (past end)");
CHECK(st.is_pure_text_continuation(SIZE_MAX), "from_idx=SIZE_MAX: true (past end)");
}

int main() {
test_non_mtmd_empty_buffer();
std::printf("[server_tokens] non_mtmd_empty_buffer OK\n");

test_non_mtmd_text_only();
std::printf("[server_tokens] non_mtmd_text_only OK\n");

test_mtmd_empty_chunks();
std::printf("[server_tokens] mtmd_empty_chunks OK\n");

test_pure_text_continuation_semantics();
std::printf("[server_tokens] pure_text_continuation_semantics OK\n");

std::printf("ALL PASS — 4 test groups, server_tokens C.2.0 foundational API\n");
return 0;
}
8 changes: 8 additions & 0 deletions tests/test-speculative-mtp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "llama.h"
#include "speculative.h"

#include <algorithm>
#include <cstdlib>
Expand All @@ -10,6 +11,13 @@
// Set env vars to run non-skip paths; otherwise exits 0.

int main() {
// Phase C.2.1 — contract smoke: common_speculative_reset / common_speculative_cancel
// must be safe no-ops on a null spec (matches the documented contract in speculative.h).
// Runs unconditionally — no model files required.
common_speculative_cancel(nullptr);
common_speculative_reset(nullptr);
std::cout << "[common_speculative] null-spec cancel + reset OK\n";

const char * path_tgt = std::getenv("LLAMA_MTP_TEST_TARGET");
const char * path_head = std::getenv("LLAMA_MTP_TEST_HEAD");
const char * path_bad = std::getenv("LLAMA_MTP_TEST_BAD_ARCH");
Expand Down
48 changes: 48 additions & 0 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,54 @@ const llama_tokens & server_tokens::get_text_tokens() const {
return tokens;
}

// Phase C.2.0 — coexistence APIs (see header for contract).

size_t server_tokens::last_image_end_idx() const {
if (!has_mtmd || map_idx_to_media.empty()) {
return 0;
}
// map_idx_to_media is std::map sorted by start idx; rbegin() is O(1).
auto last = map_idx_to_media.rbegin();
const size_t start_idx = last->first;
const size_t n_tokens = mtmd_input_chunk_get_n_tokens(last->second.get());
return start_idx + n_tokens;
}

bool server_tokens::is_pure_text_continuation(size_t from_idx) const {
if (!has_mtmd || map_idx_to_media.empty()) {
return true;
}
return from_idx >= last_image_end_idx();
}

llama_tokens server_tokens::get_text_tokens_post_media() const {
if (!has_mtmd || map_idx_to_media.empty()) {
// Defensive: even in pure-text mode the buffer should not contain LLAMA_TOKEN_NULL,
// but strip just in case to keep the post-condition invariant uniform.
llama_tokens out;
out.reserve(tokens.size());
for (const auto & t : tokens) {
if (t != LLAMA_TOKEN_NULL) {
out.push_back(t);
}
}
return out;
}
const size_t start = last_image_end_idx();
llama_tokens out;
if (start >= tokens.size()) {
return out;
}
out.reserve(tokens.size() - start);
for (size_t i = start; i < tokens.size(); ++i) {
const llama_token t = tokens[i];
if (t != LLAMA_TOKEN_NULL) {
out.push_back(t);
}
}
return out;
}

void server_tokens::set_token(llama_pos pos, llama_token id) {
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
tokens[pos] = id;
Expand Down
24 changes: 24 additions & 0 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,30 @@ struct server_tokens {
// for compatibility with speculative decoding, ctx shift, slot save/load
const llama_tokens & get_text_tokens() const;

// Phase C.2.0 — coexistence APIs for MTP + mmproj dispatch (foundational, no behavior change here).
//
// is_pure_text_continuation(from_idx) — O(log n) oracle:
// "if a caller decodes starting at position from_idx, will all tokens through end-of-buffer
// be pure text (no image chunks remaining)?"
// Used by the server to gate per-batch MTP draft dispatch when mmproj is also loaded.
// - !has_mtmd → always true
// - map empty → always true
// - from_idx >= last_image_end_idx() → true (we're past every image chunk)
// - otherwise → false (an image chunk still extends past from_idx)
bool is_pure_text_continuation(size_t from_idx) const;

// End-exclusive idx of the last image/audio chunk in the buffer (start + n_tokens).
// Returns 0 if there are no media chunks. !has_mtmd → 0.
size_t last_image_end_idx() const;

// Returns the suffix of text tokens after the last media chunk.
// - !has_mtmd → returns a copy of all tokens
// - map empty → returns a copy of all tokens
// - otherwise → tokens[last_image_end_idx() ..] with any LLAMA_TOKEN_NULL stripped
// Returned by value because the underlying buffer may interleave images and the suffix is
// not a contiguous slice. Callers typically bind to a const ref of the temporary.
llama_tokens get_text_tokens_post_media() const;

// for compatibility with speculative decoding
void set_token(llama_pos pos, llama_token id);

Expand Down