Skip to content

Commit 06d26df

Browse files
authored
download: add option to skip_download (ggml-org#23059)
* download: add option to skip_download * fix * fix 2 * if file doesn't exist, respect skip_download flag
1 parent da3f990 commit 06d26df

8 files changed

Lines changed: 126 additions & 83 deletions

File tree

common/arg.cpp

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,7 @@ struct handle_model_result {
340340
};
341341

342342
static handle_model_result common_params_handle_model(struct common_params_model & model,
343-
const std::string & bearer_token,
344-
bool offline,
345-
bool search_mtp = false) {
343+
const common_download_opts & opts) {
346344
handle_model_result result;
347345

348346
if (!model.docker_repo.empty()) {
@@ -354,10 +352,9 @@ static handle_model_result common_params_handle_model(struct common_params_model
354352
model.hf_file = model.path;
355353
model.path = "";
356354
}
357-
common_download_opts opts;
358-
opts.bearer_token = bearer_token;
359-
opts.offline = offline;
360-
auto download_result = common_download_model(model, opts, true, search_mtp);
355+
common_download_opts hf_opts = opts;
356+
hf_opts.download_mmproj = true; // also look for mmproj when downloading hf model
357+
auto download_result = common_download_model(model, hf_opts);
361358

362359
if (download_result.model_path.empty()) {
363360
throw std::runtime_error("failed to download model from Hugging Face");
@@ -382,9 +379,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
382379
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
383380
}
384381

385-
common_download_opts opts;
386-
opts.bearer_token = bearer_token;
387-
opts.offline = offline;
388382
auto download_result = common_download_model(model, opts);
389383
if (download_result.model_path.empty()) {
390384
throw std::runtime_error("failed to download model from " + model.url);
@@ -441,35 +435,49 @@ static bool parse_bool_value(const std::string & value) {
441435
// CLI argument parsing functions
442436
//
443437

444-
void common_params_handle_models(common_params & params, llama_example curr_ex) {
438+
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
445439
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
446440
params.speculative.types.end(),
447441
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
448442

449-
auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp);
450-
if (params.no_mmproj) {
451-
params.mmproj = {};
452-
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
453-
// optionally, handle mmproj model when -hf is specified
454-
params.mmproj = res.mmproj;
455-
}
456-
// only download mmproj if the current example is using it
457-
for (const auto & ex : mmproj_examples) {
458-
if (curr_ex == ex) {
459-
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
460-
break;
443+
common_download_opts opts;
444+
opts.bearer_token = params.hf_token;
445+
opts.offline = params.offline;
446+
opts.skip_download = params.skip_download;
447+
opts.download_mtp = spec_type_draft_mtp;
448+
449+
try {
450+
auto res = common_params_handle_model(params.model, opts);
451+
if (params.no_mmproj) {
452+
params.mmproj = {};
453+
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
454+
// optionally, handle mmproj model when -hf is specified
455+
params.mmproj = res.mmproj;
456+
}
457+
// only download mmproj if the current example is using it
458+
for (const auto & ex : mmproj_examples) {
459+
if (curr_ex == ex) {
460+
common_params_handle_model(params.mmproj, opts);
461+
break;
462+
}
461463
}
464+
465+
// when --spec-type mtp is set and no draft model was provided explicitly,
466+
// fall back to the MTP head discovered alongside the -hf model
467+
if (spec_type_draft_mtp && res.found_mtp &&
468+
params.speculative.draft.mparams.path.empty() &&
469+
params.speculative.draft.mparams.hf_repo.empty() &&
470+
params.speculative.draft.mparams.url.empty()) {
471+
params.speculative.draft.mparams.path = res.mtp.path;
472+
}
473+
common_params_handle_model(params.speculative.draft.mparams, opts);
474+
common_params_handle_model(params.vocoder.model, opts);
475+
return true;
476+
} catch (const common_skip_download_exception &) {
477+
return false;
478+
} catch (const std::exception &) {
479+
throw;
462480
}
463-
// when --spec-type mtp is set and no draft model was provided explicitly,
464-
// fall back to the MTP head discovered alongside the -hf model
465-
if (spec_type_draft_mtp && res.found_mtp &&
466-
params.speculative.draft.mparams.path.empty() &&
467-
params.speculative.draft.mparams.hf_repo.empty() &&
468-
params.speculative.draft.mparams.url.empty()) {
469-
params.speculative.draft.mparams.path = res.mtp.path;
470-
}
471-
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
472-
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
473481
}
474482

475483
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {

common/arg.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
129129
// see: https://github.com/ggml-org/llama.cpp/issues/18163
130130
void common_params_add_preset_options(std::vector<common_arg> & args);
131131

132-
// Populate model paths (main model, mmproj, etc) from -hf if necessary
133-
void common_params_handle_models(common_params & params, llama_example curr_ex);
132+
// populate model paths (main model, mmproj, etc) from -hf if necessary
133+
// return true if the model is ready to use
134+
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
135+
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
136+
bool common_params_handle_models(common_params & params, llama_example curr_ex);
134137

135138
// initialize argument parser context - used by test-arg-parser and preset
136139
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ struct common_params {
479479

480480
std::set<std::string> model_alias; // model aliases // NOLINT
481481
std::set<std::string> model_tags; // model tags (informational, not used for routing) // NOLINT
482-
std::string hf_token = ""; // HF token // NOLINT
482+
std::string hf_token = ""; // HF token (aka bearer token) // NOLINT
483483
std::string prompt = ""; // NOLINT
484484
std::string system_prompt = ""; // NOLINT
485485
std::string prompt_file = ""; // store the external prompt file name // NOLINT
@@ -507,6 +507,7 @@ struct common_params {
507507
int32_t control_vector_layer_start = -1; // layer range for control vector
508508
int32_t control_vector_layer_end = -1; // layer range for control vector
509509
bool offline = false;
510+
bool skip_download = false; // skip model file downloading
510511

511512
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
512513
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line

common/download.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,10 @@ static int common_download_file_single_online(const std::string & url,
292292

293293
const bool file_exists = std::filesystem::exists(path);
294294

295+
if (!file_exists && opts.skip_download) {
296+
return -2; // file is missing and download is disabled
297+
}
298+
295299
if (file_exists && skip_etag) {
296300
LOG_DBG("%s: using cached file: %s\n", __func__, path.c_str());
297301
return 304; // 304 Not Modified - fake cached response
@@ -357,6 +361,10 @@ static int common_download_file_single_online(const std::string & url,
357361
LOG_DBG("%s: using cached file (same etag): %s\n", __func__, path.c_str());
358362
return 304; // 304 Not Modified - fake cached response
359363
}
364+
// pass this point, the file exists but is different from the server version, so we need to redownload it
365+
if (opts.skip_download) {
366+
return -2; // special code to indicate that the download was skipped due to etag mismatch
367+
}
360368
if (remove(path.c_str()) != 0) {
361369
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
362370
return -1;
@@ -775,13 +783,13 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
775783
}
776784

777785
common_download_model_result common_download_model(const common_params_model & model,
778-
const common_download_opts & opts,
779-
bool download_mmproj,
780-
bool download_mtp) {
786+
const common_download_opts & opts) {
781787
common_download_model_result result;
782788
std::vector<download_task> tasks;
783789
hf_plan hf;
784790

791+
bool download_mmproj = opts.download_mmproj;
792+
bool download_mtp = opts.download_mtp;
785793
bool is_hf = !model.hf_repo.empty();
786794

787795
if (is_hf) {
@@ -806,18 +814,22 @@ common_download_model_result common_download_model(const common_params_model &
806814
return result;
807815
}
808816

809-
std::vector<std::future<bool>> futures;
817+
std::vector<std::future<int>> futures;
810818
for (const auto & task : tasks) {
811819
futures.push_back(std::async(std::launch::async,
812820
[&task, &opts, is_hf]() {
813-
int status = common_download_file_single(task.url, task.path, opts, is_hf);
814-
return is_http_status_ok(status);
821+
return common_download_file_single(task.url, task.path, opts, is_hf);
815822
}
816823
));
817824
}
818825

819826
for (auto & f : futures) {
820-
if (!f.get()) {
827+
int status = f.get();
828+
if (status == -2 && opts.skip_download) {
829+
throw common_skip_download_exception();
830+
}
831+
bool is_ok = is_http_status_ok(status);
832+
if (!is_ok) {
821833
return {};
822834
}
823835
}

common/download.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ struct common_download_opts {
5252
std::string bearer_token;
5353
common_header_list headers;
5454
bool offline = false;
55+
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
56+
bool download_mmproj = false;
57+
bool download_mtp = false;
5558
common_download_callback * callback = nullptr;
5659
};
5760

@@ -62,6 +65,11 @@ struct common_download_model_result {
6265
std::string mtp_path;
6366
};
6467

68+
// throw if the file is missing or invalid (e.g. ETag check failed)
69+
struct common_skip_download_exception : public std::runtime_error {
70+
common_skip_download_exception() : std::runtime_error("skip download") {}
71+
};
72+
6573
// Download model from HuggingFace repo or URL
6674
//
6775
// input (via model struct):
@@ -89,16 +97,15 @@ struct common_download_model_result {
8997
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
9098
common_download_model_result common_download_model(
9199
const common_params_model & model,
92-
const common_download_opts & opts = {},
93-
bool download_mmproj = false,
94-
bool download_mtp = false
100+
const common_download_opts & opts = {}
95101
);
96102

97103
// returns list of cached models
98104
std::vector<common_cached_model_info> common_list_cached_models();
99105

100106
// download single file from url to local path
101107
// returns status code or -1 on error
108+
// returns -2 if the download was skipped due to ETag mismatch (file outdated, skip_download=true)
102109
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
103110
int common_download_file_single(const std::string & url,
104111
const std::string & path,

tools/server/README.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,23 +1661,30 @@ Listing all models in cache. The model metadata will also include a field to ind
16611661
{
16621662
"data": [{
16631663
"id": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M",
1664-
"in_cache": true,
16651664
"path": "/Users/REDACTED/Library/Caches/llama.cpp/ggml-org_gemma-3-4b-it-GGUF_gemma-3-4b-it-Q4_K_M.gguf",
16661665
"status": {
16671666
"value": "loaded",
16681667
"args": ["llama-server", "-ctx", "4096"]
16691668
},
1669+
"architecture": {
1670+
"input_modalities": [
1671+
"text",
1672+
"image"
1673+
],
1674+
"output_modalities": [
1675+
"text"
1676+
]
1677+
},
16701678
...
16711679
}]
16721680
}
16731681
```
16741682

16751683
Note:
1676-
1. For a local GGUF (stored offline in a custom directory), the model object will have `"in_cache": false`.
1677-
2. Adding `?reload=1` to the query params will refresh the list of models. The behavior is as follow:
1684+
1. Adding `?reload=1` to the query params will refresh the list of models. The behavior is as follow:
16781685
- If a model is running but updated or removed from the source, it will be unloaded
16791686
- If a model is not running, it will be added or updated according to the source
1680-
3. When the model is loaded, the info from `/v1/models` is forwarded to router's `/v1/models`. This includes metadata about the model and the runtime instance.
1687+
2. When the model is loaded, the info from `/v1/models` is forwarded to router's `/v1/models`. This includes metadata about the model and the runtime instance.
16811688

16821689
The `status` object can be:
16831690

tools/server/server-models.cpp

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ void server_model_meta::update_caps() {
180180
"LLAMA_ARG_HF_REPO",
181181
"LLAMA_ARG_HF_REPO_FILE",
182182
});
183-
params.offline = true; // avoid any unwanted network call during capability detection
183+
params.offline = true;
184+
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
184185
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
185186
if (params.mmproj.path.empty()) {
186187
multimodal = { false, false };
@@ -371,18 +372,19 @@ void server_models::load_models() {
371372
// FIRST LOAD: add all models, then unlock for autoloading
372373
for (const auto & [name, preset] : final_presets) {
373374
server_model_meta meta{
374-
/* preset */ preset,
375-
/* name */ name,
376-
/* aliases */ {},
377-
/* tags */ {},
378-
/* port */ 0,
379-
/* status */ SERVER_MODEL_STATUS_UNLOADED,
380-
/* last_used */ 0,
381-
/* args */ std::vector<std::string>(),
382-
/* loaded_info */ {},
383-
/* exit_code */ 0,
384-
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
385-
/* multimodal */ mtmd_caps{false, false},
375+
/* preset */ preset,
376+
/* name */ name,
377+
/* aliases */ {},
378+
/* tags */ {},
379+
/* port */ 0,
380+
/* status */ SERVER_MODEL_STATUS_UNLOADED,
381+
/* last_used */ 0,
382+
/* args */ std::vector<std::string>(),
383+
/* loaded_info */ {},
384+
/* exit_code */ 0,
385+
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
386+
/* multimodal */ mtmd_caps{false, false},
387+
/* need_download */ false,
386388
};
387389
add_model(std::move(meta));
388390
}
@@ -524,18 +526,19 @@ void server_models::load_models() {
524526
for (const auto & [name, preset] : final_presets) {
525527
if (mapping.find(name) == mapping.end()) {
526528
server_model_meta meta{
527-
/* preset */ preset,
528-
/* name */ name,
529-
/* aliases */ {},
530-
/* tags */ {},
531-
/* port */ 0,
532-
/* status */ SERVER_MODEL_STATUS_UNLOADED,
533-
/* last_used */ 0,
534-
/* args */ std::vector<std::string>(),
535-
/* loaded_info */ {},
536-
/* exit_code */ 0,
537-
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
538-
/* multimodal */ mtmd_caps{false, false},
529+
/* preset */ preset,
530+
/* name */ name,
531+
/* aliases */ {},
532+
/* tags */ {},
533+
/* port */ 0,
534+
/* status */ SERVER_MODEL_STATUS_UNLOADED,
535+
/* last_used */ 0,
536+
/* args */ std::vector<std::string>(),
537+
/* loaded_info */ {},
538+
/* exit_code */ 0,
539+
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
540+
/* multimodal */ mtmd_caps{false, false},
541+
/* need_download */ false,
539542
};
540543
add_model(std::move(meta));
541544
newly_added.push_back(name);
@@ -1263,14 +1266,15 @@ void server_models_routes::init_routes() {
12631266
};
12641267

12651268
json model_info = json {
1266-
{"id", meta.name},
1267-
{"aliases", meta.aliases},
1268-
{"tags", meta.tags},
1269-
{"object", "model"}, // for OAI-compat
1270-
{"owned_by", "llamacpp"}, // for OAI-compat
1271-
{"created", t}, // for OAI-compat
1272-
{"status", status},
1273-
{"architecture", architecture},
1269+
{"id", meta.name},
1270+
{"aliases", meta.aliases},
1271+
{"tags", meta.tags},
1272+
{"object", "model"}, // for OAI-compat
1273+
{"owned_by", "llamacpp"}, // for OAI-compat
1274+
{"created", t}, // for OAI-compat
1275+
{"status", status},
1276+
{"architecture", architecture},
1277+
{"need_download", meta.need_download},
12741278
// TODO: add other fields, may require reading GGUF metadata
12751279
};
12761280

tools/server/server-models.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ struct server_model_meta {
6767
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
6868
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
6969
mtmd_caps multimodal; // multimodal capabilities
70+
bool need_download = false; // whether the model needs to be downloaded before loading
7071

7172
bool is_ready() const {
7273
return status == SERVER_MODEL_STATUS_LOADED;

0 commit comments

Comments
 (0)