Skip to content

Commit 9e3b928

Browse files
ddh0CISC
andauthored
common : relax sampler name matching (ggml-org#23744)
* common : relax sampler name matching Currently, in some cases, the alternative names for samplers (like `top-k` and `min-p` instead of the canonical `top_k` and `min_p`) are not always recognized by the `common_sampler_types_from_names` function in `common/sampling.cpp`. This PR changes the signature of this function to remove the `bool allow_alt_names` flag, and removes all occurences of the flag from call sites. Therefore, the function will now always match all known names. I also changed the logic of the function to unconditionally check the provided sampler names against both the canonical and alternative names, and to be case-insensitive. This fixes an issue I was seeing wherein samplers specified in the `llama-server` UI were not recognized as valid when the alternative names were used. * add more alt names * cont. fix * cast to unsigned char for correctness * common : unify sampler name mapping * annotate canonical vs. alt sampler name mappings per @CISC * Update common/sampling.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * common : auto-generate sampler name aliases per @ngxson * use merged map for matching * use `.merge` instead of iterating * nit: simplify comment * nit: use insert everywhere, not index assignment --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 8a963fc commit 9e3b928

5 files changed

Lines changed: 53 additions & 44 deletions

File tree

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1615,7 +1615,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
16151615
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
16161616
[](common_params & params, const std::string & value) {
16171617
const auto sampler_names = string_split<std::string>(value, ';');
1618-
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
1618+
params.sampling.samplers = common_sampler_types_from_names(sampler_names);
16191619
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
16201620
}
16211621
).set_sampling());

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1148,7 +1148,7 @@ static void common_init_sampler_from_model(
11481148
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
11491149
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
11501150
if (!sampler_names.empty()) {
1151-
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
1151+
sparams.samplers = common_sampler_types_from_names(sampler_names);
11521152
}
11531153
}
11541154
}

common/sampling.cpp

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -769,54 +769,63 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
769769
}
770770
}
771771

772-
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
773-
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
774-
{ "dry", COMMON_SAMPLER_TYPE_DRY },
775-
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
776-
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
777-
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
778-
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
779-
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
780-
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
781-
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
782-
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
783-
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
784-
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
785-
};
786-
787-
// since samplers names are written multiple ways
788-
// make it ready for both system names and input names
789-
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
790-
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
791-
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
792-
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
793-
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
794-
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
795-
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
796-
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
797-
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
798-
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
799-
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
800-
{ "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
801-
};
772+
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names) {
773+
// sampler names can be written multiple ways; generate aliases from canonical names
774+
static const auto sampler_name_map = []{
775+
// canonical sampler name mapping
776+
std::unordered_map<std::string, common_sampler_type> canonical_name_map {
777+
{ "dry", COMMON_SAMPLER_TYPE_DRY },
778+
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
779+
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
780+
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
781+
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
782+
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
783+
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
784+
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
785+
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
786+
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
787+
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }
788+
};
789+
std::unordered_map<std::string, common_sampler_type> alias_name_map;
790+
for (const auto & entry : canonical_name_map) {
791+
const std::string & canonical = entry.first;
792+
if (canonical.find('_') == std::string::npos) {
793+
continue;
794+
}
795+
// kebab-case: "top-k", "min-p", etc.
796+
{
797+
std::string kebab_case = canonical;
798+
std::replace(kebab_case.begin(), kebab_case.end(), '_', '-');
799+
alias_name_map.insert({kebab_case, entry.second});
800+
}
801+
// no dash: "topk", "minp", etc.
802+
{
803+
std::string no_dash = canonical;
804+
no_dash.erase(std::remove(no_dash.begin(), no_dash.end(), '_'), no_dash.end());
805+
alias_name_map.insert({no_dash, entry.second});
806+
}
807+
}
808+
// misc. aliases
809+
alias_name_map.insert({"nucleus", COMMON_SAMPLER_TYPE_TOP_P});
810+
alias_name_map.insert({"temp", COMMON_SAMPLER_TYPE_TEMPERATURE});
811+
alias_name_map.insert({"typ", COMMON_SAMPLER_TYPE_TYPICAL_P});
812+
// include aliases + canonical names in the complete mapping
813+
alias_name_map.merge(canonical_name_map);
814+
return alias_name_map;
815+
}();
802816

803817
std::vector<common_sampler_type> samplers;
804818
samplers.reserve(names.size());
805819

806820
for (const auto & name : names) {
807-
auto sampler = sampler_canonical_name_map.find(name);
808-
if (sampler != sampler_canonical_name_map.end()) {
821+
std::string name_lower = name;
822+
std::transform(name_lower.begin(), name_lower.end(), name_lower.begin(), ::tolower);
823+
auto sampler = sampler_name_map.find(name_lower);
824+
if (sampler != sampler_name_map.end()) {
809825
samplers.push_back(sampler->second);
810826
continue;
811827
}
812-
if (allow_alt_names) {
813-
sampler = sampler_alt_name_map.find(name);
814-
if (sampler != sampler_alt_name_map.end()) {
815-
samplers.push_back(sampler->second);
816-
continue;
817-
}
818-
}
819-
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
828+
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name_lower.c_str());
820829
}
821830

822831
return samplers;

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx,
109109
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
110110
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
111111

112-
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
112+
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names);
113113
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
114114

115115
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,

tools/server/server-task.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ task_params server_task::params_from_json_cmpl(
605605
const auto samplers = data.find("samplers");
606606
if (samplers != data.end()) {
607607
if (samplers->is_array()) {
608-
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
608+
params.sampling.samplers = common_sampler_types_from_names(*samplers);
609609
} else if (samplers->is_string()){
610610
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
611611
}

0 commit comments

Comments
 (0)