Skip to content

Commit 1c7cbce

Browse files
committed
Merge remote-tracking branch 'mtp/mtp-clean' into turboquant-mtp
2 parents 9e7cbb4 + 5d5f1b4 commit 1c7cbce

45 files changed

Lines changed: 1610 additions & 160 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

common/arg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3728,12 +3728,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
37283728
}
37293729
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
37303730
add_opt(common_arg(
3731-
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
3731+
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
37323732
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
37333733
common_speculative_type_to_str(params.speculative.type).c_str()),
37343734
[](common_params & params, const std::string & value) {
37353735
if (value == "none") {
37363736
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
3737+
} else if (value == "mtp") {
3738+
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
37373739
} else if (value == "ngram-cache") {
37383740
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
37393741
} else if (value == "ngram-simple") {

common/common.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,11 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
14261426
goto done;
14271427
}
14281428

1429+
if (llama_n_rs_seq(ctx) > 0) {
1430+
res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED;
1431+
goto done;
1432+
}
1433+
14291434
// try to remove the last tokens
14301435
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
14311436
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
@@ -1496,6 +1501,12 @@ struct llama_context_params common_context_params_to_llama(const common_params &
14961501

14971502
cparams.n_ctx = params.n_ctx;
14981503
cparams.n_seq_max = params.n_parallel;
1504+
{
1505+
// enable partial rollback only for MTP, each recurrent slot requires memory
1506+
// and MTP uses max 3-4 slots vs other techniques
1507+
const bool has_mtp_spec = params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP;
1508+
cparams.n_rs_seq = has_mtp_spec ? (uint32_t) params.speculative.draft.n_max : 0u;
1509+
}
14991510
cparams.n_batch = params.n_batch;
15001511
cparams.n_ubatch = params.n_ubatch;
15011512
cparams.n_threads = params.cpuparams.n_threads;

common/common.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ enum common_speculative_type {
159159
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
160160
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
161161
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
162+
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction
162163
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
163164
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
164165
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -348,11 +349,17 @@ struct common_params_speculative_ngram_cache {
348349
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
349350
};
350351

352+
struct common_params_speculative_mtp {
353+
llama_model * model = nullptr;
354+
llama_context_params cparams;
355+
};
356+
351357
struct common_params_speculative {
352358
// TODO: become a vector in order to support "chains of speculators"
353359
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
354360

355361
common_params_speculative_draft draft;
362+
common_params_speculative_mtp mtp;
356363

357364
common_params_speculative_ngram_mod ngram_mod;
358365
common_params_speculative_ngram_map ngram_simple;
@@ -883,9 +890,10 @@ std::string common_get_model_endpoint();
883890
//
884891

885892
enum common_context_seq_rm_type {
886-
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
887-
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
888-
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
893+
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
894+
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
895+
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
896+
COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rs_seq
889897
};
890898

891899
// check if the llama_context can remove sequences

common/speculative.cpp

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222
COMMON_SPECULATIVE_TYPE_NONE,
2323
COMMON_SPECULATIVE_TYPE_DRAFT,
2424
COMMON_SPECULATIVE_TYPE_EAGLE3,
25+
COMMON_SPECULATIVE_TYPE_MTP,
2526
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2627
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
2728
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334
{"none", COMMON_SPECULATIVE_TYPE_NONE},
3435
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
3536
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
37+
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
3638
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3739
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
3840
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -642,6 +644,171 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
642644
}
643645
};
644646

647+
struct common_speculative_state_mtp : public common_speculative_state {
648+
llama_context * ctx_tgt = nullptr;
649+
llama_context * ctx_mtp = nullptr;
650+
651+
llama_batch batch; // single token draft step
652+
common_sampler * smpl = nullptr;
653+
int32_t n_embd = 0;
654+
655+
uint16_t last_n_drafted = 0;
656+
int32_t last_n_accepted = -1;
657+
658+
common_speculative_state_mtp(enum common_speculative_type type,
659+
llama_context * ctx_tgt,
660+
llama_context * ctx_mtp)
661+
: common_speculative_state(type), ctx_tgt(ctx_tgt), ctx_mtp(ctx_mtp) {
662+
GGML_ASSERT(ctx_tgt && ctx_mtp);
663+
const llama_model * model_mtp = llama_get_model(ctx_mtp);
664+
n_embd = llama_model_n_embd(model_mtp);
665+
666+
{
667+
common_params_sampling sparams;
668+
sparams.no_perf = false;
669+
sparams.top_k = 1;
670+
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
671+
smpl = common_sampler_init(model_mtp, sparams);
672+
}
673+
674+
// TODO: multiple seq support
675+
batch = llama_batch_init(/*n_tokens=*/ 1, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
676+
batch.token = (llama_token *) malloc(sizeof(llama_token));
677+
batch.n_tokens = 1;
678+
batch.n_seq_id[0] = 1;
679+
batch.seq_id[0][0] = 0;
680+
batch.logits[0] = 1;
681+
682+
llama_set_mtp(ctx_tgt, ctx_mtp);
683+
}
684+
685+
~common_speculative_state_mtp() override {
686+
llama_set_mtp(ctx_tgt, nullptr);
687+
llama_batch_free(batch);
688+
common_sampler_free(smpl);
689+
if (ctx_mtp) {
690+
llama_free(ctx_mtp);
691+
}
692+
}
693+
694+
void begin(const llama_tokens & prompt) override {
695+
last_n_accepted = -1;
696+
last_n_drafted = 0;
697+
698+
const int32_t N = (int32_t) prompt.size();
699+
if (N <= 0) {
700+
return;
701+
}
702+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
703+
if (pos_max < N - 1) {
704+
LOG_WRN("%s: ctx_mtp pos_max=%d < N-1=%d — "
705+
"streaming hook may not be registered or not all prefill rows "
706+
"have logits=true. Drafts may degrade.\n",
707+
__func__, (int) pos_max, N - 1);
708+
}
709+
}
710+
711+
void draft(
712+
const common_params_speculative & params,
713+
const llama_tokens & prompt_tgt,
714+
llama_token id_last,
715+
llama_tokens & draft_tokens) override {
716+
GGML_UNUSED(prompt_tgt);
717+
draft_tokens.clear();
718+
719+
// accept with no-accepts (i.e. 0 accepts) returns early, but we still need to remove from the MTP kv-cache
720+
// TODO: check if bug in other spec states
721+
if (last_n_drafted > 0) {
722+
const int32_t n_to_drop = (int32_t) last_n_drafted - 1;
723+
if (n_to_drop > 0) {
724+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
725+
if (pos_max >= 0) {
726+
const llama_pos drop_from = pos_max - n_to_drop + 1;
727+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), 0, drop_from, -1);
728+
}
729+
}
730+
last_n_drafted = 0;
731+
last_n_accepted = 0;
732+
}
733+
734+
const int32_t n_max = std::max(1, params.draft.n_max);
735+
const size_t row_bytes = (size_t) n_embd * sizeof(float);
736+
737+
llama_token cond_tok = id_last;
738+
llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0) + 1;
739+
740+
// auto-regressive loop for MTP
741+
for (int32_t k = 0; k < n_max; ++k) {
742+
ggml_tensor * src;
743+
int32_t src_row;
744+
if (k == 0) {
745+
src = llama_context_get_t_h_pre_norm(ctx_tgt);
746+
if (last_n_accepted < 0) {
747+
// First draft after begin(): trunk's most recent decode is
748+
// the last prefill ubatch; its last row is h_{N-1}.
749+
src_row = (src && src->ne[1] > 0) ? (int32_t) src->ne[1] - 1 : 0;
750+
} else {
751+
src_row = last_n_accepted;
752+
}
753+
llama_synchronize(ctx_tgt);
754+
} else {
755+
// for the AR path get the mtp_out from the mtp ctx
756+
src = llama_context_get_t_mtp_out(ctx_mtp);
757+
src_row = src ? (int32_t) src->ne[1] - 1 : 0;
758+
llama_synchronize(ctx_mtp);
759+
}
760+
if (!src) {
761+
LOG_WRN("%s: missing source tensor at k=%d; stopping chain\n", __func__, k);
762+
return;
763+
}
764+
ggml_backend_tensor_get(src, batch.embd,
765+
(size_t) src_row * row_bytes, row_bytes);
766+
767+
batch.token[0] = cond_tok;
768+
batch.pos[0] = pos;
769+
770+
const int32_t dec_rc = llama_decode(ctx_mtp, batch);
771+
if (dec_rc != 0) {
772+
LOG_DBG("%s: llama_decode rc=%d at k=%d; stopping chain\n", __func__, dec_rc, k);
773+
return;
774+
}
775+
776+
const llama_token best = common_sampler_sample(smpl, ctx_mtp, 0);
777+
common_sampler_accept(smpl, best, /*accept_grammar=*/ false);
778+
draft_tokens.push_back(best);
779+
cond_tok = best;
780+
++pos;
781+
}
782+
783+
last_n_drafted = (uint16_t) draft_tokens.size();
784+
}
785+
786+
void accept(uint16_t n_accepted) override {
787+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
788+
const int32_t n_drafted_last = (int32_t) last_n_drafted;
789+
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
790+
if (pos_max < 0) {
791+
last_n_accepted = (int32_t) n_accepted;
792+
return;
793+
}
794+
if (n_to_drop > 0) {
795+
const llama_pos drop_from = pos_max - n_to_drop + 1;
796+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), /*seq_id=*/ 0,
797+
/*p0=*/ drop_from, /*p1=*/ -1);
798+
}
799+
last_n_drafted = 0;
800+
last_n_accepted = (int32_t) n_accepted;
801+
}
802+
803+
int32_t n_max(const common_params_speculative & params) const override {
804+
return std::max(1, params.draft.n_max);
805+
}
806+
807+
int32_t n_min(const common_params_speculative & params) const override {
808+
return std::max(1, params.draft.n_min);
809+
}
810+
};
811+
645812
// state of self-speculation (simple implementation, not ngram-map)
646813
struct common_speculative_state_ngram_simple : public common_speculative_state {
647814
common_ngram_simple_config config;
@@ -995,6 +1162,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
9951162
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
9961163
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
9971164
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
1165+
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
9981166
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
9991167
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
10001168
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
@@ -1026,11 +1194,24 @@ common_speculative * common_speculative_init(
10261194
}
10271195
}
10281196

1197+
llama_context * ctx_mtp = nullptr;
1198+
if (params.type == COMMON_SPECULATIVE_TYPE_MTP) {
1199+
ctx_mtp = llama_init_from_model(params.mtp.model, params.mtp.cparams);
1200+
if (ctx_mtp == nullptr) {
1201+
LOG_ERR("%s", "failed to create MTP context\n");
1202+
if (ctx_dft) {
1203+
llama_free(ctx_dft);
1204+
}
1205+
return nullptr;
1206+
}
1207+
}
1208+
10291209
// Compute the implementations to use based on the config and their order of preference
10301210
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
10311211
{
10321212
bool has_draft = !params.draft.mparams.path.empty();
10331213
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
1214+
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_mtp != nullptr);
10341215

10351216
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
10361217
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1077,6 +1258,9 @@ common_speculative * common_speculative_init(
10771258
if (has_draft_eagle3) {
10781259
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
10791260
}
1261+
if (has_mtp) {
1262+
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
1263+
}
10801264
}
10811265

10821266
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@@ -1101,6 +1285,11 @@ common_speculative * common_speculative_init(
11011285
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
11021286
break;
11031287
}
1288+
case COMMON_SPECULATIVE_TYPE_MTP: {
1289+
impls.push_back(std::make_unique<common_speculative_state_mtp>(
1290+
config.type, ctx_tgt, ctx_mtp));
1291+
break;
1292+
}
11041293
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
11051294
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
11061295

0 commit comments

Comments
 (0)