Skip to content

Commit 75ad0b2

Browse files
authored
server: fix remote preset handling, add test (#24938)
* server: add test for remote preset * fix remote preset handling * fix * fix test
1 parent c926ad0 commit 75ad0b2

7 files changed

Lines changed: 52 additions & 9 deletions

File tree

common/arg.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model
301301
const common_download_opts & opts) {
302302
handle_model_result result;
303303

304+
// TODO @ngxson : refactor this into a new common_model_download_context
305+
304306
if (!model.docker_repo.empty()) {
305307
model.path = common_docker_resolve_model(model.docker_repo);
306308
} else if (!model.hf_repo.empty()) {
@@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) {
396398
// CLI argument parsing functions
397399
//
398400

399-
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
401+
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
400402
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
401403
params.speculative.types.end(),
402404
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
@@ -407,9 +409,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
407409
opts.skip_download = params.skip_download;
408410
opts.download_mtp = spec_type_draft_mtp;
409411
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
412+
opts.preset_only = handle_params.preset_only;
410413

411-
if (callback) {
412-
opts.callback = callback;
414+
if (handle_params.callback) {
415+
opts.callback = handle_params.callback;
413416
}
414417

415418
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
@@ -596,7 +599,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
596599

597600
if (!skip_model_download) {
598601
// handle model and download
599-
common_params_handle_models(params, ctx_arg.ex);
602+
common_params_handle_models(params, ctx_arg.ex, {});
600603

601604
// model is required (except for server)
602605
// TODO @ngxson : maybe show a list of available models in CLI in this case

common/arg.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,19 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
130130
// see: https://github.com/ggml-org/llama.cpp/issues/18163
131131
void common_params_add_preset_options(std::vector<common_arg> & args);
132132

133+
struct common_params_handle_models_params {
134+
common_download_callback * callback = nullptr;
135+
bool preset_only = false; // if true, only check & download remote preset (for router mode)
136+
};
137+
133138
// populate model paths (main model, mmproj, etc) from -hf if necessary
134139
// return true if the model is ready to use
135140
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
136141
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
137142
bool common_params_handle_models(
138143
common_params & params,
139144
llama_example curr_ex,
140-
common_download_callback * callback = nullptr);
145+
const common_params_handle_models_params & handle_params);
141146

142147
// initialize argument parser context - used by test-arg-parser and preset
143148
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

common/download.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,14 +799,16 @@ common_download_model_result common_download_model(const common_params_model &
799799

800800
bool download_mmproj = opts.download_mmproj;
801801
bool download_mtp = opts.download_mtp;
802+
bool preset_only = opts.preset_only;
802803
bool is_hf = !model.hf_repo.empty();
803804

804805
if (is_hf) {
805806
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
806807
if (!hf.preset.path.empty()) {
807808
// if preset.ini exists, only download that file alone
808809
tasks.push_back({hf.preset.url, hf.preset.local_path});
809-
} else {
810+
} else if (!preset_only) {
811+
// only add other files if we're NOT in preset-only mode (normal run, non-router)
810812
for (const auto & f : hf.model_files) {
811813
tasks.push_back({f.url, f.local_path});
812814
}

common/download.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct common_download_opts {
5555
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
5656
bool download_mmproj = false;
5757
bool download_mtp = false;
58+
bool preset_only = false; // if true, only check & download remote preset (for router mode)
5859
common_download_callback * callback = nullptr;
5960
};
6061

tools/server/server-models.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ void server_model_meta::update_caps() {
224224
});
225225
params.offline = true;
226226
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
227-
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
227+
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
228228
if (params.mmproj.path.empty()) {
229229
multimodal = { false, false };
230230
} else {
@@ -1393,7 +1393,9 @@ struct server_download_state : public common_download_callback {
13931393

13941394
bool run(common_params & params) {
13951395
try {
1396-
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
1396+
common_params_handle_models_params p;
1397+
p.callback = this;
1398+
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
13971399
is_ok = true;
13981400
} catch (const std::exception & e) {
13991401
auto model_name = params.model.get_name();

tools/server/server.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,17 @@ int llama_server(int argc, char ** argv) {
8989
llama_backend_init();
9090
llama_numa_init(params.numa);
9191

92+
// note: router mode also accepts -hf remote-preset, so we need to check that first
93+
if (!params.model.hf_repo.empty()) {
94+
try {
95+
common_params_handle_models_params handle_params;
96+
handle_params.preset_only = true;
97+
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params);
98+
} catch (const std::exception & e) {
99+
// ignored for now
100+
}
101+
}
102+
92103
// router server never loads a model and must not touch the GPU
93104
const bool is_router_server = params.model.path.empty()
94105
&& params.model.hf_repo.empty();
@@ -263,7 +274,7 @@ int llama_server(int argc, char ** argv) {
263274
return child.run_download(params);
264275
} else if (!is_router_server) {
265276
// single-model mode (NOT spawned by router)
266-
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
277+
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
267278
}
268279

269280
//

tools/server/tests/unit/test_router.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,25 @@ def test_router_reload_models():
256256
os.remove(preset_path)
257257

258258

259+
def test_router_remote_preset():
260+
global server
261+
server.model_hf_repo = "ggml-org/test-preset-ci"
262+
server.model_hf_file = None
263+
server.offline = False
264+
server.start()
265+
266+
# Should see preset models in GET /models
267+
res = server.make_request("GET", "/models")
268+
assert res.status_code == 200
269+
ids = {item["id"] for item in res.body.get("data", [])}
270+
assert "tinygemma3-preset" in ids
271+
assert "stories260K-test" in ids
272+
273+
# Should be able to load a preset model
274+
model_id = "tinygemma3-preset"
275+
_load_model_and_wait(model_id)
276+
277+
259278
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
260279
MODEL_DOWNLOAD_TIMEOUT = 30
261280

0 commit comments

Comments
 (0)