Skip to content

Commit 9140cb9

Browse files
committed
rename pre_norm to nextn
1 parent fbd94d8 commit 9140cb9

11 files changed

Lines changed: 115 additions & 121 deletions

File tree

common/speculative.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "common.h"
44
#include "ggml.h"
55
#include "llama.h"
6-
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
6+
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
77
#include "log.h"
88
#include "ngram-cache.h"
99
#include "ngram-map.h"
@@ -162,7 +162,7 @@ struct common_speculative_impl {
162162
virtual bool need_embd() const = 0;
163163

164164
// true if this implementation requires the target context to extract pre-norm embeddings
165-
virtual bool need_embd_pre_norm() const { return false; }
165+
virtual bool need_embd_nextn() const { return false; }
166166
};
167167

168168
struct common_speculative_impl_draft_simple : public common_speculative_impl {
@@ -487,8 +487,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
487487
}
488488
}
489489

490-
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
491-
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
490+
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
491+
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
492492

493493
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
494494

@@ -583,7 +583,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
583583
// ^--- this is a problem
584584
// TODO:this is generally true, but would be nice to assert it
585585
{
586-
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
586+
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
587587
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
588588

589589
//{
@@ -625,7 +625,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
625625
verify_h[seq_id].resize((size_t) n_rows * n_embd);
626626

627627
for (int32_t i = 0; i < n_rows; ++i) {
628-
const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i);
628+
const float * h = llama_get_embeddings_nextn_ith(ctx_tgt, i_batch_beg[seq_id] + i);
629629
std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes);
630630
}
631631

@@ -686,7 +686,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
686686
auto * smpl = smpls[seq_id].get();
687687

688688
common_sampler_sample(smpl, ctx_dft, i_batch, true);
689-
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
689+
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
690690
++i_batch;
691691

692692
const auto * cur_p = common_sampler_get_candidates(smpl, true);
@@ -772,7 +772,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
772772
return false;
773773
}
774774

775-
bool need_embd_pre_norm() const override {
775+
bool need_embd_nextn() const override {
776776
return true;
777777
}
778778
};
@@ -1517,13 +1517,13 @@ bool common_speculative_need_embd(common_speculative * spec) {
15171517
return false;
15181518
}
15191519

1520-
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
1520+
bool common_speculative_need_embd_nextn(common_speculative * spec) {
15211521
if (spec == nullptr) {
15221522
return false;
15231523
}
15241524

15251525
for (auto & impl : spec->impls) {
1526-
if (impl->need_embd_pre_norm()) {
1526+
if (impl->need_embd_nextn()) {
15271527
return true;
15281528
}
15291529
}

common/speculative.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b
5656
// true if any implementation requires target post-norm embeddings to be extracted
5757
bool common_speculative_need_embd(common_speculative * spec);
5858

59-
// true if any implementation requires target pre-norm embeddings to be extracted
60-
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
59+
// true if any implementation requires target nextn embeddings to be extracted
60+
bool common_speculative_need_embd_nextn(common_speculative * spec);
6161

6262
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
6363
void common_speculative_draft(common_speculative * spec);

src/llama-context.cpp

Lines changed: 71 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,20 @@ llama_context::llama_context(
5858
cparams.n_rs_seq = 0;
5959
}
6060

61-
cparams.n_threads = params.n_threads;
62-
cparams.n_threads_batch = params.n_threads_batch;
63-
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
64-
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
65-
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
66-
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
67-
cparams.embeddings = params.embeddings;
68-
cparams.embeddings_pre_norm = false;
69-
cparams.embeddings_pre_norm_masked = false;
70-
cparams.offload_kqv = params.offload_kqv;
71-
cparams.no_perf = params.no_perf;
72-
cparams.pooling_type = params.pooling_type;
73-
cparams.warmup = false;
61+
cparams.n_threads = params.n_threads;
62+
cparams.n_threads_batch = params.n_threads_batch;
63+
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
64+
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
65+
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
66+
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
67+
cparams.embeddings = params.embeddings;
68+
cparams.embeddings_nextn = false;
69+
cparams.embeddings_nextn_masked = false;
70+
cparams.offload_kqv = params.offload_kqv;
71+
cparams.no_perf = params.no_perf;
72+
cparams.pooling_type = params.pooling_type;
73+
cparams.warmup = false;
74+
7475

7576
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
7677
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -882,34 +883,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
882883
return it->second.data();
883884
}
884885

885-
float * llama_context::get_embeddings_pre_norm() {
886+
float * llama_context::get_embeddings_nextn() {
886887
output_reorder();
887888

888-
return embd_pre_norm.data;
889+
return embd_nextn.data;
889890
}
890891

891-
float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
892+
float * llama_context::get_embeddings_nextn_ith(int32_t i) {
892893
output_reorder();
893894

894895
try {
895-
if (embd_pre_norm.data == nullptr) {
896-
throw std::runtime_error("no pre-norm embeddings");
896+
if (embd_nextn.data == nullptr) {
897+
throw std::runtime_error("no nextn embeddings");
897898
}
898899

899900
const uint32_t n_embd = model.hparams.n_embd;
900901

901-
if (!cparams.embeddings_pre_norm_masked) {
902-
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
903-
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
904-
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
902+
if (!cparams.embeddings_nextn_masked) {
903+
// unmasked: nextn rows are stored densely, indexed by raw token position.
904+
if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) {
905+
throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd));
905906
}
906-
return embd_pre_norm.data + (size_t) i * n_embd;
907+
return embd_nextn.data + (size_t) i * n_embd;
907908
}
908909

909910
const int64_t j = output_resolve_row(i);
910-
return embd_pre_norm.data + j*n_embd;
911+
return embd_nextn.data + j*n_embd;
911912
} catch (const std::exception & err) {
912-
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
913+
LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what());
913914
#ifndef NDEBUG
914915
GGML_ABORT("fatal error");
915916
#else
@@ -1098,11 +1099,11 @@ void llama_context::set_embeddings(bool value) {
10981099
//sched_need_reserve = true;
10991100
}
11001101

1101-
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
1102+
void llama_context::set_embeddings_nextn(bool value, bool masked) {
11021103
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);
11031104

1104-
cparams.embeddings_pre_norm = value;
1105-
cparams.embeddings_pre_norm_masked = masked;
1105+
cparams.embeddings_nextn = value;
1106+
cparams.embeddings_nextn_masked = masked;
11061107
}
11071108

11081109
void llama_context::set_causal_attn(bool value) {
@@ -1319,7 +1320,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
13191320
}
13201321

13211322
int llama_context::encode(const llama_batch & batch_inp) {
1322-
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
1323+
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
13231324
// so accept either present rather than requiring exactly one.
13241325
GGML_ASSERT(batch_inp.token || batch_inp.embd);
13251326

@@ -1392,9 +1393,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
13921393
}
13931394
}
13941395

1395-
auto * t_logits = res->get_logits();
1396-
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
1397-
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
1396+
auto * t_logits = res->get_logits();
1397+
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
1398+
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;
13981399

13991400
// extract logits
14001401
if (logits.data && t_logits) {
@@ -1460,14 +1461,14 @@ int llama_context::encode(const llama_batch & batch_inp) {
14601461
}
14611462
}
14621463

1463-
// extract pre-norm embeddings (hidden state before the final output norm)
1464-
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1465-
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
1464+
// extract nextn embeddings (hidden state before the final output norm)
1465+
if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1466+
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
14661467
GGML_ASSERT(backend_h != nullptr);
14671468

14681469
const uint32_t n_embd = hparams.n_embd;
1469-
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
1470-
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
1470+
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
1471+
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
14711472
}
14721473

14731474
// TODO: hacky solution
@@ -1622,7 +1623,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
16221623
}
16231624

16241625
int llama_context::decode(const llama_batch & batch_inp) {
1625-
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
1626+
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
16261627
// so accept either present rather than requiring exactly one.
16271628
GGML_ASSERT(batch_inp.token || batch_inp.embd);
16281629

@@ -1822,9 +1823,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
18221823
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
18231824
//}
18241825

1825-
auto * t_logits = res->get_logits();
1826-
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1827-
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
1826+
auto * t_logits = res->get_logits();
1827+
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1828+
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;
18281829

18291830
if (t_embd && res->get_embd_pooled()) {
18301831
t_embd = res->get_embd_pooled();
@@ -1905,22 +1906,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
19051906
}
19061907
}
19071908

1908-
// extract pre-norm embeddings (hidden state before the final output norm)
1909+
// extract nextn embeddings before
19091910
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
19101911
{
1911-
const bool masked = cparams.embeddings_pre_norm_masked;
1912+
const bool masked = cparams.embeddings_nextn_masked;
19121913
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
19131914
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
19141915

1915-
if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1916-
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
1916+
if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1917+
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
19171918
GGML_ASSERT(backend_h != nullptr);
19181919

1919-
const uint32_t n_embd = hparams.n_embd;
1920-
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
1920+
const uint32_t n_embd = hparams.n_embd;
1921+
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
19211922

1922-
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
1923-
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
1923+
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
1924+
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float));
19241925
}
19251926
}
19261927

@@ -2012,9 +2013,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20122013
const auto n_embd = hparams.n_embd;
20132014
const auto n_embd_out = hparams.n_embd_out();
20142015

2015-
bool has_logits = true;
2016-
bool has_embd = cparams.embeddings;
2017-
bool has_embd_pre_norm = cparams.embeddings_pre_norm;
2016+
bool has_logits = true;
2017+
bool has_embd = cparams.embeddings;
2018+
bool has_embd_nextn = cparams.embeddings_nextn;
20182019

20192020
// TODO: hacky enc-dec support
20202021
if (model.arch == LLM_ARCH_T5) {
@@ -2026,14 +2027,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20262027
size_t backend_float_count = 0;
20272028
size_t backend_token_count = 0;
20282029

2029-
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
2030-
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
2031-
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
2030+
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
2031+
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
2032+
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
20322033

2033-
if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
2034-
// unmasked: pre-norm row exists for every token in the batch, not just
2034+
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
2035+
// unmasked: nextn row exists for every token in the batch, not just
20352036
// those flagged via batch.logits[i] -> size by token count instead.
2036-
embd_pre_norm.size = (size_t) n_embd * n_batch;
2037+
embd_nextn.size = (size_t) n_embd * n_batch;
20372038
}
20382039

20392040
// Allocate backend sampling output buffers if there are backend samplers configured.
@@ -2050,7 +2051,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20502051

20512052
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
20522053
const size_t new_size =
2053-
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
2054+
(logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) +
20542055
( backend_token_count) * sizeof(llama_token);
20552056

20562057
// alloc only when more than the current capacity is required
@@ -2067,7 +2068,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20672068
buf_output = nullptr;
20682069
logits.data = nullptr;
20692070
embd.data = nullptr;
2070-
embd_pre_norm.data = nullptr;
2071+
embd_nextn.data = nullptr;
20712072
}
20722073

20732074
auto * buft = ggml_backend_cpu_buffer_type();
@@ -2096,8 +2097,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20962097
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
20972098
offset += embd.size * sizeof(float);
20982099

2099-
embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
2100-
offset += embd_pre_norm.size * sizeof(float);
2100+
embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0};
2101+
offset += embd_nextn.size * sizeof(float);
21012102

21022103
if (has_sampling) {
21032104
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
@@ -2163,9 +2164,9 @@ void llama_context::output_reorder() {
21632164
}
21642165
}
21652166

2166-
if (embd_pre_norm.size > 0) {
2167+
if (embd_nextn.size > 0) {
21672168
for (uint64_t k = 0; k < n_embd; k++) {
2168-
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
2169+
std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]);
21692170
}
21702171
}
21712172

@@ -3584,20 +3585,20 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
35843585
return ctx->get_embeddings_seq(seq_id);
35853586
}
35863587

3587-
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
3588-
ctx->set_embeddings_pre_norm(value, masked);
3588+
void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
3589+
ctx->set_embeddings_nextn(value, masked);
35893590
}
35903591

3591-
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
3592+
float * llama_get_embeddings_nextn(llama_context * ctx) {
35923593
ctx->synchronize();
35933594

3594-
return ctx->get_embeddings_pre_norm();
3595+
return ctx->get_embeddings_nextn();
35953596
}
35963597

3597-
float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) {
3598+
float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) {
35983599
ctx->synchronize();
35993600

3600-
return ctx->get_embeddings_pre_norm_ith(i);
3601+
return ctx->get_embeddings_nextn_ith(i);
36013602
}
36023603

36033604
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {

0 commit comments

Comments
 (0)