Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;

bool found_mtp = false;
common_params_model mtp;
};

static handle_model_result common_params_handle_model(struct common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool offline,
bool search_mtp = false) {
handle_model_result result;

if (!model.docker_repo.empty()) {
Expand All @@ -354,7 +358,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
auto download_result = common_download_model(model, opts, true);
auto download_result = common_download_model(model, opts, true, search_mtp);

if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from Hugging Face\n");
Expand All @@ -368,6 +372,11 @@ static handle_model_result common_params_handle_model(struct common_params_model
result.found_mmproj = true;
result.mmproj.path = download_result.mmproj_path;
}

if (!download_result.mtp_path.empty()) {
result.found_mtp = true;
result.mtp.path = download_result.mtp_path;
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
Expand Down Expand Up @@ -436,7 +445,11 @@ static bool parse_bool_value(const std::string & value) {
//

void common_params_handle_models(common_params & params, llama_example curr_ex) {
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();

auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp);
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
Expand All @@ -450,6 +463,14 @@ void common_params_handle_models(common_params & params, llama_example curr_ex)
break;
}
}
// when --spec-type mtp is set and no draft model was provided explicitly,
// fall back to the MTP head discovered alongside the -hf model
if (spec_type_draft_mtp && res.found_mtp &&
params.speculative.draft.mparams.path.empty() &&
params.speculative.draft.mparams.hf_repo.empty() &&
params.speculative.draft.mparams.url.empty()) {
params.speculative.draft.mparams.path = res.mtp.path;
}
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
}
Expand Down
7 changes: 7 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,12 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
goto done;
}

if (llama_n_rs_seq(ctx) > 0) {
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
goto done;
}

// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the context does not support partial sequence removal\n", __func__);
Expand Down Expand Up @@ -1490,6 +1496,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &

cparams.n_ctx = params.n_ctx;
cparams.n_seq_max = params.n_parallel;
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.cpuparams.n_threads;
Expand Down
17 changes: 14 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <string_view>
#include <vector>
#include <map>
#include <algorithm>

#if defined(_WIN32) && !defined(_WIN32_WINNT)
#define _WIN32_WINNT 0x0A00
Expand Down Expand Up @@ -159,6 +160,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
Expand Down Expand Up @@ -355,6 +357,14 @@ struct common_params_speculative {
bool has_dft() const {
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
}

uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
});

return needs_rs_seq ? draft.n_max : 0u;
}
};

struct common_params_vocoder {
Expand Down Expand Up @@ -871,9 +881,10 @@ std::string common_get_model_endpoint();
//

enum common_context_seq_rm_type {
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
COMMON_CONTEXT_SEQ_RM_TYPE_RS = 3, // can seq_rm partial sequences, bounded by n_rs_seq
};

// check if the llama_context can remove sequences
Expand Down
55 changes: 42 additions & 13 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
return result;
}

static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"),
// preferring deeper shared directory prefix with the model, then closest quantization
static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files,
const std::string & model,
const std::string & keyword) {
hf_cache::hf_file best;
size_t best_depth = 0;
int best_diff = 0;
Expand All @@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,

for (const auto & f : files) {
if (!string_ends_with(f.path, ".gguf") ||
f.path.find("mmproj") == std::string::npos) {
f.path.find(keyword) == std::string::npos) {
continue;
}

auto mmproj_parts = string_split<std::string>(f.path, '/');
auto mmproj_dir = mmproj_parts.end() - 1;
auto sib_parts = string_split<std::string>(f.path, '/');
auto sib_dir = sib_parts.end() - 1;

auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
mmproj_parts.begin(), mmproj_dir);
if (dir != mmproj_dir) {
sib_parts.begin(), sib_dir);
if (dir != sib_dir) {
continue;
}

size_t depth = dir - mmproj_parts.begin();
size_t depth = dir - sib_parts.begin();
auto bits = extract_quant_bits(f.path);
auto diff = std::abs(bits - model_bits);

Expand All @@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
return best;
}

static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
return find_best_sibling(files, model, "mmproj");
}

static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files,
const std::string & model) {
return find_best_sibling(files, model, "mtp-");
}

static bool gguf_filename_is_model(const std::string & filepath) {
if (!string_ends_with(filepath, ".gguf")) {
return false;
Expand All @@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) {
}

return filename.find("mmproj") == std::string::npos &&
filename.find("imatrix") == std::string::npos;
filename.find("imatrix") == std::string::npos &&
filename.find("mtp-") == std::string::npos;
}

static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
Expand Down Expand Up @@ -673,11 +687,13 @@ struct hf_plan {
hf_cache::hf_file primary;
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
};

static hf_plan get_hf_plan(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
bool download_mmproj,
bool download_mtp) {
hf_plan plan;
hf_cache::hf_files all;

Expand Down Expand Up @@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model,
plan.mmproj = find_best_mmproj(all, primary.path);
}

if (download_mtp) {
plan.mtp = find_best_mtp(all, primary.path);
}

return plan;
}

Expand Down Expand Up @@ -756,21 +776,25 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode

common_download_model_result common_download_model(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
bool download_mmproj,
bool download_mtp) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;

bool is_hf = !model.hf_repo.empty();

if (is_hf) {
hf = get_hf_plan(model, opts, download_mmproj);
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
} else {
Expand Down Expand Up @@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}

if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
}
} else {
result.model_path = model.path;
}
Expand Down Expand Up @@ -946,7 +974,8 @@ std::vector<common_cached_model_info> common_list_cached_models() {
for (const auto & f : files) {
auto split = get_gguf_split_info(f.path);
if (split.index != 1 || split.tag.empty() ||
split.prefix.find("mmproj") != std::string::npos) {
split.prefix.find("mmproj") != std::string::npos ||
split.prefix.find("MTP") != std::string::npos) {
continue;
}
if (seen.insert(f.repo_id + ":" + split.tag).second) {
Expand Down
7 changes: 5 additions & 2 deletions common/download.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct common_download_opts {
struct common_download_model_result {
std::string model_path;
std::string mmproj_path;
std::string mtp_path;
};

// Download model from HuggingFace repo or URL
Expand All @@ -83,12 +84,14 @@ struct common_download_model_result {
// when opts.offline=true, no network requests are made
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
// then with the closest quantization bits
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
//
// returns result with model_path and mmproj_path (empty on failure)
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
common_download_model_result common_download_model(
const common_params_model & model,
const common_download_opts & opts = {},
bool download_mmproj = false
bool download_mmproj = false,
bool download_mtp = false
);

// returns list of cached models
Expand Down
Loading
Loading