Skip to content

Commit 5dcb711

Browse files
authored
speculative : fix n_outputs_max and remove draft-simple auto-enable (ggml-org#23988)
* speculative : add common_speculative_n_max helper function Extract the speculative max-draft-size logic from server_n_outputs_max into a reusable common_speculative_n_max() function in common/speculative. Assisted-by: llama.cpp:local pi * cont : draft context always has n_parallel outputs * llama : log n_outputs_max * speculative : remove draft-simple auto-enable * ci : enable server tests on PRs
1 parent 5aa3a64 commit 5dcb711

6 files changed

Lines changed: 40 additions & 51 deletions

File tree

.github/workflows/server.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ jobs:
102102

103103
- name: Tests
104104
id: server_integration_tests
105-
if: ${{ !github.event.pull_request }}
106105
run: |
107106
cd tools/server/tests
108107
pytest -v -x -m "not slow"
@@ -116,7 +115,6 @@ jobs:
116115
117116
- name: Tests (Backend sampling)
118117
id: server_integration_tests_backend_sampling
119-
if: ${{ !github.event.pull_request }}
120118
run: |
121119
cd tools/server/tests
122120
export LLAMA_ARG_BACKEND_SAMPLING=1
@@ -169,7 +167,6 @@ jobs:
169167

170168
- name: Tests
171169
id: server_integration_tests
172-
if: ${{ !github.event.pull_request }}
173170
run: |
174171
cd tools/server/tests
175172
$env:PYTHONIOENCODING = ":replace"

common/arg.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,11 +1041,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10411041
// we define here to make sure it's included in llama-gen-docs
10421042
if (ex == LLAMA_EXAMPLE_COMPLETION) {
10431043
params.use_jinja = false; // disable jinja by default
1044-
10451044
} else if (ex == LLAMA_EXAMPLE_MTMD) {
10461045
params.use_jinja = false; // disable jinja by default
10471046
params.sampling.temp = 0.2; // lower temp by default for better quality
1048-
10491047
} else if (ex == LLAMA_EXAMPLE_SERVER) {
10501048
params.n_parallel = -1; // auto by default
10511049
}
@@ -1066,7 +1064,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10661064
sampler_type_names.pop_back(); // remove last semicolon
10671065
}
10681066

1069-
10701067
/**
10711068
* filter options by example
10721069
* rules:
@@ -1080,7 +1077,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10801077
}
10811078
};
10821079

1083-
10841080
add_opt(common_arg(
10851081
{"-h", "--help", "--usage"},
10861082
"print usage and exit",

common/speculative.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,40 @@ static uint32_t common_get_enabled_speculative_configs(const std::vector<common_
13171317
return result;
13181318
}
13191319

1320+
int32_t common_speculative_n_max(const common_params_speculative * spec) {
1321+
int32_t n_max = 0;
1322+
1323+
for (const auto type : spec->types) {
1324+
switch (type) {
1325+
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
1326+
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
1327+
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
1328+
n_max = std::max(n_max, std::max(0, spec->draft.n_max));
1329+
break;
1330+
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
1331+
n_max = std::max(n_max, (int32_t) spec->ngram_simple.size_m);
1332+
break;
1333+
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
1334+
n_max = std::max(n_max, (int32_t) spec->ngram_map_k.size_m);
1335+
break;
1336+
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
1337+
n_max = std::max(n_max, (int32_t) spec->ngram_map_k4v.size_m);
1338+
break;
1339+
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD:
1340+
n_max = std::max(n_max, std::max(0, spec->ngram_mod.n_max));
1341+
break;
1342+
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
1343+
n_max = std::max(n_max, (int32_t) 8);
1344+
break;
1345+
case COMMON_SPECULATIVE_TYPE_NONE:
1346+
case COMMON_SPECULATIVE_TYPE_COUNT:
1347+
break;
1348+
}
1349+
}
1350+
1351+
return n_max;
1352+
}
1353+
13201354
// initialization of the speculative decoding system
13211355
//
13221356
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) {
@@ -1325,8 +1359,6 @@ common_speculative * common_speculative_init(common_params_speculative & params,
13251359
{
13261360
uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types);
13271361

1328-
bool has_draft_model_path = !params.draft.mparams.path.empty();
1329-
13301362
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
13311363
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
13321364
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
@@ -1359,16 +1391,6 @@ common_speculative * common_speculative_init(common_params_speculative & params,
13591391
if (has_ngram_cache) {
13601392
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
13611393
}
1362-
if (has_draft_simple) {
1363-
if (!has_draft_model_path) {
1364-
LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__);
1365-
has_draft_simple = false;
1366-
}
1367-
} else if (has_draft_model_path && !has_mtp && !has_draft_eagle3) {
1368-
LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__);
1369-
has_draft_simple = true;
1370-
}
1371-
13721394
if (has_draft_simple) {
13731395
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params));
13741396
}

common/speculative.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
2020
// convert type to string
2121
std::string common_speculative_type_to_str(enum common_speculative_type type);
2222

23+
// return the max number of draft tokens based on the speculative parameters
24+
int32_t common_speculative_n_max(const common_params_speculative * spec);
25+
2326
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);
2427

2528
void common_speculative_free(common_speculative * spec);

src/llama-context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ llama_context::llama_context(
229229
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
230230
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
231231
LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq);
232+
LLAMA_LOG_INFO("%s: n_outputs_max = %u\n", __func__, cparams.n_outputs_max);
232233

233234
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
234235
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",

tools/server/server-context.cpp

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,35 +45,7 @@ static uint32_t server_n_outputs_max(const common_params & params) {
4545
return n_batch;
4646
}
4747

48-
uint32_t n_outputs_per_seq = 1;
49-
50-
for (const auto type : params.speculative.types) {
51-
switch (type) {
52-
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
53-
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
54-
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
55-
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + std::max(0, params.speculative.draft.n_max));
56-
break;
57-
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
58-
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_simple.size_m);
59-
break;
60-
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
61-
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_map_k.size_m);
62-
break;
63-
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
64-
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_map_k4v.size_m);
65-
break;
66-
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD:
67-
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + std::max(0, params.speculative.ngram_mod.n_max));
68-
break;
69-
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
70-
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + 8);
71-
break;
72-
case COMMON_SPECULATIVE_TYPE_NONE:
73-
case COMMON_SPECULATIVE_TYPE_COUNT:
74-
break;
75-
}
76-
}
48+
const uint32_t n_outputs_per_seq = 1 + common_speculative_n_max(&params.speculative);
7749

7850
const uint64_t n_outputs = (uint64_t) params.n_parallel * n_outputs_per_seq;
7951

@@ -862,9 +834,7 @@ struct server_context_impl {
862834
measure_model_bytes = false;
863835
}
864836

865-
if (!has_draft) {
866-
params_dft.n_outputs_max = params_base.n_parallel;
867-
}
837+
params_dft.n_outputs_max = params_base.n_parallel;
868838

869839
auto mparams_dft = common_model_params_to_llama(params_dft);
870840
auto cparams_dft = common_context_params_to_llama(params_dft);

0 commit comments

Comments
 (0)