Skip to content

Commit d936634

Browse files
committed
move n_outputs_max to server-context
1 parent 810aa71 commit d936634

6 files changed

Lines changed: 42 additions & 48 deletions

File tree

common/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,8 +1562,8 @@ struct llama_context_params common_context_params_to_llama(const common_params &
15621562

15631563
cparams.n_ctx = params.n_ctx;
15641564
cparams.n_seq_max = params.n_parallel;
1565-
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
1566-
cparams.n_outputs_per_seq = std::max(params.n_outputs_per_seq, 0);
1565+
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
1566+
cparams.n_outputs_max = std::max(params.n_outputs_max, 0);
15671567
cparams.n_batch = params.n_batch;
15681568
cparams.n_ubatch = params.n_ubatch;
15691569
cparams.n_threads = params.cpuparams.n_threads;

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ struct common_params {
431431
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
432432
int32_t n_parallel = 1; // number of parallel sequences to decode
433433
int32_t n_sequences = 1; // number of sequences to decode
434-
int32_t n_outputs_per_seq = 0; // max outputs per sequence in a ubatch (0 = no limit)
434+
int32_t n_outputs_max = 0; // max outputs in a ubatch (0 = n_batch)
435435
int32_t grp_attn_n = 1; // group-attention factor
436436
int32_t grp_attn_w = 512; // group-attention width
437437
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ extern "C" {
339339
uint32_t n_ubatch; // physical maximum batch size
340340
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
341341
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL]
342-
uint32_t n_outputs_per_seq; // max outputs per sequence in a ubatch (0 = no limit)
342+
uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch)
343343
int32_t n_threads; // number of threads to use for generation
344344
int32_t n_threads_batch; // number of threads to use for batch processing
345345

src/llama-context.cpp

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,6 @@ static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) {
3030
throw std::runtime_error("Unsupported ctx type");
3131
}
3232

33-
static uint32_t graph_n_outputs_pp(const llama_cparams & cparams, uint32_t n_tokens, uint32_t n_seqs) {
34-
GGML_ASSERT(n_tokens >= 1);
35-
GGML_ASSERT(n_seqs >= 1);
36-
37-
const bool reserve_all_outputs =
38-
cparams.embeddings ||
39-
cparams.pooling_type != LLAMA_POOLING_TYPE_NONE ||
40-
cparams.n_outputs_per_seq == 0;
41-
42-
if (reserve_all_outputs) {
43-
return n_tokens;
44-
}
45-
46-
const uint64_t n_outputs = (uint64_t) n_seqs * cparams.n_outputs_per_seq;
47-
48-
return std::max<uint32_t>(1, std::min<uint64_t>(n_tokens, n_outputs));
49-
}
50-
5133
llama_context::llama_context(
5234
const llama_model & model,
5335
llama_context_params params) :
@@ -69,8 +51,6 @@ llama_context::llama_context(
6951
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
7052
}
7153

72-
cparams.n_outputs_per_seq = params.n_outputs_per_seq;
73-
7454
cparams.n_rs_seq = params.n_rs_seq;
7555
if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) {
7656
LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n",
@@ -202,6 +182,8 @@ llama_context::llama_context(
202182

203183
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
204184

185+
cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max;
186+
205187
cparams.op_offload = params.op_offload;
206188
cparams.kv_unified = params.kv_unified;
207189

@@ -597,7 +579,7 @@ void llama_context::sched_reserve() {
597579
int n_splits_tg = -1;
598580
int n_nodes_tg = -1;
599581

600-
const uint32_t n_outputs_pp = graph_n_outputs_pp(cparams, n_tokens, n_seqs);
582+
const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max);
601583

602584
// reserve pp (prompt processing) graph first so that buffers are only allocated once
603585
{
@@ -796,7 +778,7 @@ bool llama_context::memory_update(bool optimize) {
796778
const uint32_t n_seqs = cparams.n_seq_max;
797779
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
798780

799-
const uint32_t n_outputs_pp = graph_n_outputs_pp(cparams, n_tokens, n_seqs);
781+
const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max);
800782

801783
auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get());
802784
if (!gf) {
@@ -1804,6 +1786,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
18041786

18051787
// needs to happen before the graph is built
18061788
n_outputs = n_outputs_new;
1789+
1790+
GGML_ASSERT(n_outputs <= cparams.n_outputs_max);
18071791
}
18081792

18091793
ggml_status status;
@@ -3365,7 +3349,7 @@ llama_context_params llama_context_default_params() {
33653349
/*.n_ubatch =*/ 512,
33663350
/*.n_seq_max =*/ 1,
33673351
/*.n_rs_seq =*/ 0,
3368-
/*.n_outputs_per_seq =*/ 0,
3352+
/*.n_outputs_max =*/ 0,
33693353
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
33703354
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
33713355
/*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT,

src/llama-cparams.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ struct llama_cparams {
1212
uint32_t n_batch;
1313
uint32_t n_ubatch;
1414
uint32_t n_seq_max;
15-
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
16-
uint32_t n_outputs_per_seq; // max outputs per sequence in a ubatch (0 = no limit)
17-
int32_t n_threads; // number of threads to use for generation
18-
int32_t n_threads_batch; // number of threads to use for batch processing
15+
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
16+
uint32_t n_outputs_max; // max outputs in a ubatch
17+
int32_t n_threads; // number of threads to use for generation
18+
int32_t n_threads_batch; // number of threads to use for batch processing
1919

2020
float rope_freq_base;
2121
float rope_freq_scale;

tools/server/server-context.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,38 +37,48 @@ using json = nlohmann::ordered_json;
3737

3838
constexpr int HTTP_POLLING_SECONDS = 1;
3939

40-
static uint32_t server_n_outputs_per_seq(const common_params_speculative & speculative) {
41-
uint32_t n_outputs = 1;
40+
static uint32_t server_n_outputs_max(const common_params & params) {
41+
const uint32_t n_batch = params.n_batch;
42+
const uint32_t n_ubatch = std::min(n_batch, params.n_ubatch == 0 ? n_batch : params.n_ubatch);
4243

43-
for (const auto type : speculative.types) {
44+
if (params.embedding ||
45+
(params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && params.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
46+
return n_ubatch;
47+
}
48+
49+
uint32_t n_outputs_per_seq = 1;
50+
51+
for (const auto type : params.speculative.types) {
4452
switch (type) {
4553
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
4654
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
4755
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
48-
n_outputs = std::max<uint32_t>(n_outputs, 1 + std::max(0, speculative.draft.n_max));
56+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + std::max(0, params.speculative.draft.n_max));
4957
break;
5058
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
51-
n_outputs = std::max<uint32_t>(n_outputs, 1 + speculative.ngram_simple.size_m);
59+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_simple.size_m);
5260
break;
5361
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
54-
n_outputs = std::max<uint32_t>(n_outputs, 1 + speculative.ngram_map_k.size_m);
62+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_map_k.size_m);
5563
break;
5664
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
57-
n_outputs = std::max<uint32_t>(n_outputs, 1 + speculative.ngram_map_k4v.size_m);
65+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_map_k4v.size_m);
5866
break;
5967
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD:
60-
n_outputs = std::max<uint32_t>(n_outputs, 1 + std::max(0, speculative.ngram_mod.n_max));
68+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + std::max(0, params.speculative.ngram_mod.n_max));
6169
break;
6270
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
63-
n_outputs = std::max<uint32_t>(n_outputs, 1 + 8);
71+
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + 8);
6472
break;
6573
case COMMON_SPECULATIVE_TYPE_NONE:
6674
case COMMON_SPECULATIVE_TYPE_COUNT:
6775
break;
6876
}
6977
}
7078

71-
return n_outputs;
79+
const uint64_t n_outputs = (uint64_t) params.n_parallel * n_outputs_per_seq;
80+
81+
return std::max<uint32_t>(1, std::min<uint64_t>(n_ubatch, n_outputs));
7282
}
7383

7484
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
@@ -787,7 +797,7 @@ struct server_context_impl {
787797
SRV_INF("loading model '%s'\n", params.model.path.c_str());
788798

789799
params_base = params;
790-
params_base.n_outputs_per_seq = server_n_outputs_per_seq(params_base.speculative);
800+
params_base.n_outputs_max = server_n_outputs_max(params_base);
791801

792802
std::string & mmproj_path = params_base.mmproj.path;
793803
bool has_mmproj = !mmproj_path.empty();
@@ -854,7 +864,7 @@ struct server_context_impl {
854864
}
855865

856866
if (!has_draft) {
857-
params_dft.n_outputs_per_seq = 1;
867+
params_dft.n_outputs_max = params_base.n_parallel;
858868
}
859869

860870
auto mparams_dft = common_model_params_to_llama(params_dft);
@@ -980,11 +990,11 @@ struct server_context_impl {
980990
params_base.model.path.c_str());
981991

982992
auto cparams_mtp = common_context_params_to_llama(params_base);
983-
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
984-
cparams_mtp.type_k = params_base.speculative.draft.cache_type_k;
985-
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
986-
cparams_mtp.n_rs_seq = 0;
987-
cparams_mtp.n_outputs_per_seq = 1;
993+
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
994+
cparams_mtp.type_k = params_base.speculative.draft.cache_type_k;
995+
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
996+
cparams_mtp.n_rs_seq = 0;
997+
cparams_mtp.n_outputs_max = params_base.n_parallel;
988998

989999
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
9901000
if (ctx_dft == nullptr) {

0 commit comments

Comments
 (0)