Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3550,12 +3550,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
} else if (value == "mtp") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
} else if (value == "ngram-cache") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
} else if (value == "ngram-simple") {
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction head loaded from the target GGUF
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
Expand Down
296 changes: 296 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
Expand All @@ -23,6 +24,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_EAGLE3,
COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
Expand All @@ -34,6 +36,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
Expand Down Expand Up @@ -375,6 +378,290 @@ struct common_speculative_state_eagle3 : public common_speculative_impl {
}
};

struct common_speculative_state_mtp : public common_speculative_impl {
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)

llama_batch batch;

std::vector<common_sampler_ptr> smpls;

int32_t n_embd = 0;

// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
std::vector<std::vector<float>> pending_h; // [n_seq][n_embd]

std::vector<int32_t> i_batch_beg;
std::vector<int32_t> i_batch_end;

common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq)
, params(params.draft)
{
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");

n_embd = llama_model_n_embd(llama_get_model(ctx_dft));

const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
// llama_batch_init allocates only one of token/embd; MTP needs both.
// TODO: fix, how to call without malloc
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b);

smpls.resize(n_seq);
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 10;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}

llama_set_embeddings_pre_norm(ctx_tgt, true);
llama_set_embeddings_pre_norm(ctx_dft, true);

pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));

i_batch_beg.assign(n_seq, -1);
i_batch_end.assign(n_seq, -1);
}

~common_speculative_state_mtp() override {
if (batch.token != nullptr) {
free(batch.token);
batch.token = nullptr;
}
llama_batch_free(batch);
}

void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
}
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
}
}

bool process(const llama_batch & batch_in) override {
if (batch_in.n_tokens <= 0) {
return true;
}

// TODO: how to make it work with vision tokens?
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
return true;
}
Comment on lines +461 to +464
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure what is the correct way to process the image embeddings with the MTP context. In any case, vision MTP seems to already work to good extent:

Here I ask it to OCR 100 random integers without speculative decoding and with MTP:

  • Without spec decoding
Image
  • With MTP
Image

With MTP it is ~2x faster which means the MTP context "knows" about the integers in some way. But at the same time, I'm pretty sure that the current way of processing is not 100% correct because inp->tokens tensor in the mtp graph is being used with stale data when the input batch has image embeddings and no tokens.

I think we will figure this out later - not super important atm.


const int32_t n_tokens = batch_in.n_tokens;

// remember the frist and last batch index for each sequence
std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1);
std::fill(i_batch_end.begin(), i_batch_end.end(), -1);

for (int k = 0; k < n_tokens; ++k) {
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
GGML_ASSERT(batch_in.n_seq_id[k] == 1);

if (batch_in.seq_id[k][0] == seq_id) {
i_batch_end[seq_id] = k;
if (i_batch_beg[seq_id] < 0) {
i_batch_beg[seq_id] = k;
}
}
}
}

auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;

const size_t row_bytes = (size_t) n_embd * sizeof(float);

common_batch_clear(batch);

for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}

// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));

//{
// // string with seq_ids in the batch
// std::stringstream ss;
// for (int i = 0; i < n_tokens; ++i) {
// ss << batch_in.seq_id[i][0] << ",";
// }
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
//}
}

// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};

for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}

set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}

const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
}

for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_end[seq_id] < 0) {
continue;
}

const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]);
std::memcpy(pending_h[seq_id].data(), h_last, row_bytes);
}

return true;
}

void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;

common_batch_clear(batch);

// keep track of which sequences are still drafting
int n_drafting = 0;
std::vector<bool> drafting(n_seq);

const float * h_row = nullptr;
const size_t row_bytes = (size_t) n_embd * sizeof(float);

for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];

if (!dp.drafting) {
continue;
}

n_drafting++;
drafting[seq_id] = true;
common_sampler_reset(smpls[seq_id].get());

common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);

h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}

int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}

int i = 0;

while (n_drafting > 0) {
int i_batch = 0;

common_batch_clear(batch);

for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (!drafting[seq_id]) {
continue;
}

auto * smpl = smpls[seq_id].get();

common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
++i_batch;

const auto * cur_p = common_sampler_get_candidates(smpl, true);

for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}

// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;

// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
drafting[seq_id] = false;
n_drafting--;

continue;
}

common_sampler_accept(smpl, id, true);

auto & dp = dparams.at(seq_id);
auto & result = *dp.result;

result.push_back(id);

if ((params.n_max <= (int) result.size()) ||
(dp.n_max > 0 && dp.n_max <= (int) result.size())) {
drafting[seq_id] = false;
n_drafting--;
continue;
}

common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}

if (batch.n_tokens == 0) {
break;
}

// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}

++i;
}

for (auto & dp : dparams) {
if (!dp.drafting) {
continue;
}

if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
}
}

void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
}
};

// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_impl {
common_params_speculative_ngram_map params;
Expand Down Expand Up @@ -818,6 +1105,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
Expand All @@ -843,6 +1131,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
{
bool has_draft = !params.draft.mparams.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && params.draft.ctx_dft != nullptr;

bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
Expand Down Expand Up @@ -875,6 +1164,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
}
}

std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
Expand All @@ -892,6 +1184,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);

Expand Down
Loading
Loading