Skip to content

Commit 7e00869

Browse files
committed
rename pre_norm to nextn
1 parent e94341d commit 7e00869

11 files changed

Lines changed: 115 additions & 121 deletions

File tree

common/speculative.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "common.h"
44
#include "ggml.h"
55
#include "llama.h"
6-
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
6+
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
77
#include "log.h"
88
#include "ngram-cache.h"
99
#include "ngram-map.h"
@@ -162,7 +162,7 @@ struct common_speculative_impl {
162162
virtual bool need_embd() const = 0;
163163

164164
// true if this implementation requires the target context to extract pre-norm embeddings
165-
virtual bool need_embd_pre_norm() const { return false; }
165+
virtual bool need_embd_nextn() const { return false; }
166166
};
167167

168168
struct common_speculative_impl_draft_simple : public common_speculative_impl {
@@ -487,8 +487,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
487487
}
488488
}
489489

490-
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
491-
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
490+
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
491+
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
492492

493493
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
494494

@@ -583,7 +583,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
583583
// ^--- this is a problem
584584
// TODO:this is generally true, but would be nice to assert it
585585
{
586-
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
586+
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
587587
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
588588

589589
//{
@@ -625,7 +625,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
625625
verify_h[seq_id].resize((size_t) n_rows * n_embd);
626626

627627
for (int32_t i = 0; i < n_rows; ++i) {
628-
const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i);
628+
const float * h = llama_get_embeddings_nextn_ith(ctx_tgt, i_batch_beg[seq_id] + i);
629629
std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes);
630630
}
631631

@@ -686,7 +686,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
686686
auto * smpl = smpls[seq_id].get();
687687

688688
common_sampler_sample(smpl, ctx_dft, i_batch, true);
689-
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
689+
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
690690
++i_batch;
691691

692692
const auto * cur_p = common_sampler_get_candidates(smpl, true);
@@ -772,7 +772,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
772772
return false;
773773
}
774774

775-
bool need_embd_pre_norm() const override {
775+
bool need_embd_nextn() const override {
776776
return true;
777777
}
778778
};
@@ -1539,13 +1539,13 @@ bool common_speculative_need_embd(common_speculative * spec) {
15391539
return false;
15401540
}
15411541

1542-
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
1542+
bool common_speculative_need_embd_nextn(common_speculative * spec) {
15431543
if (spec == nullptr) {
15441544
return false;
15451545
}
15461546

15471547
for (auto & impl : spec->impls) {
1548-
if (impl->need_embd_pre_norm()) {
1548+
if (impl->need_embd_nextn()) {
15491549
return true;
15501550
}
15511551
}

common/speculative.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b
5959
// true if any implementation requires target post-norm embeddings to be extracted
6060
bool common_speculative_need_embd(common_speculative * spec);
6161

62-
// true if any implementation requires target pre-norm embeddings to be extracted
63-
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
62+
// true if any implementation requires target nextn embeddings to be extracted
63+
bool common_speculative_need_embd_nextn(common_speculative * spec);
6464

6565
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
6666
void common_speculative_draft(common_speculative * spec);

src/llama-context.cpp

Lines changed: 71 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,20 @@ llama_context::llama_context(
5858
cparams.n_rs_seq = 0;
5959
}
6060

61-
cparams.n_threads = params.n_threads;
62-
cparams.n_threads_batch = params.n_threads_batch;
63-
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
64-
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
65-
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
66-
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
67-
cparams.embeddings = params.embeddings;
68-
cparams.embeddings_pre_norm = false;
69-
cparams.embeddings_pre_norm_masked = false;
70-
cparams.offload_kqv = params.offload_kqv;
71-
cparams.no_perf = params.no_perf;
72-
cparams.pooling_type = params.pooling_type;
73-
cparams.warmup = false;
61+
cparams.n_threads = params.n_threads;
62+
cparams.n_threads_batch = params.n_threads_batch;
63+
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
64+
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
65+
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
66+
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
67+
cparams.embeddings = params.embeddings;
68+
cparams.embeddings_nextn = false;
69+
cparams.embeddings_nextn_masked = false;
70+
cparams.offload_kqv = params.offload_kqv;
71+
cparams.no_perf = params.no_perf;
72+
cparams.pooling_type = params.pooling_type;
73+
cparams.warmup = false;
74+
7475

7576
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
7677
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -889,34 +890,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
889890
return it->second.data();
890891
}
891892

892-
float * llama_context::get_embeddings_pre_norm() {
893+
float * llama_context::get_embeddings_nextn() {
893894
output_reorder();
894895

895-
return embd_pre_norm.data;
896+
return embd_nextn.data;
896897
}
897898

898-
float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
899+
float * llama_context::get_embeddings_nextn_ith(int32_t i) {
899900
output_reorder();
900901

901902
try {
902-
if (embd_pre_norm.data == nullptr) {
903-
throw std::runtime_error("no pre-norm embeddings");
903+
if (embd_nextn.data == nullptr) {
904+
throw std::runtime_error("no nextn embeddings");
904905
}
905906

906907
const uint32_t n_embd = model.hparams.n_embd;
907908

908-
if (!cparams.embeddings_pre_norm_masked) {
909-
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
910-
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
911-
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
909+
if (!cparams.embeddings_nextn_masked) {
910+
// unmasked: nextn rows are stored densely, indexed by raw token position.
911+
if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) {
912+
throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd));
912913
}
913-
return embd_pre_norm.data + (size_t) i * n_embd;
914+
return embd_nextn.data + (size_t) i * n_embd;
914915
}
915916

916917
const int64_t j = output_resolve_row(i);
917-
return embd_pre_norm.data + j*n_embd;
918+
return embd_nextn.data + j*n_embd;
918919
} catch (const std::exception & err) {
919-
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
920+
LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what());
920921
#ifndef NDEBUG
921922
GGML_ABORT("fatal error");
922923
#else
@@ -1105,11 +1106,11 @@ void llama_context::set_embeddings(bool value) {
11051106
//sched_need_reserve = true;
11061107
}
11071108

1108-
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
1109+
void llama_context::set_embeddings_nextn(bool value, bool masked) {
11091110
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);
11101111

1111-
cparams.embeddings_pre_norm = value;
1112-
cparams.embeddings_pre_norm_masked = masked;
1112+
cparams.embeddings_nextn = value;
1113+
cparams.embeddings_nextn_masked = masked;
11131114
}
11141115

11151116
void llama_context::set_causal_attn(bool value) {
@@ -1326,7 +1327,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
13261327
}
13271328

13281329
int llama_context::encode(const llama_batch & batch_inp) {
1329-
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
1330+
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
13301331
// so accept either present rather than requiring exactly one.
13311332
GGML_ASSERT(batch_inp.token || batch_inp.embd);
13321333

@@ -1399,9 +1400,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
13991400
}
14001401
}
14011402

1402-
auto * t_logits = res->get_logits();
1403-
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
1404-
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
1403+
auto * t_logits = res->get_logits();
1404+
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
1405+
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;
14051406

14061407
// extract logits
14071408
if (logits.data && t_logits) {
@@ -1467,14 +1468,14 @@ int llama_context::encode(const llama_batch & batch_inp) {
14671468
}
14681469
}
14691470

1470-
// extract pre-norm embeddings (hidden state before the final output norm)
1471-
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1472-
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
1471+
// extract nextn embeddings (hidden state before the final output norm)
1472+
if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1473+
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
14731474
GGML_ASSERT(backend_h != nullptr);
14741475

14751476
const uint32_t n_embd = hparams.n_embd;
1476-
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
1477-
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
1477+
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
1478+
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
14781479
}
14791480

14801481
// TODO: hacky solution
@@ -1629,7 +1630,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
16291630
}
16301631

16311632
int llama_context::decode(const llama_batch & batch_inp) {
1632-
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
1633+
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
16331634
// so accept either present rather than requiring exactly one.
16341635
GGML_ASSERT(batch_inp.token || batch_inp.embd);
16351636

@@ -1829,9 +1830,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
18291830
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
18301831
//}
18311832

1832-
auto * t_logits = res->get_logits();
1833-
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1834-
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
1833+
auto * t_logits = res->get_logits();
1834+
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1835+
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;
18351836

18361837
if (t_embd && res->get_embd_pooled()) {
18371838
t_embd = res->get_embd_pooled();
@@ -1912,22 +1913,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
19121913
}
19131914
}
19141915

1915-
// extract pre-norm embeddings (hidden state before the final output norm)
1916+
// extract nextn embeddings before
19161917
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
19171918
{
1918-
const bool masked = cparams.embeddings_pre_norm_masked;
1919+
const bool masked = cparams.embeddings_nextn_masked;
19191920
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
19201921
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
19211922

1922-
if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1923-
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
1923+
if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1924+
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
19241925
GGML_ASSERT(backend_h != nullptr);
19251926

1926-
const uint32_t n_embd = hparams.n_embd;
1927-
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
1927+
const uint32_t n_embd = hparams.n_embd;
1928+
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
19281929

1929-
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
1930-
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
1930+
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
1931+
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float));
19311932
}
19321933
}
19331934

@@ -2019,9 +2020,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20192020
const auto n_embd = hparams.n_embd;
20202021
const auto n_embd_out = hparams.n_embd_out();
20212022

2022-
bool has_logits = true;
2023-
bool has_embd = cparams.embeddings;
2024-
bool has_embd_pre_norm = cparams.embeddings_pre_norm;
2023+
bool has_logits = true;
2024+
bool has_embd = cparams.embeddings;
2025+
bool has_embd_nextn = cparams.embeddings_nextn;
20252026

20262027
// TODO: hacky enc-dec support
20272028
if (model.arch == LLM_ARCH_T5) {
@@ -2033,14 +2034,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20332034
size_t backend_float_count = 0;
20342035
size_t backend_token_count = 0;
20352036

2036-
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
2037-
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
2038-
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
2037+
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
2038+
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
2039+
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
20392040

2040-
if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
2041-
// unmasked: pre-norm row exists for every token in the batch, not just
2041+
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
2042+
// unmasked: nextn row exists for every token in the batch, not just
20422043
// those flagged via batch.logits[i] -> size by token count instead.
2043-
embd_pre_norm.size = (size_t) n_embd * n_batch;
2044+
embd_nextn.size = (size_t) n_embd * n_batch;
20442045
}
20452046

20462047
// Allocate backend sampling output buffers if there are backend samplers configured.
@@ -2057,7 +2058,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20572058

20582059
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
20592060
const size_t new_size =
2060-
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
2061+
(logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) +
20612062
( backend_token_count) * sizeof(llama_token);
20622063

20632064
// alloc only when more than the current capacity is required
@@ -2074,7 +2075,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20742075
buf_output = nullptr;
20752076
logits.data = nullptr;
20762077
embd.data = nullptr;
2077-
embd_pre_norm.data = nullptr;
2078+
embd_nextn.data = nullptr;
20782079
}
20792080

20802081
auto * buft = ggml_backend_cpu_buffer_type();
@@ -2103,8 +2104,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
21032104
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
21042105
offset += embd.size * sizeof(float);
21052106

2106-
embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
2107-
offset += embd_pre_norm.size * sizeof(float);
2107+
embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0};
2108+
offset += embd_nextn.size * sizeof(float);
21082109

21092110
if (has_sampling) {
21102111
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
@@ -2172,9 +2173,9 @@ void llama_context::output_reorder() {
21722173
}
21732174
}
21742175

2175-
if (embd_pre_norm.size > 0) {
2176+
if (embd_nextn.size > 0) {
21762177
for (uint64_t k = 0; k < n_embd; k++) {
2177-
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
2178+
std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]);
21782179
}
21792180
}
21802181

@@ -3588,20 +3589,20 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
35883589
return ctx->get_embeddings_seq(seq_id);
35893590
}
35903591

3591-
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
3592-
ctx->set_embeddings_pre_norm(value, masked);
3592+
void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
3593+
ctx->set_embeddings_nextn(value, masked);
35933594
}
35943595

3595-
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
3596+
float * llama_get_embeddings_nextn(llama_context * ctx) {
35963597
ctx->synchronize();
35973598

3598-
return ctx->get_embeddings_pre_norm();
3599+
return ctx->get_embeddings_nextn();
35993600
}
36003601

3601-
float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) {
3602+
float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) {
36023603
ctx->synchronize();
36033604

3604-
return ctx->get_embeddings_pre_norm_ith(i);
3605+
return ctx->get_embeddings_nextn_ith(i);
36053606
}
36063607

36073608
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {

0 commit comments

Comments
 (0)