22#include " common.h"
33#include " log.h"
44#include " llama.h"
5+ #include " ../../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
56#include " gguf.h"
67
78#include < algorithm>
@@ -916,7 +917,53 @@ static void process_logits(
916917 }
917918}
918919
919- static bool compute_imatrix (llama_context * ctx, const common_params & params, const int32_t n_ctx) {
920+ // Run a forward pass through the MTP/NextN draft head so its weights
921+ // (blk.<n_layer>.nextn.eh_proj etc.) receive activations and get recorded by
922+ // the imatrix collector. Mirrors common_speculative_state_draft_mtp::process():
923+ // the MTP head at position p is fed the next-token id (tokens[p+1]) paired
924+ // with the trunk's pre-norm hidden state h[p]. The last position of the
925+ // chunk has no next-token target and is dropped.
926+ static bool compute_imatrix_mtp (
927+ llama_context * ctx_tgt,
928+ llama_context * ctx_mtp,
929+ const llama_token * tokens, // n_tokens consecutive tokens (covers this batch)
930+ int32_t n_tokens,
931+ int32_t pos_first, // absolute position of tokens[0] in the chunk
932+ int32_t n_embd,
933+ llama_seq_id seq_id,
934+ llama_batch & mtp_batch) { // pre-allocated, embd-capable batch (token+embd both alloc'd)
935+ if (n_tokens < 2 ) {
936+ return true ; // need at least one (h[p], token[p+1]) pair
937+ }
938+ const int32_t n_pairs = n_tokens - 1 ;
939+
940+ const size_t row_bytes = (size_t ) n_embd * sizeof (float );
941+
942+ common_batch_clear (mtp_batch);
943+
944+ for (int32_t k = 0 ; k < n_pairs; ++k) {
945+ // MTP position p+1 carries the next-token id and h[p] from the trunk.
946+ common_batch_add (mtp_batch, tokens[k + 1 ], pos_first + k + 1 , { seq_id }, false );
947+ }
948+
949+ // Fill h[p] rows from the trunk's pre-norm output.
950+ for (int32_t k = 0 ; k < n_pairs; ++k) {
951+ const float * h = llama_get_embeddings_pre_norm_ith (ctx_tgt, k);
952+ if (h == nullptr ) {
953+ LOG_ERR (" %s: trunk did not produce pre-norm embedding at row %d (was output enabled?)\n " , __func__, k);
954+ return false ;
955+ }
956+ std::memcpy (mtp_batch.embd + (size_t ) k * n_embd, h, row_bytes);
957+ }
958+
959+ if (llama_decode (ctx_mtp, mtp_batch) != 0 ) {
960+ LOG_ERR (" %s: llama_decode(ctx_mtp) failed\n " , __func__);
961+ return false ;
962+ }
963+ return true ;
964+ }
965+
966+ static bool compute_imatrix (llama_context * ctx, llama_context * ctx_mtp, const common_params & params, const int32_t n_ctx) {
920967 const llama_model * model = llama_get_model (ctx);
921968 const llama_vocab * vocab = llama_model_get_vocab (model);
922969
@@ -975,15 +1022,37 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
9751022
9761023 llama_batch batch = llama_batch_init (std::min (n_batch, n_ctx*n_seq), 0 , 1 );
9771024
1025+ // Optional MTP/NextN draft head batch. Only used when ctx_mtp != nullptr.
1026+ // llama_batch_init() only allocates one of token/embd; MTP needs both, so we
1027+ // patch in a token buffer alongside (same trick as common/speculative.cpp).
1028+ llama_batch mtp_batch = {};
1029+ bool mtp_enabled = (ctx_mtp != nullptr );
1030+ const int n_embd = llama_model_n_embd (model);
1031+ if (mtp_enabled) {
1032+ if (n_seq != 1 ) {
1033+ LOG_WRN (" %s: --imat-mtp is only supported with n_seq=1 (one sequence per batch); disabling MTP collection\n " , __func__);
1034+ mtp_enabled = false ;
1035+ } else {
1036+ mtp_batch = llama_batch_init (std::min (n_batch, n_ctx), n_embd, 1 );
1037+ mtp_batch.token = (llama_token *) malloc (sizeof (llama_token) * std::min (n_batch, n_ctx));
1038+ }
1039+ }
1040+
9781041 std::vector<float > logits;
9791042 if (params.compute_ppl && num_batches > 1 ) {
9801043 logits.reserve ((size_t )n_ctx * n_vocab);
9811044 }
9821045
983- LOG_INF (" %s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n " , __func__, n_chunk, n_ctx, n_batch, n_seq);
1046+ LOG_INF (" %s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d%s\n " ,
1047+ __func__, n_chunk, n_ctx, n_batch, n_seq, mtp_enabled ? " (mtp head active)" : " " );
9841048
9851049 std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
9861050
1051+ if (mtp_enabled) {
1052+ // Trunk must expose the pre-norm hidden state so we can feed it into the MTP head.
1053+ llama_set_embeddings_pre_norm (ctx, true , /* masked=*/ false );
1054+ }
1055+
9871056 for (int i = 0 ; i < n_chunk; i += n_seq) {
9881057 const int start = i * n_ctx;
9891058 const int end = start + n_ctx;
@@ -994,6 +1063,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
9941063
9951064 // clear the KV cache
9961065 llama_memory_clear (llama_get_memory (ctx), true );
1066+ if (mtp_enabled) {
1067+ llama_memory_clear (llama_get_memory (ctx_mtp), true );
1068+ }
9971069
9981070 for (int j = 0 ; j < num_batches; ++j) {
9991071 const int batch_start = start + j * n_batch;
@@ -1027,9 +1099,27 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
10271099 if (llama_decode (ctx, batch)) {
10281100 LOG_ERR (" %s : failed to eval\n " , __func__);
10291101 llama_batch_free (batch);
1102+ if (mtp_enabled) {
1103+ free (mtp_batch.token );
1104+ llama_batch_free (mtp_batch);
1105+ }
10301106 return false ;
10311107 }
10321108
1109+ if (mtp_enabled) {
1110+ // The sub-batch covers absolute positions [batch_start, batch_start + batch_size).
1111+ // tokens.data() + batch_start gives the matching token ids.
1112+ const int32_t pos_first = j * n_batch;
1113+ if (!compute_imatrix_mtp (ctx, ctx_mtp,
1114+ tokens.data () + batch_start, batch_size,
1115+ pos_first, n_embd, /* seq_id=*/ 0 , mtp_batch)) {
1116+ llama_batch_free (batch);
1117+ free (mtp_batch.token );
1118+ llama_batch_free (mtp_batch);
1119+ return false ;
1120+ }
1121+ }
1122+
10331123 if (params.compute_ppl && num_batches > 1 ) {
10341124 const auto * batch_logits = llama_get_logits (ctx);
10351125 logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
@@ -1089,6 +1179,10 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
10891179 }
10901180
10911181 llama_batch_free (batch);
1182+ if (mtp_enabled) {
1183+ free (mtp_batch.token );
1184+ llama_batch_free (mtp_batch);
1185+ }
10921186
10931187 return true ;
10941188}
@@ -1303,7 +1397,28 @@ int main(int argc, char ** argv) {
13031397 LOG_INF (" %s\n " , common_params_get_system_info (params).c_str ());
13041398 }
13051399
1306- if (!compute_imatrix (ctx, params, n_ctx)) {
1400+ // Optional second context for the MTP/NextN draft head. Shares the same model
1401+ // as `ctx`; uses LLAMA_CONTEXT_TYPE_MTP so the MTP graph is built/run instead
1402+ // of the trunk graph. The trunk feeds it pre-norm hidden states each batch.
1403+ llama_context * ctx_mtp = nullptr ;
1404+ if (params.imat_mtp ) {
1405+ if (llama_model_n_nextn_layer (model) == 0 ) {
1406+ LOG_WRN (" %s: --imat-mtp requested but model has no MTP/NextN layers; ignoring\n " , __func__);
1407+ } else {
1408+ auto cparams_mtp = common_context_params_to_llama (params);
1409+ cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP ;
1410+ cparams_mtp.n_rs_seq = 0 ;
1411+ ctx_mtp = llama_init_from_model (model, cparams_mtp);
1412+ if (ctx_mtp == nullptr ) {
1413+ LOG_ERR (" %s : failed to create MTP context\n " , __func__);
1414+ return 1 ;
1415+ }
1416+ LOG_INF (" %s: created MTP draft-head context for imatrix collection\n " , __func__);
1417+ }
1418+ }
1419+
1420+ if (!compute_imatrix (ctx, ctx_mtp, params, n_ctx)) {
1421+ if (ctx_mtp) llama_free (ctx_mtp);
13071422 return 1 ;
13081423 }
13091424
@@ -1312,6 +1427,10 @@ int main(int argc, char ** argv) {
13121427 LOG (" \n " );
13131428 llama_perf_context_print (ctx);
13141429
1430+ if (ctx_mtp) {
1431+ llama_free (ctx_mtp);
1432+ }
1433+
13151434 llama_backend_free ();
13161435
13171436 return 0 ;
0 commit comments