Skip to content

Commit d14ce3d

Browse files
authored
llama : MTP clean-up (#23269)
* llama : disable equal splits for recurrent memory with partial rollback * spec : re-enable p-min with MTP drafts * spec : re-enable ngram spec in combination with RS rollback * spec : fix ngram-map-* params * spec : fix acceptance logic in combined ngram + draft configs * graph : fix reuse for combined `token` + `embd` batches * spec : log parameters for each speculative implementation - add LOG_INF in each constructor with implementation type and parameters - extract device string logic into common_speculative_get_devices_str() - move 'adding speculative implementation' log from init into constructors Assisted-by: llama.cpp:local pi * spec : extend --spec-default with ngram-map-k4v Assisted-by: llama.cpp:local pi * minor : fix n_embd log * args : update draft.n_max == 3 + regen docs * spec : relax ngram-mod rejection thold to 0.25 @ 5 low * logs : improve * docs : update speculative decoding CLI argument documentation - Add missing draft model CPU scheduling and tensor override parameters - Update --spec-type to include all available types (excluding draft-eagle3 WIP) - Fix default values to match implementation (n_max=3, n_min=0, p_min=0.0) - Remove deprecated options (spec-draft-ctx-size, spec-draft-replace) - Add environment variables for new parameters Assisted-by: llama.cpp:local pi * arg : step-back on adding k4v to the default spec config * cont : fix name
1 parent 6db1304 commit d14ce3d

15 files changed

Lines changed: 293 additions & 134 deletions

common/arg.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
536536
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
537537
}
538538
if (!seen_args.insert(arg).second) {
539-
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
539+
const bool skip = (arg == "--spec-type");
540+
541+
if (!skip) {
542+
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
543+
}
540544
}
541545
auto & tmp = arg_to_options[arg];
542546
auto opt = *tmp.first;
@@ -893,7 +897,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
893897
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
894898
}
895899
if (!seen_args.insert(arg).second) {
896-
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
900+
const bool skip = (arg == "--spec-type");
901+
902+
if (!skip) {
903+
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
904+
}
897905
}
898906
auto opt = *arg_to_options[arg];
899907
std::string val;
@@ -4117,6 +4125,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
41174125
params.speculative.ngram_mod.n_match = 24;
41184126
params.speculative.ngram_mod.n_min = 48;
41194127
params.speculative.ngram_mod.n_max = 64;
4128+
4129+
// TODO: not sure if this is a good config - explore more settings and potentially enable it
4130+
//params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
4131+
//params.speculative.ngram_map_k4v.size_n = 8;
4132+
//params.speculative.ngram_map_k4v.size_m = 24;
4133+
//params.speculative.ngram_map_k4v.min_hits = 2;
41204134
}
41214135
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
41224136

common/common.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,29 +1256,6 @@ common_init_result::common_init_result(common_params & params, bool model_only)
12561256
cparams.n_samplers = pimpl->samplers_seq_config.size();
12571257
}
12581258

1259-
// [TAG_RS_STATE_ROLLBACK_SUPPORT]
1260-
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
1261-
// currently this is not supported. so we disable the partial rollback
1262-
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
1263-
auto & types = params.speculative.types;
1264-
1265-
for (int i = 0; i < (int) types.size(); i++) {
1266-
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
1267-
continue;
1268-
}
1269-
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
1270-
continue;
1271-
}
1272-
1273-
cparams.n_rs_seq = 0;
1274-
1275-
LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
1276-
common_speculative_type_to_str(types[i]).c_str());
1277-
1278-
break;
1279-
}
1280-
}
1281-
12821259
llama_context * lctx = llama_init_from_model(model, cparams);
12831260
if (lctx == NULL) {
12841261
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());

common/common.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,11 @@ struct common_params_model {
299299

300300
// draft-model-based speculative decoding parameters
301301
struct common_params_speculative_draft {
302-
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
303-
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
302+
int32_t n_max = 3; // maximum number of tokens to draft during speculative decoding
303+
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
304304

305-
float p_split = 0.1f; // speculative decoding split probability
306-
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f
305+
float p_split = 0.1f; // speculative decoding split probability
306+
float p_min = 0.0f; // minimum speculative decoding probability (greedy)
307307

308308
common_params_model mparams;
309309

common/ngram-map.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ void common_ngram_map_draft(common_ngram_map & map,
500500
draft.push_back(inp[match_pos + n + i]);
501501
}
502502

503-
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
503+
LOG_DBG("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
504504
key_offset, slot_max,
505505
curr_key.key_num, draft.size());
506506

0 commit comments

Comments
 (0)