Skip to content

Commit b82080b

Browse files
jimbothigpenclaude
andcommitted
feat(imatrix): port ggml-org#23476 — collect imatrix for MTP/NextN draft-head layers
Adds --imat-mtp flag to llama-imatrix. When set on a model that has bundled NextN layers, creates a second LLAMA_CONTEXT_TYPE_MTP context and runs a forward pass through the draft head after each trunk batch, feeding (token[p+1], h[p]) pairs via the pre-norm embedding interface. Adaptations vs upstream ggml-org#23476: - Renamed --mtp flag to --imat-mtp (collision with deprecated --mtp alias at common/arg.cpp:1387 which maps to --spec-type draft-mtp) - llama_set_embeddings_pre_norm called as 3-arg (ctx, true, false) per our fork's API (src/llama-ext.h:107 / llama-context.cpp:3839) - llama_model_n_nextn_layer used instead of PR's new accessor; PR's additions to llama.h and llama-model.cpp dropped (accessor already present in fork as llama_model_n_nextn_layer) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8260329 commit b82080b

3 files changed

Lines changed: 132 additions & 3 deletions

File tree

common/arg.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2891,6 +2891,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
28912891
params.parse_special = true;
28922892
}
28932893
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
2894+
add_opt(common_arg(
2895+
{"--imat-mtp"},
2896+
string_format("also activate the MTP/NextN draft head during imatrix collection so its tensors "
2897+
"(blk.<n>.nextn.eh_proj etc.) receive activations. No-op if the model has no MTP layers. "
2898+
"(default: %s)", params.imat_mtp ? "true" : "false"),
2899+
[](common_params & params) {
2900+
params.imat_mtp = true;
2901+
}
2902+
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
28942903
add_opt(common_arg(
28952904
{"-pps"},
28962905
string_format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"),

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ struct common_params {
713713
bool compute_ppl = true; // whether to compute perplexity
714714
bool show_statistics = false; // show imatrix statistics per tensor
715715
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
716+
bool imat_mtp = false; // also activate the MTP/NextN draft head so its tensors get imatrix data
716717

717718
// cvector-generator params
718719
int n_pca_batch = 100;

tools/imatrix/imatrix.cpp

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "common.h"
33
#include "log.h"
44
#include "llama.h"
5+
#include "../../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
56
#include "gguf.h"
67

78
#include <algorithm>
@@ -916,7 +917,53 @@ static void process_logits(
916917
}
917918
}
918919

919-
static bool compute_imatrix(llama_context * ctx, const common_params & params, const int32_t n_ctx) {
920+
// Run a forward pass through the MTP/NextN draft head so its weights
921+
// (blk.<n_layer>.nextn.eh_proj etc.) receive activations and get recorded by
922+
// the imatrix collector. Mirrors common_speculative_state_draft_mtp::process():
923+
// the MTP head at position p is fed the next-token id (tokens[p+1]) paired
924+
// with the trunk's pre-norm hidden state h[p]. The last position of the
925+
// chunk has no next-token target and is dropped.
926+
static bool compute_imatrix_mtp(
927+
llama_context * ctx_tgt,
928+
llama_context * ctx_mtp,
929+
const llama_token * tokens, // n_tokens consecutive tokens (covers this batch)
930+
int32_t n_tokens,
931+
int32_t pos_first, // absolute position of tokens[0] in the chunk
932+
int32_t n_embd,
933+
llama_seq_id seq_id,
934+
llama_batch & mtp_batch) { // pre-allocated, embd-capable batch (token+embd both alloc'd)
935+
if (n_tokens < 2) {
936+
return true; // need at least one (h[p], token[p+1]) pair
937+
}
938+
const int32_t n_pairs = n_tokens - 1;
939+
940+
const size_t row_bytes = (size_t) n_embd * sizeof(float);
941+
942+
common_batch_clear(mtp_batch);
943+
944+
for (int32_t k = 0; k < n_pairs; ++k) {
945+
// MTP position p+1 carries the next-token id and h[p] from the trunk.
946+
common_batch_add(mtp_batch, tokens[k + 1], pos_first + k + 1, { seq_id }, false);
947+
}
948+
949+
// Fill h[p] rows from the trunk's pre-norm output.
950+
for (int32_t k = 0; k < n_pairs; ++k) {
951+
const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, k);
952+
if (h == nullptr) {
953+
LOG_ERR("%s: trunk did not produce pre-norm embedding at row %d (was output enabled?)\n", __func__, k);
954+
return false;
955+
}
956+
std::memcpy(mtp_batch.embd + (size_t) k * n_embd, h, row_bytes);
957+
}
958+
959+
if (llama_decode(ctx_mtp, mtp_batch) != 0) {
960+
LOG_ERR("%s: llama_decode(ctx_mtp) failed\n", __func__);
961+
return false;
962+
}
963+
return true;
964+
}
965+
966+
static bool compute_imatrix(llama_context * ctx, llama_context * ctx_mtp, const common_params & params, const int32_t n_ctx) {
920967
const llama_model * model = llama_get_model(ctx);
921968
const llama_vocab * vocab = llama_model_get_vocab(model);
922969

@@ -975,15 +1022,37 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
9751022

9761023
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
9771024

1025+
// Optional MTP/NextN draft head batch. Only used when ctx_mtp != nullptr.
1026+
// llama_batch_init() only allocates one of token/embd; MTP needs both, so we
1027+
// patch in a token buffer alongside (same trick as common/speculative.cpp).
1028+
llama_batch mtp_batch = {};
1029+
bool mtp_enabled = (ctx_mtp != nullptr);
1030+
const int n_embd = llama_model_n_embd(model);
1031+
if (mtp_enabled) {
1032+
if (n_seq != 1) {
1033+
LOG_WRN("%s: --imat-mtp is only supported with n_seq=1 (one sequence per batch); disabling MTP collection\n", __func__);
1034+
mtp_enabled = false;
1035+
} else {
1036+
mtp_batch = llama_batch_init(std::min(n_batch, n_ctx), n_embd, 1);
1037+
mtp_batch.token = (llama_token *) malloc(sizeof(llama_token) * std::min(n_batch, n_ctx));
1038+
}
1039+
}
1040+
9781041
std::vector<float> logits;
9791042
if (params.compute_ppl && num_batches > 1) {
9801043
logits.reserve((size_t)n_ctx * n_vocab);
9811044
}
9821045

983-
LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
1046+
LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d%s\n",
1047+
__func__, n_chunk, n_ctx, n_batch, n_seq, mtp_enabled ? " (mtp head active)" : "");
9841048

9851049
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
9861050

1051+
if (mtp_enabled) {
1052+
// Trunk must expose the pre-norm hidden state so we can feed it into the MTP head.
1053+
llama_set_embeddings_pre_norm(ctx, true, /*masked=*/false);
1054+
}
1055+
9871056
for (int i = 0; i < n_chunk; i += n_seq) {
9881057
const int start = i * n_ctx;
9891058
const int end = start + n_ctx;
@@ -994,6 +1063,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
9941063

9951064
// clear the KV cache
9961065
llama_memory_clear(llama_get_memory(ctx), true);
1066+
if (mtp_enabled) {
1067+
llama_memory_clear(llama_get_memory(ctx_mtp), true);
1068+
}
9971069

9981070
for (int j = 0; j < num_batches; ++j) {
9991071
const int batch_start = start + j * n_batch;
@@ -1027,9 +1099,27 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
10271099
if (llama_decode(ctx, batch)) {
10281100
LOG_ERR("%s : failed to eval\n", __func__);
10291101
llama_batch_free(batch);
1102+
if (mtp_enabled) {
1103+
free(mtp_batch.token);
1104+
llama_batch_free(mtp_batch);
1105+
}
10301106
return false;
10311107
}
10321108

1109+
if (mtp_enabled) {
1110+
// The sub-batch covers absolute positions [batch_start, batch_start + batch_size).
1111+
// tokens.data() + batch_start gives the matching token ids.
1112+
const int32_t pos_first = j * n_batch;
1113+
if (!compute_imatrix_mtp(ctx, ctx_mtp,
1114+
tokens.data() + batch_start, batch_size,
1115+
pos_first, n_embd, /*seq_id=*/0, mtp_batch)) {
1116+
llama_batch_free(batch);
1117+
free(mtp_batch.token);
1118+
llama_batch_free(mtp_batch);
1119+
return false;
1120+
}
1121+
}
1122+
10331123
if (params.compute_ppl && num_batches > 1) {
10341124
const auto * batch_logits = llama_get_logits(ctx);
10351125
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
@@ -1089,6 +1179,10 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
10891179
}
10901180

10911181
llama_batch_free(batch);
1182+
if (mtp_enabled) {
1183+
free(mtp_batch.token);
1184+
llama_batch_free(mtp_batch);
1185+
}
10921186

10931187
return true;
10941188
}
@@ -1303,7 +1397,28 @@ int main(int argc, char ** argv) {
13031397
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
13041398
}
13051399

1306-
if (!compute_imatrix(ctx, params, n_ctx)) {
1400+
// Optional second context for the MTP/NextN draft head. Shares the same model
1401+
// as `ctx`; uses LLAMA_CONTEXT_TYPE_MTP so the MTP graph is built/run instead
1402+
// of the trunk graph. The trunk feeds it pre-norm hidden states each batch.
1403+
llama_context * ctx_mtp = nullptr;
1404+
if (params.imat_mtp) {
1405+
if (llama_model_n_nextn_layer(model) == 0) {
1406+
LOG_WRN("%s: --imat-mtp requested but model has no MTP/NextN layers; ignoring\n", __func__);
1407+
} else {
1408+
auto cparams_mtp = common_context_params_to_llama(params);
1409+
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
1410+
cparams_mtp.n_rs_seq = 0;
1411+
ctx_mtp = llama_init_from_model(model, cparams_mtp);
1412+
if (ctx_mtp == nullptr) {
1413+
LOG_ERR("%s : failed to create MTP context\n", __func__);
1414+
return 1;
1415+
}
1416+
LOG_INF("%s: created MTP draft-head context for imatrix collection\n", __func__);
1417+
}
1418+
}
1419+
1420+
if (!compute_imatrix(ctx, ctx_mtp, params, n_ctx)) {
1421+
if (ctx_mtp) llama_free(ctx_mtp);
13071422
return 1;
13081423
}
13091424

@@ -1312,6 +1427,10 @@ int main(int argc, char ** argv) {
13121427
LOG("\n");
13131428
llama_perf_context_print(ctx);
13141429

1430+
if (ctx_mtp) {
1431+
llama_free(ctx_mtp);
1432+
}
1433+
13151434
llama_backend_free();
13161435

13171436
return 0;

0 commit comments

Comments
 (0)