@@ -895,8 +895,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
895895 throw std::runtime_error (" no pre-norm embeddings" );
896896 }
897897
898- const int64_t j = output_resolve_row (i);
899898 const uint32_t n_embd = model.hparams .n_embd ;
899+
900+ if (!cparams.embeddings_pre_norm_masked ) {
901+ // unmasked: pre-norm rows are stored densely, indexed by raw token position.
902+ if (i < 0 || (size_t )(i + 1 ) * n_embd > embd_pre_norm.size ) {
903+ throw std::runtime_error (format (" out of range [0, %zu)" , embd_pre_norm.size / n_embd));
904+ }
905+ return embd_pre_norm.data + (size_t ) i * n_embd;
906+ }
907+
908+ const int64_t j = output_resolve_row (i);
900909 return embd_pre_norm.data + j*n_embd;
901910 } catch (const std::exception & err) {
902911 LLAMA_LOG_ERROR (" %s: invalid pre-norm embeddings id %d, reason: %s\n " , __func__, i, err.what ());
@@ -1088,10 +1097,11 @@ void llama_context::set_embeddings(bool value) {
10881097 // sched_need_reserve = true;
10891098}
10901099
1091- void llama_context::set_embeddings_pre_norm (bool value) {
1092- LLAMA_LOG_DEBUG (" %s: value = %d\n " , __func__, value);
1100+ void llama_context::set_embeddings_pre_norm (bool value, bool masked ) {
1101+ LLAMA_LOG_DEBUG (" %s: value = %d, masked = %d \n " , __func__, value, masked );
10931102
1094- cparams.embeddings_pre_norm = value;
1103+ cparams.embeddings_pre_norm = value;
1104+ cparams.embeddings_pre_norm_masked = masked;
10951105}
10961106
10971107void llama_context::set_causal_attn (bool value) {
@@ -1737,6 +1747,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
17371747 };
17381748
17391749 int64_t n_outputs_prev = 0 ;
1750+ int64_t n_tokens_prev = 0 ;
17401751
17411752 do {
17421753 const auto & ubatch = mctx->get_ubatch ();
@@ -1882,16 +1893,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
18821893
18831894 // extract pre-norm embeddings (hidden state before the final output norm)
18841895 // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
1885- if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1886- ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_pre_norm);
1887- GGML_ASSERT (backend_h != nullptr );
1896+ {
1897+ const bool masked = cparams.embeddings_pre_norm_masked ;
1898+ const int64_t n_rows = masked ? n_outputs : (int64_t ) ubatch.n_tokens ;
1899+ const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
1900+
1901+ if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1902+ ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend (sched.get (), t_h_pre_norm);
1903+ GGML_ASSERT (backend_h != nullptr );
18881904
1889- const uint32_t n_embd = hparams.n_embd ;
1890- float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev *n_embd;
1905+ const uint32_t n_embd = hparams.n_embd ;
1906+ float * embd_pre_norm_out = embd_pre_norm.data + offset *n_embd;
18911907
1892- GGML_ASSERT ( n_outputs_prev + n_outputs <= n_outputs_all );
1893- GGML_ASSERT ((n_outputs_prev + n_outputs)*n_embd <= ( int64_t ) embd_pre_norm. size );
1894- ggml_backend_tensor_get_async (backend_h, t_h_pre_norm, embd_pre_norm_out, 0 , n_outputs*n_embd* sizeof ( float ));
1908+ GGML_ASSERT ((offset + n_rows)*n_embd <= ( int64_t ) embd_pre_norm. size );
1909+ ggml_backend_tensor_get_async (backend_h, t_h_pre_norm, embd_pre_norm_out, 0 , n_rows*n_embd* sizeof ( float ) );
1910+ }
18951911 }
18961912
18971913 // Copy backend sampling output if this ubatch produced any sampling tensors.
@@ -1908,6 +1924,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
19081924 }
19091925
19101926 n_outputs_prev += n_outputs;
1927+ n_tokens_prev += ubatch.n_tokens ;
19111928 } while (mctx->next ());
19121929
19131930 // set to total number of outputs in the batch, for use in llama_get_logits_ith
@@ -1999,6 +2016,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
19992016 embd.size = has_embd ? n_embd_out*n_outputs_max : 0 ;
20002017 embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0 ;
20012018
2019+ if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked ) {
2020+ // unmasked: pre-norm row exists for every token in the ubatch, not just
2021+ // those flagged via batch.logits[i] -> size by token count instead.
2022+ embd_pre_norm.size = (size_t ) n_embd * n_batch;
2023+ }
2024+
20022025 // Allocate backend sampling output buffers if there are backend samplers configured.
20032026 const bool has_sampling = !sampling.samplers .empty ();
20042027 if (has_sampling) {
@@ -3547,8 +3570,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
35473570 return ctx->get_embeddings_seq (seq_id);
35483571}
35493572
3550- void llama_set_embeddings_pre_norm (llama_context * ctx, bool value) {
3551- ctx->set_embeddings_pre_norm (value);
3573+ void llama_set_embeddings_pre_norm (llama_context * ctx, bool value, bool masked ) {
3574+ ctx->set_embeddings_pre_norm (value, masked );
35523575}
35533576
35543577float * llama_get_embeddings_pre_norm (llama_context * ctx) {
0 commit comments