Skip to content

Commit a0c02ae

Browse files
author
lvyichen
committed
fix arg: embeddings\n_min\reuse\model-draft\spec-type
1 parent 4f28cc8 commit a0c02ae

8 files changed

Lines changed: 90 additions & 39 deletions

File tree

common/arg.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3525,6 +3525,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35253525
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
35263526
} else if (value == "mtp") {
35273527
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
3528+
params.mtp = true;
35283529
} else {
35293530
throw std::invalid_argument("unknown speculative decoding type without draft model");
35303531
}

common/common.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
14501450

14511451
struct llama_context_params common_context_params_to_llama(const common_params & params) {
14521452
auto cparams = llama_context_default_params();
1453+
const bool mtp_needs_hidden_states = params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP;
14531454

14541455
cparams.n_ctx = params.n_ctx;
14551456
cparams.n_seq_max = params.n_parallel;
@@ -1458,7 +1459,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
14581459
cparams.n_threads = params.cpuparams.n_threads;
14591460
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
14601461
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
1461-
cparams.embeddings = params.embedding;
1462+
cparams.embeddings = params.embedding || mtp_needs_hidden_states;
14621463
cparams.rope_scaling_type = params.rope_scaling_type;
14631464
cparams.rope_freq_base = params.rope_freq_base;
14641465
cparams.rope_freq_scale = params.rope_freq_scale;

common/common.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,10 +356,6 @@ struct common_params_speculative {
356356
bool has_dft() const {
357357
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
358358
}
359-
360-
bool requires_dft() const {
361-
return type == COMMON_SPECULATIVE_TYPE_DRAFT || type == COMMON_SPECULATIVE_TYPE_EAGLE3;
362-
}
363359
};
364360

365361
struct common_params_vocoder {

common/speculative.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,7 @@ common_speculative * common_speculative_init(
15631563
cparams.n_threads_batch = llama_n_threads_batch(ctx_tgt);
15641564
}
15651565

1566+
llama_set_embeddings(ctx_tgt, true);
15661567
cparams.embeddings = true;
15671568

15681569
llama_context * ctx_mtp = llama_init_from_model(const_cast<llama_model *>(llama_get_model(ctx_tgt)), cparams);

examples/speculative-simple/speculative-simple.cpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,26 +55,23 @@ int main(int argc, char ** argv) {
5555
{
5656
const auto & params_spec = params.speculative;
5757

58-
auto params_dft = params;
59-
60-
params_dft.n_parallel = 1;
61-
params_dft.n_ctx = params_spec.n_ctx == 0 ? (int32_t) llama_n_ctx_seq(ctx_tgt) : params_spec.n_ctx;
62-
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
63-
params_dft.cache_type_k = params_spec.cache_type_k;
64-
params_dft.cache_type_v = params_spec.cache_type_v;
65-
params_dft.devices = params_spec.devices;
66-
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
67-
68-
if (params_spec.cpuparams.n_threads > 0) {
69-
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
70-
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
71-
}
72-
73-
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
74-
75-
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
58+
if (params_spec.has_dft()) {
59+
auto params_dft = params;
60+
61+
params_dft.n_parallel = 1;
62+
params_dft.n_ctx = params_spec.n_ctx == 0 ? (int32_t) llama_n_ctx_seq(ctx_tgt) : params_spec.n_ctx;
63+
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
64+
params_dft.cache_type_k = params_spec.cache_type_k;
65+
params_dft.cache_type_v = params_spec.cache_type_v;
66+
params_dft.devices = params_spec.devices;
67+
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
68+
69+
if (params_spec.cpuparams.n_threads > 0) {
70+
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
71+
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
72+
}
7673

77-
if (params_spec.requires_dft()) {
74+
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
7875
params_dft.model = params_spec.mparams_dft;
7976

8077
auto mparams_dft = common_model_params_to_llama(params_dft);
@@ -86,6 +83,25 @@ int main(int argc, char ** argv) {
8683
}
8784

8885
params.speculative.model_dft = model_dft.get();
86+
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
87+
} else if (params_spec.type == COMMON_SPECULATIVE_TYPE_MTP) {
88+
auto params_dft = params;
89+
90+
params_dft.n_parallel = 1;
91+
params_dft.n_ctx = params_spec.n_ctx == 0 ? (int32_t) llama_n_ctx_seq(ctx_tgt) : params_spec.n_ctx;
92+
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
93+
params_dft.cache_type_k = params_spec.cache_type_k;
94+
params_dft.cache_type_v = params_spec.cache_type_v;
95+
params_dft.devices = params_spec.devices;
96+
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
97+
98+
if (params_spec.cpuparams.n_threads > 0) {
99+
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
100+
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
101+
}
102+
103+
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
104+
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
89105
}
90106
}
91107

src/llama-graph.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,18 @@ void llm_graph_input_mtp_hidden_state::set_input(const llama_ubatch * ubatch) {
340340
}
341341
}
342342

343+
bool llm_graph_input_mtp_hidden_state::can_reuse(const llm_graph_params & params) {
344+
data = params.mtp_hidden_state;
345+
346+
bool res = true;
347+
348+
res &= hidden_state != nullptr;
349+
res &= data != nullptr;
350+
res &= hidden_state->ne[1] == params.ubatch.n_tokens;
351+
352+
return res;
353+
}
354+
343355
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
344356
GGML_UNUSED(ubatch);
345357

src/llama-graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ class llm_graph_input_mtp_hidden_state : public llm_graph_input_i {
270270

271271
void set_input(const llama_ubatch * ubatch) override;
272272

273+
bool can_reuse(const llm_graph_params & params) override;
274+
273275
ggml_tensor * hidden_state = nullptr; // F32 [n_embd, n_tokens]
274276

275277
const float * data = nullptr;

tools/server/server-context.cpp

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ struct server_context_impl {
716716

717717
add_bos_token = llama_vocab_get_add_bos(vocab);
718718

719-
if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP || params_base.speculative.has_dft()) {
719+
if (params_base.speculative.has_dft()) {
720720
const auto & params_spec = params_base.speculative;
721721

722722
auto params_dft = params_base;
@@ -736,23 +736,39 @@ struct server_context_impl {
736736
}
737737

738738
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
739-
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
740739

741-
if (params_base.speculative.requires_dft() && params_base.speculative.has_dft()) {
742-
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
740+
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
743741

744-
params_dft.model = params_spec.mparams_dft;
742+
auto mparams_dft = common_model_params_to_llama(params_dft);
745743

746-
auto mparams_dft = common_model_params_to_llama(params_dft);
744+
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
745+
if (model_dft == nullptr) {
746+
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
747+
return false;
748+
}
747749

748-
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
749-
if (model_dft == nullptr) {
750-
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
751-
return false;
752-
}
750+
params_base.speculative.model_dft = model_dft.get();
751+
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
752+
} else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) {
753+
const auto & params_spec = params_base.speculative;
754+
755+
auto params_dft = params_base;
756+
757+
params_dft.n_parallel = 1;
758+
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
759+
params_dft.n_batch = llama_n_ctx_seq(ctx);
760+
params_dft.devices = params_spec.devices;
761+
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
762+
params_dft.cache_type_k = params_spec.cache_type_k;
763+
params_dft.cache_type_v = params_spec.cache_type_v;
753764

754-
params_base.speculative.model_dft = model_dft.get();
765+
if (params_spec.cpuparams.n_threads > 0) {
766+
params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads;
767+
params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
755768
}
769+
770+
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
771+
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
756772
}
757773

758774
std::string & mmproj_path = params_base.mmproj.path;
@@ -2196,10 +2212,16 @@ struct server_context_impl {
21962212

21972213
if (slot.task->params.speculative.n_min > (int) draft.size()) {
21982214
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
2199-
// fallback to normal decoding
2200-
slot.i_batch = slot.i_batch_dft[0];
22012215
slot.drafted.clear();
2202-
slot.i_batch_dft.clear();
2216+
if (slot.task->params.speculative.type != COMMON_SPECULATIVE_TYPE_MTP) {
2217+
// Non-MTP speculation can safely fall back to plain decoding.
2218+
slot.i_batch = slot.i_batch_dft[0];
2219+
slot.i_batch_dft.clear();
2220+
} else {
2221+
// MTP still needs a 0-accept speculative round so accept() can stage
2222+
// the frontier hidden state for the next shifted first pass.
2223+
slot.i_batch = -1;
2224+
}
22032225
} else {
22042226
// keep track of total number of drafted tokens tested
22052227
slot.n_draft_total += draft.size();

0 commit comments

Comments
 (0)