Skip to content

Commit abbb1e4

Browse files
committed
add n_outputs_per_seq
1 parent af9f4af commit abbb1e4

7 files changed

Lines changed: 82 additions & 12 deletions

File tree

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
15631563
cparams.n_ctx = params.n_ctx;
15641564
cparams.n_seq_max = params.n_parallel;
15651565
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
1566+
cparams.n_outputs_per_seq = std::max(params.n_outputs_per_seq, 0);
15661567
cparams.n_batch = params.n_batch;
15671568
cparams.n_ubatch = params.n_ubatch;
15681569
cparams.n_threads = params.cpuparams.n_threads;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ struct common_params {
431431
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
432432
int32_t n_parallel = 1; // number of parallel sequences to decode
433433
int32_t n_sequences = 1; // number of sequences to decode
434+
int32_t n_outputs_per_seq = 0; // max outputs per sequence in a ubatch (0 = no limit)
434435
int32_t grp_attn_n = 1; // group-attention factor
435436
int32_t grp_attn_w = 512; // group-attention width
436437
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ extern "C" {
339339
uint32_t n_ubatch; // physical maximum batch size
340340
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
341341
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL]
342+
uint32_t n_outputs_per_seq; // max outputs per sequence in a ubatch (0 = no limit)
342343
int32_t n_threads; // number of threads to use for generation
343344
int32_t n_threads_batch; // number of threads to use for batch processing
344345

src/llama-context.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ llama_context::llama_context(
5151
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
5252
}
5353

54+
cparams.n_outputs_per_seq = params.n_outputs_per_seq;
55+
5456
cparams.n_rs_seq = params.n_rs_seq;
5557
if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) {
5658
LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n",
@@ -577,8 +579,7 @@ void llama_context::sched_reserve() {
577579
int n_splits_tg = -1;
578580
int n_nodes_tg = -1;
579581

580-
const bool reserve_all_outputs = cparams.embeddings || cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
581-
const uint32_t n_outputs_pp = reserve_all_outputs ? n_tokens : n_seqs;
582+
const uint32_t n_outputs_pp = graph_n_outputs_pp(n_tokens, n_seqs);
582583

583584
// reserve pp (prompt processing) graph first so that buffers are only allocated once
584585
{
@@ -777,8 +778,7 @@ bool llama_context::memory_update(bool optimize) {
777778
const uint32_t n_seqs = cparams.n_seq_max;
778779
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
779780

780-
const bool reserve_all_outputs = cparams.embeddings || cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
781-
const uint32_t n_outputs_pp = reserve_all_outputs ? n_tokens : n_seqs;
781+
const uint32_t n_outputs_pp = graph_n_outputs_pp(n_tokens, n_seqs);
782782

783783
auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get());
784784
if (!gf) {
@@ -2221,6 +2221,24 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
22212221
return res;
22222222
}
22232223

2224+
uint32_t llama_context::graph_n_outputs_pp(uint32_t n_tokens, uint32_t n_seqs) const {
2225+
GGML_ASSERT(n_tokens >= 1);
2226+
GGML_ASSERT(n_seqs >= 1);
2227+
2228+
const bool reserve_all_outputs =
2229+
cparams.embeddings ||
2230+
cparams.pooling_type != LLAMA_POOLING_TYPE_NONE ||
2231+
cparams.n_outputs_per_seq == 0;
2232+
2233+
if (reserve_all_outputs) {
2234+
return n_tokens;
2235+
}
2236+
2237+
const uint64_t n_outputs = (uint64_t) n_seqs * cparams.n_outputs_per_seq;
2238+
2239+
return std::max<uint32_t>(1, std::min<uint64_t>(n_tokens, n_outputs));
2240+
}
2241+
22242242
llm_graph_result * llama_context::get_gf_res_reserve() const {
22252243
return static_cast<llm_graph_result *>(gf_res_reserve.get());
22262244
}
@@ -2230,9 +2248,13 @@ ggml_cgraph * llama_context::graph_reserve(
22302248
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
22312249
GGML_ASSERT(n_outputs >= 1);
22322250

2251+
const bool reserve_all_outputs = n_outputs >= n_tokens;
2252+
22332253
if (n_tokens % n_seqs != 0) {
22342254
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
2235-
n_outputs = std::max(n_outputs, n_tokens);
2255+
if (reserve_all_outputs) {
2256+
n_outputs = std::max(n_outputs, n_tokens);
2257+
}
22362258

22372259
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
22382260
}
@@ -3343,6 +3365,7 @@ llama_context_params llama_context_default_params() {
33433365
/*.n_ubatch =*/ 512,
33443366
/*.n_seq_max =*/ 1,
33453367
/*.n_rs_seq =*/ 0,
3368+
/*.n_outputs_per_seq =*/ 0,
33463369
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
33473370
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
33483371
/*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT,

src/llama-context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ struct llama_context {
238238
// returns the result of ggml_backend_sched_graph_compute_async execution
239239
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
240240

241+
// max outputs to reserve for prompt-processing graphs
242+
uint32_t graph_n_outputs_pp(uint32_t n_tokens, uint32_t n_seqs) const;
243+
241244
// reserve a graph with a dummy ubatch of the specified size
242245
ggml_cgraph * graph_reserve(
243246
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);

src/llama-cparams.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ struct llama_cparams {
1212
uint32_t n_batch;
1313
uint32_t n_ubatch;
1414
uint32_t n_seq_max;
15-
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
16-
int32_t n_threads; // number of threads to use for generation
17-
int32_t n_threads_batch; // number of threads to use for batch processing
15+
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
16+
uint32_t n_outputs_per_seq; // max outputs per sequence in a ubatch (0 = no limit)
17+
int32_t n_threads; // number of threads to use for generation
18+
int32_t n_threads_batch; // number of threads to use for batch processing
1819

1920
float rope_freq_base;
2021
float rope_freq_scale;

tools/server/server-context.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,40 @@ using json = nlohmann::ordered_json;
3737

3838
constexpr int HTTP_POLLING_SECONDS = 1;
3939

40+
static uint32_t server_n_outputs_per_seq(const common_params_speculative & speculative) {
41+
uint32_t n_outputs = 1;
42+
43+
for (const auto type : speculative.types) {
44+
switch (type) {
45+
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
46+
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
47+
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
48+
n_outputs = std::max<uint32_t>(n_outputs, 1 + std::max(0, speculative.draft.n_max));
49+
break;
50+
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
51+
n_outputs = std::max<uint32_t>(n_outputs, 1 + speculative.ngram_simple.size_m);
52+
break;
53+
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
54+
n_outputs = std::max<uint32_t>(n_outputs, 1 + speculative.ngram_map_k.size_m);
55+
break;
56+
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
57+
n_outputs = std::max<uint32_t>(n_outputs, 1 + speculative.ngram_map_k4v.size_m);
58+
break;
59+
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD:
60+
n_outputs = std::max<uint32_t>(n_outputs, 1 + std::max(0, speculative.ngram_mod.n_max));
61+
break;
62+
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
63+
n_outputs = std::max<uint32_t>(n_outputs, 1 + 8);
64+
break;
65+
case COMMON_SPECULATIVE_TYPE_NONE:
66+
case COMMON_SPECULATIVE_TYPE_COUNT:
67+
break;
68+
}
69+
}
70+
71+
return n_outputs;
72+
}
73+
4074
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
4175
enum slot_state {
4276
SLOT_STATE_IDLE,
@@ -753,6 +787,7 @@ struct server_context_impl {
753787
SRV_INF("loading model '%s'\n", params.model.path.c_str());
754788

755789
params_base = params;
790+
params_base.n_outputs_per_seq = server_n_outputs_per_seq(params_base.speculative);
756791

757792
std::string & mmproj_path = params_base.mmproj.path;
758793
bool has_mmproj = !mmproj_path.empty();
@@ -818,6 +853,10 @@ struct server_context_impl {
818853
measure_model_bytes = false;
819854
}
820855

856+
if (!has_draft) {
857+
params_dft.n_outputs_per_seq = 1;
858+
}
859+
821860
auto mparams_dft = common_model_params_to_llama(params_dft);
822861
auto cparams_dft = common_context_params_to_llama(params_dft);
823862
if (spec_mtp) {
@@ -941,10 +980,11 @@ struct server_context_impl {
941980
params_base.model.path.c_str());
942981

943982
auto cparams_mtp = common_context_params_to_llama(params_base);
944-
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
945-
cparams_mtp.type_k = params_base.speculative.draft.cache_type_k;
946-
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
947-
cparams_mtp.n_rs_seq = 0;
983+
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
984+
cparams_mtp.type_k = params_base.speculative.draft.cache_type_k;
985+
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
986+
cparams_mtp.n_rs_seq = 0;
987+
cparams_mtp.n_outputs_per_seq = 1;
948988

949989
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
950990
if (ctx_dft == nullptr) {

0 commit comments

Comments
 (0)