Skip to content

Commit e18fd18

Browse files
committed
llama: avoid copying logits during prompt decode in MTP
1 parent 25b1bc9 commit e18fd18

9 files changed

Lines changed: 89 additions & 25 deletions

File tree

common/speculative.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,11 @@ struct common_speculative_impl {
146146

147147
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
148148

149-
// true if this implementation requires the target context to extract embeddings
149+
// true if this implementation requires the target context to extract post-norm embeddings
150150
virtual bool need_embd() const = 0;
151+
152+
// true if this implementation requires the target context to extract pre-norm embeddings
153+
virtual bool need_embd_pre_norm() const { return false; }
151154
};
152155

153156
struct common_speculative_impl_draft_simple : public common_speculative_impl {
@@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
429432
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
430433
}
431434

432-
llama_set_embeddings_pre_norm(ctx_tgt, true);
433-
llama_set_embeddings_pre_norm(ctx_dft, true);
435+
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
436+
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
434437

435438
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
436439

@@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
691694
}
692695

693696
bool need_embd() const override {
697+
return false;
698+
}
699+
700+
bool need_embd_pre_norm() const override {
694701
return true;
695702
}
696703
};
@@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) {
14081415
return false;
14091416
}
14101417

1418+
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
1419+
if (spec == nullptr) {
1420+
return false;
1421+
}
1422+
1423+
for (auto & impl : spec->impls) {
1424+
if (impl->need_embd_pre_norm()) {
1425+
return true;
1426+
}
1427+
}
1428+
1429+
return false;
1430+
}
1431+
14111432
void common_speculative_draft(common_speculative * spec) {
14121433
if (spec == nullptr) {
14131434
return;

common/speculative.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
5353
// process the batch and update the internal state of the speculative context
5454
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
5555

56-
// true if any implementation requires target embeddings to be extracted
56+
// true if any implementation requires target post-norm embeddings to be extracted
5757
bool common_speculative_need_embd(common_speculative * spec);
5858

59+
// true if any implementation requires target pre-norm embeddings to be extracted
60+
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
61+
5962
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
6063
void common_speculative_draft(common_speculative * spec);
6164

src/llama-context.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -895,8 +895,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
895895
throw std::runtime_error("no pre-norm embeddings");
896896
}
897897

898-
const int64_t j = output_resolve_row(i);
899898
const uint32_t n_embd = model.hparams.n_embd;
899+
900+
if (!cparams.embeddings_pre_norm_masked) {
901+
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
902+
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
903+
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
904+
}
905+
return embd_pre_norm.data + (size_t) i * n_embd;
906+
}
907+
908+
const int64_t j = output_resolve_row(i);
900909
return embd_pre_norm.data + j*n_embd;
901910
} catch (const std::exception & err) {
902911
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
@@ -1088,10 +1097,11 @@ void llama_context::set_embeddings(bool value) {
10881097
//sched_need_reserve = true;
10891098
}
10901099

1091-
void llama_context::set_embeddings_pre_norm(bool value) {
1092-
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1100+
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
1101+
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);
10931102

1094-
cparams.embeddings_pre_norm = value;
1103+
cparams.embeddings_pre_norm = value;
1104+
cparams.embeddings_pre_norm_masked = masked;
10951105
}
10961106

10971107
void llama_context::set_causal_attn(bool value) {
@@ -1737,6 +1747,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
17371747
};
17381748

17391749
int64_t n_outputs_prev = 0;
1750+
int64_t n_tokens_prev = 0;
17401751

17411752
do {
17421753
const auto & ubatch = mctx->get_ubatch();
@@ -1882,16 +1893,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
18821893

18831894
// extract pre-norm embeddings (hidden state before the final output norm)
18841895
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
1885-
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1886-
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
1887-
GGML_ASSERT(backend_h != nullptr);
1896+
{
1897+
const bool masked = cparams.embeddings_pre_norm_masked;
1898+
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
1899+
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
1900+
1901+
if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1902+
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
1903+
GGML_ASSERT(backend_h != nullptr);
18881904

1889-
const uint32_t n_embd = hparams.n_embd;
1890-
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
1905+
const uint32_t n_embd = hparams.n_embd;
1906+
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
18911907

1892-
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1893-
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
1894-
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
1908+
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
1909+
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
1910+
}
18951911
}
18961912

18971913
// Copy backend sampling output if this ubatch produced any sampling tensors.
@@ -1908,6 +1924,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
19081924
}
19091925

19101926
n_outputs_prev += n_outputs;
1927+
n_tokens_prev += ubatch.n_tokens;
19111928
} while (mctx->next());
19121929

19131930
// set to total number of outputs in the batch, for use in llama_get_logits_ith
@@ -1999,6 +2016,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
19992016
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
20002017
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
20012018

2019+
if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
2020+
// unmasked: pre-norm row exists for every token in the ubatch, not just
2021+
// those flagged via batch.logits[i] -> size by token count instead.
2022+
embd_pre_norm.size = (size_t) n_embd * n_batch;
2023+
}
2024+
20022025
// Allocate backend sampling output buffers if there are backend samplers configured.
20032026
const bool has_sampling = !sampling.samplers.empty();
20042027
if (has_sampling) {
@@ -3547,8 +3570,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
35473570
return ctx->get_embeddings_seq(seq_id);
35483571
}
35493572

3550-
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) {
3551-
ctx->set_embeddings_pre_norm(value);
3573+
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
3574+
ctx->set_embeddings_pre_norm(value, masked);
35523575
}
35533576

35543577
float * llama_get_embeddings_pre_norm(llama_context * ctx) {

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct llama_context {
110110
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
111111

112112
void set_embeddings (bool value);
113-
void set_embeddings_pre_norm(bool value);
113+
void set_embeddings_pre_norm(bool value, bool masked);
114114
void set_causal_attn(bool value);
115115
void set_warmup(bool value);
116116

src/llama-cparams.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ struct llama_cparams {
2828
float yarn_beta_slow;
2929

3030
bool embeddings;
31-
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
31+
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
32+
bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0
3233
bool causal_attn;
3334
bool offload_kqv;
3435
bool flash_attn;

src/llama-ext.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
9797
// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
9898
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value);
9999

100-
// mirrors:
101-
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
102-
LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx);
100+
// Set whether the context outputs pre-norm embeddings or not
101+
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
102+
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
103+
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked);
103104

104105
// mirrors:
106+
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
105107
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
108+
LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx);
106109
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);

src/models/qwen35.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
176176
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
177177
}
178178

179-
if (il == n_transformer_layers - 1 && inp_out_ids) {
179+
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
180180
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
181181
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
182182
}
@@ -211,6 +211,10 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
211211
cb(cur, "h_pre_norm", -1);
212212
res->t_h_pre_norm = cur;
213213

214+
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
215+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
216+
}
217+
214218
// Final norm
215219
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
216220

src/models/qwen35moe.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
199199
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
200200
}
201201

202-
if (il == n_transformer_layers - 1 && inp_out_ids) {
202+
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
203203
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
204204
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
205205
}
@@ -234,6 +234,10 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
234234
cb(cur, "h_pre_norm", -1);
235235
res->t_h_pre_norm = cur;
236236

237+
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
238+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
239+
}
240+
237241
// Final norm
238242
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
239243

tools/server/server-context.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ struct server_slot {
243243
return task->need_embd() || (spec && common_speculative_need_embd(spec));
244244
}
245245

246+
bool need_embd_pre_norm() const {
247+
GGML_ASSERT(task);
248+
return spec && common_speculative_need_embd_pre_norm(spec);
249+
}
250+
246251
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
247252
// also we cannot split if the pooling would require any past tokens
248253
// (MTP supports splitting — uses task->need_embd() not need_embd())

0 commit comments

Comments
 (0)