@@ -30,6 +30,24 @@ static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) {
3030 throw std::runtime_error (" Unsupported ctx type" );
3131}
3232
33+ static uint32_t graph_n_outputs_pp (const llama_cparams & cparams, uint32_t n_tokens, uint32_t n_seqs) {
34+ GGML_ASSERT (n_tokens >= 1 );
35+ GGML_ASSERT (n_seqs >= 1 );
36+
37+ const bool reserve_all_outputs =
38+ cparams.embeddings ||
39+ cparams.pooling_type != LLAMA_POOLING_TYPE_NONE ||
40+ cparams.n_outputs_per_seq == 0 ;
41+
42+ if (reserve_all_outputs) {
43+ return n_tokens;
44+ }
45+
46+ const uint64_t n_outputs = (uint64_t ) n_seqs * cparams.n_outputs_per_seq ;
47+
48+ return std::max<uint32_t >(1 , std::min<uint64_t >(n_tokens, n_outputs));
49+ }
50+
3351llama_context::llama_context (
3452 const llama_model & model,
3553 llama_context_params params) :
@@ -51,6 +69,8 @@ llama_context::llama_context(
5169 throw std::runtime_error (" n_seq_max must be <= " + std::to_string (LLAMA_MAX_SEQ ));
5270 }
5371
72+ cparams.n_outputs_per_seq = params.n_outputs_per_seq ;
73+
5474 cparams.n_rs_seq = params.n_rs_seq ;
5575 if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback (model.arch )) {
5676 LLAMA_LOG_DEBUG (" %s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n " ,
@@ -577,8 +597,7 @@ void llama_context::sched_reserve() {
577597 int n_splits_tg = -1 ;
578598 int n_nodes_tg = -1 ;
579599
580- const bool reserve_all_outputs = cparams.embeddings || cparams.pooling_type != LLAMA_POOLING_TYPE_NONE ;
581- const uint32_t n_outputs_pp = reserve_all_outputs ? n_tokens : n_seqs;
600+ const uint32_t n_outputs_pp = graph_n_outputs_pp (cparams, n_tokens, n_seqs);
582601
583602 // reserve pp (prompt processing) graph first so that buffers are only allocated once
584603 {
@@ -777,8 +796,7 @@ bool llama_context::memory_update(bool optimize) {
777796 const uint32_t n_seqs = cparams.n_seq_max ;
778797 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
779798
780- const bool reserve_all_outputs = cparams.embeddings || cparams.pooling_type != LLAMA_POOLING_TYPE_NONE ;
781- const uint32_t n_outputs_pp = reserve_all_outputs ? n_tokens : n_seqs;
799+ const uint32_t n_outputs_pp = graph_n_outputs_pp (cparams, n_tokens, n_seqs);
782800
783801 auto * gf = graph_reserve (n_tokens, n_seqs, n_outputs_pp, mctx.get ());
784802 if (!gf) {
@@ -2230,9 +2248,13 @@ ggml_cgraph * llama_context::graph_reserve(
22302248 LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
22312249 GGML_ASSERT (n_outputs >= 1 );
22322250
2251+ const bool reserve_all_outputs = n_outputs >= n_tokens;
2252+
22332253 if (n_tokens % n_seqs != 0 ) {
22342254 n_tokens = ((n_tokens + (n_seqs - 1 )) / n_seqs) * n_seqs; // round to next multiple of n_seqs
2235- n_outputs = std::max (n_outputs, n_tokens);
2255+ if (reserve_all_outputs) {
2256+ n_outputs = std::max (n_outputs, n_tokens);
2257+ }
22362258
22372259 LLAMA_LOG_DEBUG (" %s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n " , __func__, n_tokens, n_seqs, n_outputs);
22382260 }
@@ -3343,6 +3365,7 @@ llama_context_params llama_context_default_params() {
33433365 /* .n_ubatch =*/ 512 ,
33443366 /* .n_seq_max =*/ 1 ,
33453367 /* .n_rs_seq =*/ 0 ,
3368+ /* .n_outputs_per_seq =*/ 0 ,
33463369 /* .n_threads =*/ GGML_DEFAULT_N_THREADS , // TODO: better default
33473370 /* .n_threads_batch =*/ GGML_DEFAULT_N_THREADS ,
33483371 /* .ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT ,
0 commit comments