@@ -182,6 +182,8 @@ llama_context::llama_context(
182182
183183 cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
184184
185+ cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max ;
186+
185187 cparams.op_offload = params.op_offload ;
186188 cparams.kv_unified = params.kv_unified ;
187189
@@ -531,7 +533,7 @@ void llama_context::sched_reserve() {
531533 // note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
532534 // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
533535 // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
534- // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
536+ // the ggml_mul_mat assertion fails.
535537 const uint32_t n_tokens_ch = 16 *n_seqs;
536538 auto * gf = graph_reserve (n_tokens_ch, n_seqs, n_tokens_ch, mctx.get (), true );
537539 if (!gf) {
@@ -577,16 +579,18 @@ void llama_context::sched_reserve() {
577579 int n_splits_tg = -1 ;
578580 int n_nodes_tg = -1 ;
579581
582+ const uint32_t n_outputs_pp = std::min (n_tokens, cparams.n_outputs_max );
583+
580584 // reserve pp (prompt processing) graph first so that buffers are only allocated once
581585 {
582- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens , mctx.get (),
586+ auto * gf = graph_reserve (n_tokens, n_seqs, n_outputs_pp , mctx.get (),
583587 model.hparams .no_alloc , model.hparams .no_alloc ? backend_buf_exp_size.data () : nullptr );
584588 if (!gf) {
585589 if (cparams.pipeline_parallel ) {
586590 LLAMA_LOG_WARN (" %s: compute buffer allocation failed, retrying without pipeline parallelism\n " , __func__);
587591 cparams.pipeline_parallel = false ;
588592 sched.reset (ggml_backend_sched_new (backend_ptrs.data (), backend_buft.data (), backend_ptrs.size (), max_nodes, false , cparams.op_offload ));
589- gf = graph_reserve (n_tokens, n_seqs, n_tokens , mctx.get ());
593+ gf = graph_reserve (n_tokens, n_seqs, n_outputs_pp , mctx.get ());
590594 }
591595 if (!gf) {
592596 throw std::runtime_error (" failed to allocate compute pp buffers" );
@@ -614,7 +618,7 @@ void llama_context::sched_reserve() {
614618 //
615619 // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
616620 //
617- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens , mctx.get (), model.hparams .no_alloc );
621+ auto * gf = graph_reserve (n_tokens, n_seqs, n_outputs_pp , mctx.get (), model.hparams .no_alloc );
618622 if (!gf) {
619623 throw std::runtime_error (" failed to allocate compute pp buffers" );
620624 }
@@ -774,7 +778,9 @@ bool llama_context::memory_update(bool optimize) {
774778 const uint32_t n_seqs = cparams.n_seq_max ;
775779 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
776780
777- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
781+ const uint32_t n_outputs_max = std::min (n_tokens, cparams.n_outputs_max );
782+
783+ auto * gf = graph_reserve (n_tokens, n_seqs, n_outputs_max, mctx.get ());
778784 if (!gf) {
779785 LLAMA_LOG_ERROR (" %s: failed to reserve graph after the memory update\n " , __func__);
780786 }
@@ -2140,6 +2146,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
21402146
21412147 this ->n_outputs = 0 ;
21422148
2149+ GGML_ASSERT (n_outputs_max <= cparams.n_outputs_max );
2150+
21432151 return n_outputs_max;
21442152}
21452153
@@ -2226,8 +2234,6 @@ ggml_cgraph * llama_context::graph_reserve(
22262234
22272235 if (n_tokens % n_seqs != 0 ) {
22282236 n_tokens = ((n_tokens + (n_seqs - 1 )) / n_seqs) * n_seqs; // round to next multiple of n_seqs
2229- n_outputs = std::max (n_outputs, n_tokens);
2230-
22312237 LLAMA_LOG_DEBUG (" %s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n " , __func__, n_tokens, n_seqs, n_outputs);
22322238 }
22332239
@@ -3337,6 +3343,7 @@ llama_context_params llama_context_default_params() {
33373343 /* .n_ubatch =*/ 512 ,
33383344 /* .n_seq_max =*/ 1 ,
33393345 /* .n_rs_seq =*/ 0 ,
3346+ /* .n_outputs_max =*/ 0 ,
33403347 /* .n_threads =*/ GGML_DEFAULT_N_THREADS , // TODO: better default
33413348 /* .n_threads_batch =*/ GGML_DEFAULT_N_THREADS ,
33423349 /* .ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT ,
0 commit comments