Skip to content

Commit de6f727

Browse files
authored
llama: limit max outputs of llama_context (ggml-org#23861)
* llama: save more VRAM by reserving n_outputs == n_seqs when possible * add n_outputs_per_seq * move n_outputs_max to server-context * change ubatch to batch everywhere
1 parent 95b8b8e commit de6f727

6 files changed

Lines changed: 71 additions & 11 deletions

File tree

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
15631563
cparams.n_ctx = params.n_ctx;
15641564
cparams.n_seq_max = params.n_parallel;
15651565
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
1566+
cparams.n_outputs_max = std::max(params.n_outputs_max, 0);
15661567
cparams.n_batch = params.n_batch;
15671568
cparams.n_ubatch = params.n_ubatch;
15681569
cparams.n_threads = params.cpuparams.n_threads;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ struct common_params {
431431
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
432432
int32_t n_parallel = 1; // number of parallel sequences to decode
433433
int32_t n_sequences = 1; // number of sequences to decode
434+
int32_t n_outputs_max = 0; // max outputs in a batch (0 = n_batch)
434435
int32_t grp_attn_n = 1; // group-attention factor
435436
int32_t grp_attn_w = 512; // group-attention width
436437
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ extern "C" {
339339
uint32_t n_ubatch; // physical maximum batch size
340340
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
341341
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL]
342+
uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch)
342343
int32_t n_threads; // number of threads to use for generation
343344
int32_t n_threads_batch; // number of threads to use for batch processing
344345

src/llama-context.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ llama_context::llama_context(
182182

183183
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
184184

185+
cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max;
186+
185187
cparams.op_offload = params.op_offload;
186188
cparams.kv_unified = params.kv_unified;
187189

@@ -531,7 +533,7 @@ void llama_context::sched_reserve() {
531533
// note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
532534
// because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
533535
// it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
534-
// the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
536+
// the ggml_mul_mat assertion fails.
535537
const uint32_t n_tokens_ch = 16*n_seqs;
536538
auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
537539
if (!gf) {
@@ -577,16 +579,18 @@ void llama_context::sched_reserve() {
577579
int n_splits_tg = -1;
578580
int n_nodes_tg = -1;
579581

582+
const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max);
583+
580584
// reserve pp (prompt processing) graph first so that buffers are only allocated once
581585
{
582-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
586+
auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(),
583587
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
584588
if (!gf) {
585589
if (cparams.pipeline_parallel) {
586590
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
587591
cparams.pipeline_parallel = false;
588592
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
589-
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
593+
gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get());
590594
}
591595
if (!gf) {
592596
throw std::runtime_error("failed to allocate compute pp buffers");
@@ -614,7 +618,7 @@ void llama_context::sched_reserve() {
614618
//
615619
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
616620
//
617-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
621+
auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc);
618622
if (!gf) {
619623
throw std::runtime_error("failed to allocate compute pp buffers");
620624
}
@@ -774,7 +778,9 @@ bool llama_context::memory_update(bool optimize) {
774778
const uint32_t n_seqs = cparams.n_seq_max;
775779
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
776780

777-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
781+
const uint32_t n_outputs_max = std::min(n_tokens, cparams.n_outputs_max);
782+
783+
auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_max, mctx.get());
778784
if (!gf) {
779785
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
780786
}
@@ -2140,6 +2146,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
21402146

21412147
this->n_outputs = 0;
21422148

2149+
GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max);
2150+
21432151
return n_outputs_max;
21442152
}
21452153

@@ -2226,8 +2234,6 @@ ggml_cgraph * llama_context::graph_reserve(
22262234

22272235
if (n_tokens % n_seqs != 0) {
22282236
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
2229-
n_outputs = std::max(n_outputs, n_tokens);
2230-
22312237
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
22322238
}
22332239

@@ -3337,6 +3343,7 @@ llama_context_params llama_context_default_params() {
33373343
/*.n_ubatch =*/ 512,
33383344
/*.n_seq_max =*/ 1,
33393345
/*.n_rs_seq =*/ 0,
3346+
/*.n_outputs_max =*/ 0,
33403347
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
33413348
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
33423349
/*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT,

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct llama_cparams {
1313
uint32_t n_ubatch;
1414
uint32_t n_seq_max;
1515
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
16+
uint32_t n_outputs_max; // max outputs supported by the context
1617
int32_t n_threads; // number of threads to use for generation
1718
int32_t n_threads_batch; // number of threads to use for batch processing
1819

tools/server/server-context.cpp

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,49 @@ using json = nlohmann::ordered_json;
3737

3838
constexpr int HTTP_POLLING_SECONDS = 1;
3939

40+
static uint32_t server_n_outputs_max(const common_params & params) {
41+
const uint32_t n_batch = params.n_batch;
42+
43+
if (params.embedding ||
44+
(params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && params.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
45+
return n_batch;
46+
}
47+
48+
uint32_t n_outputs_per_seq = 1;
49+
50+
for (const auto type : params.speculative.types) {
51+
switch (type) {
52+
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
53+
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
54+
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
55+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + std::max(0, params.speculative.draft.n_max));
56+
break;
57+
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
58+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_simple.size_m);
59+
break;
60+
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
61+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_map_k.size_m);
62+
break;
63+
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
64+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_map_k4v.size_m);
65+
break;
66+
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD:
67+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + std::max(0, params.speculative.ngram_mod.n_max));
68+
break;
69+
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
70+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + 8);
71+
break;
72+
case COMMON_SPECULATIVE_TYPE_NONE:
73+
case COMMON_SPECULATIVE_TYPE_COUNT:
74+
break;
75+
}
76+
}
77+
78+
const uint64_t n_outputs = (uint64_t) params.n_parallel * n_outputs_per_seq;
79+
80+
return std::max<uint32_t>(1, std::min<uint64_t>(n_batch, n_outputs));
81+
}
82+
4083
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
4184
enum slot_state {
4285
SLOT_STATE_IDLE,
@@ -753,6 +796,7 @@ struct server_context_impl {
753796
SRV_INF("loading model '%s'\n", params.model.path.c_str());
754797

755798
params_base = params;
799+
params_base.n_outputs_max = server_n_outputs_max(params_base);
756800

757801
std::string & mmproj_path = params_base.mmproj.path;
758802
bool has_mmproj = !mmproj_path.empty();
@@ -818,6 +862,10 @@ struct server_context_impl {
818862
measure_model_bytes = false;
819863
}
820864

865+
if (!has_draft) {
866+
params_dft.n_outputs_max = params_base.n_parallel;
867+
}
868+
821869
auto mparams_dft = common_model_params_to_llama(params_dft);
822870
auto cparams_dft = common_context_params_to_llama(params_dft);
823871
if (spec_mtp) {
@@ -941,10 +989,11 @@ struct server_context_impl {
941989
params_base.model.path.c_str());
942990

943991
auto cparams_mtp = common_context_params_to_llama(params_base);
944-
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
945-
cparams_mtp.type_k = params_base.speculative.draft.cache_type_k;
946-
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
947-
cparams_mtp.n_rs_seq = 0;
992+
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
993+
cparams_mtp.type_k = params_base.speculative.draft.cache_type_k;
994+
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
995+
cparams_mtp.n_rs_seq = 0;
996+
cparams_mtp.n_outputs_max = params_base.n_parallel;
948997

949998
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
950999
if (ctx_dft == nullptr) {

0 commit comments

Comments
 (0)