@@ -51,6 +51,8 @@ llama_context::llama_context(
5151 throw std::runtime_error (" n_seq_max must be <= " + std::to_string (LLAMA_MAX_SEQ));
5252 }
5353
54+ cparams.n_outputs_per_seq = params.n_outputs_per_seq ;
55+
5456 cparams.n_rs_seq = params.n_rs_seq ;
5557 if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback (model.arch )) {
5658 LLAMA_LOG_DEBUG (" %s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n " ,
@@ -577,8 +579,7 @@ void llama_context::sched_reserve() {
577579 int n_splits_tg = -1 ;
578580 int n_nodes_tg = -1 ;
579581
580- const bool reserve_all_outputs = cparams.embeddings || cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
581- const uint32_t n_outputs_pp = reserve_all_outputs ? n_tokens : n_seqs;
582+ const uint32_t n_outputs_pp = graph_n_outputs_pp (n_tokens, n_seqs);
582583
583584 // reserve pp (prompt processing) graph first so that buffers are only allocated once
584585 {
@@ -777,8 +778,7 @@ bool llama_context::memory_update(bool optimize) {
777778 const uint32_t n_seqs = cparams.n_seq_max ;
778779 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
779780
780- const bool reserve_all_outputs = cparams.embeddings || cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
781- const uint32_t n_outputs_pp = reserve_all_outputs ? n_tokens : n_seqs;
781+ const uint32_t n_outputs_pp = graph_n_outputs_pp (n_tokens, n_seqs);
782782
783783 auto * gf = graph_reserve (n_tokens, n_seqs, n_outputs_pp, mctx.get ());
784784 if (!gf) {
@@ -2221,6 +2221,24 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
22212221 return res;
22222222}
22232223
2224+ uint32_t llama_context::graph_n_outputs_pp (uint32_t n_tokens, uint32_t n_seqs) const {
2225+ GGML_ASSERT (n_tokens >= 1 );
2226+ GGML_ASSERT (n_seqs >= 1 );
2227+
2228+ const bool reserve_all_outputs =
2229+ cparams.embeddings ||
2230+ cparams.pooling_type != LLAMA_POOLING_TYPE_NONE ||
2231+ cparams.n_outputs_per_seq == 0 ;
2232+
2233+ if (reserve_all_outputs) {
2234+ return n_tokens;
2235+ }
2236+
2237+ const uint64_t n_outputs = (uint64_t ) n_seqs * cparams.n_outputs_per_seq ;
2238+
2239+ return std::max<uint32_t >(1 , std::min<uint64_t >(n_tokens, n_outputs));
2240+ }
2241+
22242242llm_graph_result * llama_context::get_gf_res_reserve () const {
22252243 return static_cast <llm_graph_result *>(gf_res_reserve.get ());
22262244}
@@ -2230,9 +2248,13 @@ ggml_cgraph * llama_context::graph_reserve(
22302248 LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
22312249 GGML_ASSERT (n_outputs >= 1 );
22322250
2251+ const bool reserve_all_outputs = n_outputs >= n_tokens;
2252+
22332253 if (n_tokens % n_seqs != 0 ) {
22342254 n_tokens = ((n_tokens + (n_seqs - 1 )) / n_seqs) * n_seqs; // round to next multiple of n_seqs
2235- n_outputs = std::max (n_outputs, n_tokens);
2255+ if (reserve_all_outputs) {
2256+ n_outputs = std::max (n_outputs, n_tokens);
2257+ }
22362258
22372259 LLAMA_LOG_DEBUG (" %s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n " , __func__, n_tokens, n_seqs, n_outputs);
22382260 }
@@ -3343,6 +3365,7 @@ llama_context_params llama_context_default_params() {
33433365 /* .n_ubatch =*/ 512 ,
33443366 /* .n_seq_max =*/ 1 ,
33453367 /* .n_rs_seq =*/ 0 ,
3368+ /* .n_outputs_per_seq =*/ 0 ,
33463369 /* .n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
33473370 /* .n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
33483371 /* .ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT,
0 commit comments