Skip to content

Commit f6a0283

Browse files
author
lvyichen
committed
fix arg: embeddings\n_min\reuse\model-draft\spec-type
1 parent d2d7bbc commit f6a0283

8 files changed

Lines changed: 88 additions & 42 deletions

File tree

common/arg.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3491,6 +3491,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34913491
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
34923492
} else if (value == "mtp") {
34933493
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
3494+
params.mtp = true;
34943495
} else {
34953496
throw std::invalid_argument("unknown speculative decoding type without draft model");
34963497
}

common/common.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
13501350

13511351
struct llama_context_params common_context_params_to_llama(const common_params & params) {
13521352
auto cparams = llama_context_default_params();
1353+
const bool mtp_needs_hidden_states = params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP;
13531354

13541355
cparams.n_ctx = params.n_ctx;
13551356
cparams.n_seq_max = params.n_parallel;
@@ -1358,7 +1359,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
13581359
cparams.n_threads = params.cpuparams.n_threads;
13591360
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
13601361
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
1361-
cparams.embeddings = params.embedding;
1362+
cparams.embeddings = params.embedding || mtp_needs_hidden_states;
13621363
cparams.rope_scaling_type = params.rope_scaling_type;
13631364
cparams.rope_freq_base = params.rope_freq_base;
13641365
cparams.rope_freq_scale = params.rope_freq_scale;

common/common.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,6 @@ struct common_params_speculative {
312312
bool has_dft() const {
313313
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
314314
}
315-
316-
bool requires_dft() const {
317-
return type == COMMON_SPECULATIVE_TYPE_DRAFT || type == COMMON_SPECULATIVE_TYPE_EAGLE3;
318-
}
319315
};
320316

321317
struct common_params_vocoder {

common/speculative.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,7 @@ common_speculative * common_speculative_init(
15631563
cparams.n_threads_batch = llama_n_threads_batch(ctx_tgt);
15641564
}
15651565

1566+
llama_set_embeddings(ctx_tgt, true);
15661567
cparams.embeddings = true;
15671568

15681569
llama_context * ctx_mtp = llama_init_from_model(const_cast<llama_model *>(llama_get_model(ctx_tgt)), cparams);

examples/speculative-simple/speculative-simple.cpp

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ int main(int argc, char ** argv) {
2727

2828
common_init();
2929

30-
if (params.speculative.requires_dft() && !params.speculative.has_dft()) {
31-
LOG_ERR("%s: --model-draft is required\n", __func__);
32-
return 1;
33-
}
34-
3530
// init llama.cpp
3631
llama_backend_init();
3732
llama_numa_init(params.numa);
@@ -55,26 +50,23 @@ int main(int argc, char ** argv) {
5550
{
5651
const auto & params_spec = params.speculative;
5752

58-
auto params_dft = params;
53+
if (params_spec.has_dft()) {
54+
auto params_dft = params;
5955

60-
params_dft.n_parallel = 1;
61-
params_dft.n_ctx = params_spec.n_ctx == 0 ? (int32_t) llama_n_ctx_seq(ctx_tgt) : params_spec.n_ctx;
62-
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
63-
params_dft.cache_type_k = params_spec.cache_type_k;
64-
params_dft.cache_type_v = params_spec.cache_type_v;
65-
params_dft.devices = params_spec.devices;
66-
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
67-
68-
if (params_spec.cpuparams.n_threads > 0) {
69-
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
70-
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
71-
}
56+
params_dft.n_parallel = 1;
57+
params_dft.n_ctx = params_spec.n_ctx == 0 ? (int32_t) llama_n_ctx_seq(ctx_tgt) : params_spec.n_ctx;
58+
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
59+
params_dft.cache_type_k = params_spec.cache_type_k;
60+
params_dft.cache_type_v = params_spec.cache_type_v;
61+
params_dft.devices = params_spec.devices;
62+
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
7263

73-
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
74-
75-
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
64+
if (params_spec.cpuparams.n_threads > 0) {
65+
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
66+
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
67+
}
7668

77-
if (params_spec.requires_dft()) {
69+
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
7870
params_dft.model = params_spec.mparams_dft;
7971

8072
auto mparams_dft = common_model_params_to_llama(params_dft);
@@ -86,6 +78,25 @@ int main(int argc, char ** argv) {
8678
}
8779

8880
params.speculative.model_dft = model_dft.get();
81+
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
82+
} else if (params_spec.type == COMMON_SPECULATIVE_TYPE_MTP) {
83+
auto params_dft = params;
84+
85+
params_dft.n_parallel = 1;
86+
params_dft.n_ctx = params_spec.n_ctx == 0 ? (int32_t) llama_n_ctx_seq(ctx_tgt) : params_spec.n_ctx;
87+
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
88+
params_dft.cache_type_k = params_spec.cache_type_k;
89+
params_dft.cache_type_v = params_spec.cache_type_v;
90+
params_dft.devices = params_spec.devices;
91+
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
92+
93+
if (params_spec.cpuparams.n_threads > 0) {
94+
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
95+
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
96+
}
97+
98+
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
99+
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
89100
}
90101
}
91102

src/llama-graph.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,18 @@ void llm_graph_input_mtp_hidden_state::set_input(const llama_ubatch * ubatch) {
321321
}
322322
}
323323

324+
bool llm_graph_input_mtp_hidden_state::can_reuse(const llm_graph_params & params) {
325+
data = params.mtp_hidden_state;
326+
327+
bool res = true;
328+
329+
res &= hidden_state != nullptr;
330+
res &= data != nullptr;
331+
res &= hidden_state->ne[1] == params.ubatch.n_tokens;
332+
333+
return res;
334+
}
335+
324336
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
325337
GGML_UNUSED(ubatch);
326338

src/llama-graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ class llm_graph_input_mtp_hidden_state : public llm_graph_input_i {
270270

271271
void set_input(const llama_ubatch * ubatch) override;
272272

273+
bool can_reuse(const llm_graph_params & params) override;
274+
273275
ggml_tensor * hidden_state = nullptr; // F32 [n_embd, n_tokens]
274276

275277
const float * data = nullptr;

tools/server/server-context.cpp

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ struct server_context_impl {
704704

705705
add_bos_token = llama_vocab_get_add_bos(vocab);
706706

707-
if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP || params_base.speculative.has_dft()) {
707+
if (params_base.speculative.has_dft()) {
708708
const auto & params_spec = params_base.speculative;
709709

710710
auto params_dft = params_base;
@@ -724,23 +724,39 @@ struct server_context_impl {
724724
}
725725

726726
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
727-
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
728727

729-
if (params_base.speculative.requires_dft() && params_base.speculative.has_dft()) {
730-
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
728+
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
731729

732-
params_dft.model = params_spec.mparams_dft;
730+
auto mparams_dft = common_model_params_to_llama(params_dft);
733731

734-
auto mparams_dft = common_model_params_to_llama(params_dft);
732+
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
733+
if (model_dft == nullptr) {
734+
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
735+
return false;
736+
}
735737

736-
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
737-
if (model_dft == nullptr) {
738-
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
739-
return false;
740-
}
738+
params_base.speculative.model_dft = model_dft.get();
739+
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
740+
} else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) {
741+
const auto & params_spec = params_base.speculative;
742+
743+
auto params_dft = params_base;
744+
745+
params_dft.n_parallel = 1;
746+
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
747+
params_dft.n_batch = llama_n_ctx_seq(ctx);
748+
params_dft.devices = params_spec.devices;
749+
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
750+
params_dft.cache_type_k = params_spec.cache_type_k;
751+
params_dft.cache_type_v = params_spec.cache_type_v;
741752

742-
params_base.speculative.model_dft = model_dft.get();
753+
if (params_spec.cpuparams.n_threads > 0) {
754+
params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads;
755+
params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
743756
}
757+
758+
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
759+
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
744760
}
745761

746762
std::string & mmproj_path = params_base.mmproj.path;
@@ -2162,10 +2178,16 @@ struct server_context_impl {
21622178

21632179
if (slot.task->params.speculative.n_min > (int) draft.size()) {
21642180
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
2165-
// fallback to normal decoding
2166-
slot.i_batch = slot.i_batch_dft[0];
21672181
slot.drafted.clear();
2168-
slot.i_batch_dft.clear();
2182+
if (slot.task->params.speculative.type != COMMON_SPECULATIVE_TYPE_MTP) {
2183+
// Non-MTP speculation can safely fall back to plain decoding.
2184+
slot.i_batch = slot.i_batch_dft[0];
2185+
slot.i_batch_dft.clear();
2186+
} else {
2187+
// MTP still needs a 0-accept speculative round so accept() can stage
2188+
// the frontier hidden state for the next shifted first pass.
2189+
slot.i_batch = -1;
2190+
}
21692191
} else {
21702192
// keep track of total number of drafted tokens tested
21712193
slot.n_draft_total += draft.size();

0 commit comments

Comments
 (0)