@@ -58,19 +58,20 @@ llama_context::llama_context(
5858 cparams.n_rs_seq = 0 ;
5959 }
6060
61- cparams.n_threads = params.n_threads ;
62- cparams.n_threads_batch = params.n_threads_batch ;
63- cparams.yarn_ext_factor = params.yarn_ext_factor >= 0 .0f ? params.yarn_ext_factor : hparams.yarn_ext_factor ;
64- cparams.yarn_attn_factor = params.yarn_attn_factor >= 0 .0f ? params.yarn_attn_factor : hparams.yarn_attn_factor ;
65- cparams.yarn_beta_fast = params.yarn_beta_fast >= 0 .0f ? params.yarn_beta_fast : hparams.yarn_beta_fast ;
66- cparams.yarn_beta_slow = params.yarn_beta_slow >= 0 .0f ? params.yarn_beta_slow : hparams.yarn_beta_slow ;
67- cparams.embeddings = params.embeddings ;
68- cparams.embeddings_pre_norm = false ;
69- cparams.embeddings_pre_norm_masked = false ;
70- cparams.offload_kqv = params.offload_kqv ;
71- cparams.no_perf = params.no_perf ;
72- cparams.pooling_type = params.pooling_type ;
73- cparams.warmup = false ;
61+ cparams.n_threads = params.n_threads ;
62+ cparams.n_threads_batch = params.n_threads_batch ;
63+ cparams.yarn_ext_factor = params.yarn_ext_factor >= 0 .0f ? params.yarn_ext_factor : hparams.yarn_ext_factor ;
64+ cparams.yarn_attn_factor = params.yarn_attn_factor >= 0 .0f ? params.yarn_attn_factor : hparams.yarn_attn_factor ;
65+ cparams.yarn_beta_fast = params.yarn_beta_fast >= 0 .0f ? params.yarn_beta_fast : hparams.yarn_beta_fast ;
66+ cparams.yarn_beta_slow = params.yarn_beta_slow >= 0 .0f ? params.yarn_beta_slow : hparams.yarn_beta_slow ;
67+ cparams.embeddings = params.embeddings ;
68+ cparams.embeddings_nextn = false ;
69+ cparams.embeddings_nextn_masked = false ;
70+ cparams.offload_kqv = params.offload_kqv ;
71+ cparams.no_perf = params.no_perf ;
72+ cparams.pooling_type = params.pooling_type ;
73+ cparams.warmup = false ;
74+
7475
7576 cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx ;
7677 cparams.rope_freq_base = params.rope_freq_base == 0 .0f ? hparams.rope_freq_base_train : params.rope_freq_base ;
@@ -889,34 +890,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
889890 return it->second .data ();
890891}
891892
892- float * llama_context::get_embeddings_pre_norm () {
893+ float * llama_context::get_embeddings_nextn () {
893894 output_reorder ();
894895
895- return embd_pre_norm .data ;
896+ return embd_nextn .data ;
896897}
897898
898- float * llama_context::get_embeddings_pre_norm_ith (int32_t i) {
899+ float * llama_context::get_embeddings_nextn_ith (int32_t i) {
899900 output_reorder ();
900901
901902 try {
902- if (embd_pre_norm .data == nullptr ) {
903- throw std::runtime_error (" no pre-norm embeddings" );
903+ if (embd_nextn .data == nullptr ) {
904+ throw std::runtime_error (" no nextn embeddings" );
904905 }
905906
906907 const uint32_t n_embd = model.hparams .n_embd ;
907908
908- if (!cparams.embeddings_pre_norm_masked ) {
909- // unmasked: pre-norm rows are stored densely, indexed by raw token position.
910- if (i < 0 || (size_t )(i + 1 ) * n_embd > embd_pre_norm .size ) {
911- throw std::runtime_error (format (" out of range [0, %zu)" , embd_pre_norm .size / n_embd));
909+ if (!cparams.embeddings_nextn_masked ) {
910+ // unmasked: nextn rows are stored densely, indexed by raw token position.
911+ if (i < 0 || (size_t )(i + 1 ) * n_embd > embd_nextn .size ) {
912+ throw std::runtime_error (format (" out of range [0, %zu)" , embd_nextn .size / n_embd));
912913 }
913- return embd_pre_norm .data + (size_t ) i * n_embd;
914+ return embd_nextn .data + (size_t ) i * n_embd;
914915 }
915916
916917 const int64_t j = output_resolve_row (i);
917- return embd_pre_norm .data + j*n_embd;
918+ return embd_nextn .data + j*n_embd;
918919 } catch (const std::exception & err) {
919- LLAMA_LOG_ERROR (" %s: invalid pre-norm embeddings id %d, reason: %s\n " , __func__, i, err.what ());
920+ LLAMA_LOG_ERROR (" %s: invalid nextn embeddings id %d, reason: %s\n " , __func__, i, err.what ());
920921#ifndef NDEBUG
921922 GGML_ABORT (" fatal error" );
922923#else
@@ -1105,11 +1106,11 @@ void llama_context::set_embeddings(bool value) {
11051106 // sched_need_reserve = true;
11061107}
11071108
1108- void llama_context::set_embeddings_pre_norm (bool value, bool masked) {
1109+ void llama_context::set_embeddings_nextn (bool value, bool masked) {
11091110 LLAMA_LOG_DEBUG (" %s: value = %d, masked = %d\n " , __func__, value, masked);
11101111
1111- cparams.embeddings_pre_norm = value;
1112- cparams.embeddings_pre_norm_masked = masked;
1112+ cparams.embeddings_nextn = value;
1113+ cparams.embeddings_nextn_masked = masked;
11131114}
11141115
11151116void llama_context::set_causal_attn (bool value) {
@@ -1326,7 +1327,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
13261327}
13271328
13281329int llama_context::encode (const llama_batch & batch_inp) {
1329- // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
1330+ // MTP hook batches carry both token (next-token id) and embd (h_nextn row),
13301331 // so accept either present rather than requiring exactly one.
13311332 GGML_ASSERT (batch_inp.token || batch_inp.embd );
13321333
@@ -1399,9 +1400,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
13991400 }
14001401 }
14011402
1402- auto * t_logits = res->get_logits ();
1403- auto * t_embd = res->get_embd_pooled () ? res->get_embd_pooled () : res->get_embd ();
1404- auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm () : nullptr ;
1403+ auto * t_logits = res->get_logits ();
1404+ auto * t_embd = res->get_embd_pooled () ? res->get_embd_pooled () : res->get_embd ();
1405+ auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn () : nullptr ;
14051406
14061407 // extract logits
14071408 if (logits.data && t_logits) {
@@ -1467,14 +1468,14 @@ int llama_context::encode(const llama_batch & batch_inp) {
14671468 }
14681469 }
14691470
1470- // extract pre-norm embeddings (hidden state before the final output norm)
1471- if (embd_pre_norm .data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1472- ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_pre_norm );
1471+ // extract nextn embeddings (hidden state before the final output norm)
1472+ if (embd_nextn .data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1473+ ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_nextn );
14731474 GGML_ASSERT (backend_h != nullptr );
14741475
14751476 const uint32_t n_embd = hparams.n_embd ;
1476- GGML_ASSERT (n_tokens*n_embd <= (int64_t ) embd_pre_norm .size );
1477- ggml_backend_tensor_get_async (backend_h, t_h_pre_norm, embd_pre_norm .data , 0 , n_tokens*n_embd*sizeof (float ));
1477+ GGML_ASSERT (n_tokens*n_embd <= (int64_t ) embd_nextn .size );
1478+ ggml_backend_tensor_get_async (backend_h, t_h_nextn, embd_nextn .data , 0 , n_tokens*n_embd*sizeof (float ));
14781479 }
14791480
14801481 // TODO: hacky solution
@@ -1629,7 +1630,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
16291630}
16301631
16311632int llama_context::decode (const llama_batch & batch_inp) {
1632- // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
1633+ // MTP hook batches carry both token (next-token id) and embd (h_nextn row),
16331634 // so accept either present rather than requiring exactly one.
16341635 GGML_ASSERT (batch_inp.token || batch_inp.embd );
16351636
@@ -1829,9 +1830,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
18291830 // ggml_graph_dump_dot(gf, NULL, "llama.dot");
18301831 // }
18311832
1832- auto * t_logits = res->get_logits ();
1833- auto * t_embd = cparams.embeddings ? res->get_embd () : nullptr ;
1834- auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm () : nullptr ;
1833+ auto * t_logits = res->get_logits ();
1834+ auto * t_embd = cparams.embeddings ? res->get_embd () : nullptr ;
1835+ auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn () : nullptr ;
18351836
18361837 if (t_embd && res->get_embd_pooled ()) {
18371838 t_embd = res->get_embd_pooled ();
@@ -1912,22 +1913,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
19121913 }
19131914 }
19141915
1915- // extract pre-norm embeddings (hidden state before the final output norm)
1916+ // extract nextn embeddings before
19161917 // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
19171918 {
1918- const bool masked = cparams.embeddings_pre_norm_masked ;
1919+ const bool masked = cparams.embeddings_nextn_masked ;
19191920 const int64_t n_rows = masked ? n_outputs : (int64_t ) ubatch.n_tokens ;
19201921 const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
19211922
1922- if (embd_pre_norm .data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1923- ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_pre_norm );
1923+ if (embd_nextn .data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1924+ ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_nextn );
19241925 GGML_ASSERT (backend_h != nullptr );
19251926
1926- const uint32_t n_embd = hparams.n_embd ;
1927- float * embd_pre_norm_out = embd_pre_norm .data + offset*n_embd;
1927+ const uint32_t n_embd = hparams.n_embd ;
1928+ float * embd_nextn_out = embd_nextn .data + offset*n_embd;
19281929
1929- GGML_ASSERT ((offset + n_rows)*n_embd <= (int64_t ) embd_pre_norm .size );
1930- ggml_backend_tensor_get_async (backend_h, t_h_pre_norm, embd_pre_norm_out , 0 , n_rows*n_embd*sizeof (float ));
1930+ GGML_ASSERT ((offset + n_rows)*n_embd <= (int64_t ) embd_nextn .size );
1931+ ggml_backend_tensor_get_async (backend_h, t_h_nextn, embd_nextn_out , 0 , n_rows*n_embd*sizeof (float ));
19311932 }
19321933 }
19331934
@@ -2019,9 +2020,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20192020 const auto n_embd = hparams.n_embd ;
20202021 const auto n_embd_out = hparams.n_embd_out ();
20212022
2022- bool has_logits = true ;
2023- bool has_embd = cparams.embeddings ;
2024- bool has_embd_pre_norm = cparams.embeddings_pre_norm ;
2023+ bool has_logits = true ;
2024+ bool has_embd = cparams.embeddings ;
2025+ bool has_embd_nextn = cparams.embeddings_nextn ;
20252026
20262027 // TODO: hacky enc-dec support
20272028 if (model.arch == LLM_ARCH_T5) {
@@ -2033,14 +2034,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20332034 size_t backend_float_count = 0 ;
20342035 size_t backend_token_count = 0 ;
20352036
2036- logits.size = has_logits ? n_vocab*n_outputs_max : 0 ;
2037- embd.size = has_embd ? n_embd_out*n_outputs_max : 0 ;
2038- embd_pre_norm .size = has_embd_pre_norm ? n_embd*n_outputs_max : 0 ;
2037+ logits.size = has_logits ? n_vocab*n_outputs_max : 0 ;
2038+ embd.size = has_embd ? n_embd_out*n_outputs_max : 0 ;
2039+ embd_nextn .size = has_embd_nextn ? n_embd*n_outputs_max : 0 ;
20392040
2040- if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked ) {
2041- // unmasked: pre-norm row exists for every token in the batch, not just
2041+ if (has_embd_nextn && !cparams.embeddings_nextn_masked ) {
2042+ // unmasked: nextn row exists for every token in the batch, not just
20422043 // those flagged via batch.logits[i] -> size by token count instead.
2043- embd_pre_norm .size = (size_t ) n_embd * n_batch;
2044+ embd_nextn .size = (size_t ) n_embd * n_batch;
20442045 }
20452046
20462047 // Allocate backend sampling output buffers if there are backend samplers configured.
@@ -2057,7 +2058,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20572058
20582059 const size_t prev_size = buf_output ? ggml_backend_buffer_get_size (buf_output.get ()) : 0 ;
20592060 const size_t new_size =
2060- (logits.size + embd.size + embd_pre_norm .size + backend_float_count) * sizeof (float ) +
2061+ (logits.size + embd.size + embd_nextn .size + backend_float_count) * sizeof (float ) +
20612062 ( backend_token_count) * sizeof (llama_token);
20622063
20632064 // alloc only when more than the current capacity is required
@@ -2074,7 +2075,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20742075 buf_output = nullptr ;
20752076 logits.data = nullptr ;
20762077 embd.data = nullptr ;
2077- embd_pre_norm .data = nullptr ;
2078+ embd_nextn .data = nullptr ;
20782079 }
20792080
20802081 auto * buft = ggml_backend_cpu_buffer_type ();
@@ -2103,8 +2104,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
21032104 embd = has_embd ? buffer_view<float >{(float *) (base + offset), embd.size } : buffer_view<float >{nullptr , 0 };
21042105 offset += embd.size * sizeof (float );
21052106
2106- embd_pre_norm = has_embd_pre_norm ? buffer_view<float >{(float *) (base + offset), embd_pre_norm .size } : buffer_view<float >{nullptr , 0 };
2107- offset += embd_pre_norm .size * sizeof (float );
2107+ embd_nextn = has_embd_nextn ? buffer_view<float >{(float *) (base + offset), embd_nextn .size } : buffer_view<float >{nullptr , 0 };
2108+ offset += embd_nextn .size * sizeof (float );
21082109
21092110 if (has_sampling) {
21102111 sampling.logits = {(float *) (base + offset), (size_t )(n_vocab*n_outputs_max)};
@@ -2172,9 +2173,9 @@ void llama_context::output_reorder() {
21722173 }
21732174 }
21742175
2175- if (embd_pre_norm .size > 0 ) {
2176+ if (embd_nextn .size > 0 ) {
21762177 for (uint64_t k = 0 ; k < n_embd; k++) {
2177- std::swap (embd_pre_norm .data [i0*n_embd + k], embd_pre_norm .data [i1*n_embd + k]);
2178+ std::swap (embd_nextn .data [i0*n_embd + k], embd_nextn .data [i1*n_embd + k]);
21782179 }
21792180 }
21802181
@@ -3588,20 +3589,20 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
35883589 return ctx->get_embeddings_seq (seq_id);
35893590}
35903591
3591- void llama_set_embeddings_pre_norm (llama_context * ctx, bool value, bool masked) {
3592- ctx->set_embeddings_pre_norm (value, masked);
3592+ void llama_set_embeddings_nextn (llama_context * ctx, bool value, bool masked) {
3593+ ctx->set_embeddings_nextn (value, masked);
35933594}
35943595
3595- float * llama_get_embeddings_pre_norm (llama_context * ctx) {
3596+ float * llama_get_embeddings_nextn (llama_context * ctx) {
35963597 ctx->synchronize ();
35973598
3598- return ctx->get_embeddings_pre_norm ();
3599+ return ctx->get_embeddings_nextn ();
35993600}
36003601
3601- float * llama_get_embeddings_pre_norm_ith (llama_context * ctx, int32_t i) {
3602+ float * llama_get_embeddings_nextn_ith (llama_context * ctx, int32_t i) {
36023603 ctx->synchronize ();
36033604
3604- return ctx->get_embeddings_pre_norm_ith (i);
3605+ return ctx->get_embeddings_nextn_ith (i);
36053606}
36063607
36073608bool llama_set_sampler (llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
0 commit comments