@@ -58,19 +58,20 @@ llama_context::llama_context(
5858 cparams.n_rs_seq = 0 ;
5959 }
6060
61- cparams.n_threads = params.n_threads ;
62- cparams.n_threads_batch = params.n_threads_batch ;
63- cparams.yarn_ext_factor = params.yarn_ext_factor >= 0 .0f ? params.yarn_ext_factor : hparams.yarn_ext_factor ;
64- cparams.yarn_attn_factor = params.yarn_attn_factor >= 0 .0f ? params.yarn_attn_factor : hparams.yarn_attn_factor ;
65- cparams.yarn_beta_fast = params.yarn_beta_fast >= 0 .0f ? params.yarn_beta_fast : hparams.yarn_beta_fast ;
66- cparams.yarn_beta_slow = params.yarn_beta_slow >= 0 .0f ? params.yarn_beta_slow : hparams.yarn_beta_slow ;
67- cparams.embeddings = params.embeddings ;
68- cparams.embeddings_pre_norm = false ;
69- cparams.embeddings_pre_norm_masked = false ;
70- cparams.offload_kqv = params.offload_kqv ;
71- cparams.no_perf = params.no_perf ;
72- cparams.pooling_type = params.pooling_type ;
73- cparams.warmup = false ;
61+ cparams.n_threads = params.n_threads ;
62+ cparams.n_threads_batch = params.n_threads_batch ;
63+ cparams.yarn_ext_factor = params.yarn_ext_factor >= 0 .0f ? params.yarn_ext_factor : hparams.yarn_ext_factor ;
64+ cparams.yarn_attn_factor = params.yarn_attn_factor >= 0 .0f ? params.yarn_attn_factor : hparams.yarn_attn_factor ;
65+ cparams.yarn_beta_fast = params.yarn_beta_fast >= 0 .0f ? params.yarn_beta_fast : hparams.yarn_beta_fast ;
66+ cparams.yarn_beta_slow = params.yarn_beta_slow >= 0 .0f ? params.yarn_beta_slow : hparams.yarn_beta_slow ;
67+ cparams.embeddings = params.embeddings ;
68+ cparams.embeddings_nextn = false ;
69+ cparams.embeddings_nextn_masked = false ;
70+ cparams.offload_kqv = params.offload_kqv ;
71+ cparams.no_perf = params.no_perf ;
72+ cparams.pooling_type = params.pooling_type ;
73+ cparams.warmup = false ;
74+
7475
7576 cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx ;
7677 cparams.rope_freq_base = params.rope_freq_base == 0 .0f ? hparams.rope_freq_base_train : params.rope_freq_base ;
@@ -882,34 +883,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
882883 return it->second .data ();
883884}
884885
885- float * llama_context::get_embeddings_pre_norm () {
886+ float * llama_context::get_embeddings_nextn () {
886887 output_reorder ();
887888
888- return embd_pre_norm .data ;
889+ return embd_nextn .data ;
889890}
890891
891- float * llama_context::get_embeddings_pre_norm_ith (int32_t i) {
892+ float * llama_context::get_embeddings_nextn_ith (int32_t i) {
892893 output_reorder ();
893894
894895 try {
895- if (embd_pre_norm .data == nullptr ) {
896- throw std::runtime_error (" no pre-norm embeddings" );
896+ if (embd_nextn .data == nullptr ) {
897+ throw std::runtime_error (" no nextn embeddings" );
897898 }
898899
899900 const uint32_t n_embd = model.hparams .n_embd ;
900901
901- if (!cparams.embeddings_pre_norm_masked ) {
902- // unmasked: pre-norm rows are stored densely, indexed by raw token position.
903- if (i < 0 || (size_t )(i + 1 ) * n_embd > embd_pre_norm .size ) {
904- throw std::runtime_error (format (" out of range [0, %zu)" , embd_pre_norm .size / n_embd));
902+ if (!cparams.embeddings_nextn_masked ) {
903+ // unmasked: nextn rows are stored densely, indexed by raw token position.
904+ if (i < 0 || (size_t )(i + 1 ) * n_embd > embd_nextn .size ) {
905+ throw std::runtime_error (format (" out of range [0, %zu)" , embd_nextn .size / n_embd));
905906 }
906- return embd_pre_norm .data + (size_t ) i * n_embd;
907+ return embd_nextn .data + (size_t ) i * n_embd;
907908 }
908909
909910 const int64_t j = output_resolve_row (i);
910- return embd_pre_norm .data + j*n_embd;
911+ return embd_nextn .data + j*n_embd;
911912 } catch (const std::exception & err) {
912- LLAMA_LOG_ERROR (" %s: invalid pre-norm embeddings id %d, reason: %s\n " , __func__, i, err.what ());
913+ LLAMA_LOG_ERROR (" %s: invalid nextn embeddings id %d, reason: %s\n " , __func__, i, err.what ());
913914#ifndef NDEBUG
914915 GGML_ABORT (" fatal error" );
915916#else
@@ -1098,11 +1099,11 @@ void llama_context::set_embeddings(bool value) {
10981099 // sched_need_reserve = true;
10991100}
11001101
1101- void llama_context::set_embeddings_pre_norm (bool value, bool masked) {
1102+ void llama_context::set_embeddings_nextn (bool value, bool masked) {
11021103 LLAMA_LOG_DEBUG (" %s: value = %d, masked = %d\n " , __func__, value, masked);
11031104
1104- cparams.embeddings_pre_norm = value;
1105- cparams.embeddings_pre_norm_masked = masked;
1105+ cparams.embeddings_nextn = value;
1106+ cparams.embeddings_nextn_masked = masked;
11061107}
11071108
11081109void llama_context::set_causal_attn (bool value) {
@@ -1319,7 +1320,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
13191320}
13201321
13211322int llama_context::encode (const llama_batch & batch_inp) {
1322- // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
1323+ // MTP hook batches carry both token (next-token id) and embd (h_nextn row),
13231324 // so accept either present rather than requiring exactly one.
13241325 GGML_ASSERT (batch_inp.token || batch_inp.embd );
13251326
@@ -1392,9 +1393,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
13921393 }
13931394 }
13941395
1395- auto * t_logits = res->get_logits ();
1396- auto * t_embd = res->get_embd_pooled () ? res->get_embd_pooled () : res->get_embd ();
1397- auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm () : nullptr ;
1396+ auto * t_logits = res->get_logits ();
1397+ auto * t_embd = res->get_embd_pooled () ? res->get_embd_pooled () : res->get_embd ();
1398+ auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn () : nullptr ;
13981399
13991400 // extract logits
14001401 if (logits.data && t_logits) {
@@ -1460,14 +1461,14 @@ int llama_context::encode(const llama_batch & batch_inp) {
14601461 }
14611462 }
14621463
1463- // extract pre-norm embeddings (hidden state before the final output norm)
1464- if (embd_pre_norm .data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1465- ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_pre_norm );
1464+ // extract nextn embeddings (hidden state before the final output norm)
1465+ if (embd_nextn .data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1466+ ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_nextn );
14661467 GGML_ASSERT (backend_h != nullptr );
14671468
14681469 const uint32_t n_embd = hparams.n_embd ;
1469- GGML_ASSERT (n_tokens*n_embd <= (int64_t ) embd_pre_norm .size );
1470- ggml_backend_tensor_get_async (backend_h, t_h_pre_norm, embd_pre_norm .data , 0 , n_tokens*n_embd*sizeof (float ));
1470+ GGML_ASSERT (n_tokens*n_embd <= (int64_t ) embd_nextn .size );
1471+ ggml_backend_tensor_get_async (backend_h, t_h_nextn, embd_nextn .data , 0 , n_tokens*n_embd*sizeof (float ));
14711472 }
14721473
14731474 // TODO: hacky solution
@@ -1622,7 +1623,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
16221623}
16231624
16241625int llama_context::decode (const llama_batch & batch_inp) {
1625- // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
1626+ // MTP hook batches carry both token (next-token id) and embd (h_nextn row),
16261627 // so accept either present rather than requiring exactly one.
16271628 GGML_ASSERT (batch_inp.token || batch_inp.embd );
16281629
@@ -1822,9 +1823,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
18221823 // ggml_graph_dump_dot(gf, NULL, "llama.dot");
18231824 // }
18241825
1825- auto * t_logits = res->get_logits ();
1826- auto * t_embd = cparams.embeddings ? res->get_embd () : nullptr ;
1827- auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm () : nullptr ;
1826+ auto * t_logits = res->get_logits ();
1827+ auto * t_embd = cparams.embeddings ? res->get_embd () : nullptr ;
1828+ auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn () : nullptr ;
18281829
18291830 if (t_embd && res->get_embd_pooled ()) {
18301831 t_embd = res->get_embd_pooled ();
@@ -1905,22 +1906,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
19051906 }
19061907 }
19071908
1908- // extract pre-norm embeddings (hidden state before the final output norm)
1909+ // extract nextn embeddings before
19091910 // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
19101911 {
1911- const bool masked = cparams.embeddings_pre_norm_masked ;
1912+ const bool masked = cparams.embeddings_nextn_masked ;
19121913 const int64_t n_rows = masked ? n_outputs : (int64_t ) ubatch.n_tokens ;
19131914 const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
19141915
1915- if (embd_pre_norm .data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1916- ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_pre_norm );
1916+ if (embd_nextn .data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1917+ ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_nextn );
19171918 GGML_ASSERT (backend_h != nullptr );
19181919
1919- const uint32_t n_embd = hparams.n_embd ;
1920- float * embd_pre_norm_out = embd_pre_norm .data + offset*n_embd;
1920+ const uint32_t n_embd = hparams.n_embd ;
1921+ float * embd_nextn_out = embd_nextn .data + offset*n_embd;
19211922
1922- GGML_ASSERT ((offset + n_rows)*n_embd <= (int64_t ) embd_pre_norm .size );
1923- ggml_backend_tensor_get_async (backend_h, t_h_pre_norm, embd_pre_norm_out , 0 , n_rows*n_embd*sizeof (float ));
1923+ GGML_ASSERT ((offset + n_rows)*n_embd <= (int64_t ) embd_nextn .size );
1924+ ggml_backend_tensor_get_async (backend_h, t_h_nextn, embd_nextn_out , 0 , n_rows*n_embd*sizeof (float ));
19241925 }
19251926 }
19261927
@@ -2012,9 +2013,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20122013 const auto n_embd = hparams.n_embd ;
20132014 const auto n_embd_out = hparams.n_embd_out ();
20142015
2015- bool has_logits = true ;
2016- bool has_embd = cparams.embeddings ;
2017- bool has_embd_pre_norm = cparams.embeddings_pre_norm ;
2016+ bool has_logits = true ;
2017+ bool has_embd = cparams.embeddings ;
2018+ bool has_embd_nextn = cparams.embeddings_nextn ;
20182019
20192020 // TODO: hacky enc-dec support
20202021 if (model.arch == LLM_ARCH_T5) {
@@ -2026,14 +2027,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20262027 size_t backend_float_count = 0 ;
20272028 size_t backend_token_count = 0 ;
20282029
2029- logits.size = has_logits ? n_vocab*n_outputs_max : 0 ;
2030- embd.size = has_embd ? n_embd_out*n_outputs_max : 0 ;
2031- embd_pre_norm .size = has_embd_pre_norm ? n_embd*n_outputs_max : 0 ;
2030+ logits.size = has_logits ? n_vocab*n_outputs_max : 0 ;
2031+ embd.size = has_embd ? n_embd_out*n_outputs_max : 0 ;
2032+ embd_nextn .size = has_embd_nextn ? n_embd*n_outputs_max : 0 ;
20322033
2033- if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked ) {
2034- // unmasked: pre-norm row exists for every token in the batch, not just
2034+ if (has_embd_nextn && !cparams.embeddings_nextn_masked ) {
2035+ // unmasked: nextn row exists for every token in the batch, not just
20352036 // those flagged via batch.logits[i] -> size by token count instead.
2036- embd_pre_norm .size = (size_t ) n_embd * n_batch;
2037+ embd_nextn .size = (size_t ) n_embd * n_batch;
20372038 }
20382039
20392040 // Allocate backend sampling output buffers if there are backend samplers configured.
@@ -2050,7 +2051,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20502051
20512052 const size_t prev_size = buf_output ? ggml_backend_buffer_get_size (buf_output.get ()) : 0 ;
20522053 const size_t new_size =
2053- (logits.size + embd.size + embd_pre_norm .size + backend_float_count) * sizeof (float ) +
2054+ (logits.size + embd.size + embd_nextn .size + backend_float_count) * sizeof (float ) +
20542055 ( backend_token_count) * sizeof (llama_token);
20552056
20562057 // alloc only when more than the current capacity is required
@@ -2067,7 +2068,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20672068 buf_output = nullptr ;
20682069 logits.data = nullptr ;
20692070 embd.data = nullptr ;
2070- embd_pre_norm .data = nullptr ;
2071+ embd_nextn .data = nullptr ;
20712072 }
20722073
20732074 auto * buft = ggml_backend_cpu_buffer_type ();
@@ -2096,8 +2097,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20962097 embd = has_embd ? buffer_view<float >{(float *) (base + offset), embd.size } : buffer_view<float >{nullptr , 0 };
20972098 offset += embd.size * sizeof (float );
20982099
2099- embd_pre_norm = has_embd_pre_norm ? buffer_view<float >{(float *) (base + offset), embd_pre_norm .size } : buffer_view<float >{nullptr , 0 };
2100- offset += embd_pre_norm .size * sizeof (float );
2100+ embd_nextn = has_embd_nextn ? buffer_view<float >{(float *) (base + offset), embd_nextn .size } : buffer_view<float >{nullptr , 0 };
2101+ offset += embd_nextn .size * sizeof (float );
21012102
21022103 if (has_sampling) {
21032104 sampling.logits = {(float *) (base + offset), (size_t )(n_vocab*n_outputs_max)};
@@ -2163,9 +2164,9 @@ void llama_context::output_reorder() {
21632164 }
21642165 }
21652166
2166- if (embd_pre_norm .size > 0 ) {
2167+ if (embd_nextn .size > 0 ) {
21672168 for (uint64_t k = 0 ; k < n_embd; k++) {
2168- std::swap (embd_pre_norm .data [i0*n_embd + k], embd_pre_norm .data [i1*n_embd + k]);
2169+ std::swap (embd_nextn .data [i0*n_embd + k], embd_nextn .data [i1*n_embd + k]);
21692170 }
21702171 }
21712172
@@ -3584,20 +3585,20 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
35843585 return ctx->get_embeddings_seq (seq_id);
35853586}
35863587
3587- void llama_set_embeddings_pre_norm (llama_context * ctx, bool value, bool masked) {
3588- ctx->set_embeddings_pre_norm (value, masked);
3588+ void llama_set_embeddings_nextn (llama_context * ctx, bool value, bool masked) {
3589+ ctx->set_embeddings_nextn (value, masked);
35893590}
35903591
3591- float * llama_get_embeddings_pre_norm (llama_context * ctx) {
3592+ float * llama_get_embeddings_nextn (llama_context * ctx) {
35923593 ctx->synchronize ();
35933594
3594- return ctx->get_embeddings_pre_norm ();
3595+ return ctx->get_embeddings_nextn ();
35953596}
35963597
3597- float * llama_get_embeddings_pre_norm_ith (llama_context * ctx, int32_t i) {
3598+ float * llama_get_embeddings_nextn_ith (llama_context * ctx, int32_t i) {
35983599 ctx->synchronize ();
35993600
3600- return ctx->get_embeddings_pre_norm_ith (i);
3601+ return ctx->get_embeddings_nextn_ith (i);
36013602}
36023603
36033604bool llama_set_sampler (llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
0 commit comments