Skip to content

Commit 7bfe120

Browse files
authored
mtmd, server, common: expose modalities to /v1/models (#22952)
* mtmd, server, common: expose modalities to /v1/models * fix build * rename to mtmd_caps
1 parent 927dada commit 7bfe120

10 files changed

Lines changed: 121 additions & 27 deletions

File tree

common/arg.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,25 @@ static bool parse_bool_value(const std::string & value) {
435435
// CLI argument parsing functions
436436
//
437437

438+
void common_params_handle_models(common_params & params, llama_example curr_ex) {
439+
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
440+
if (params.no_mmproj) {
441+
params.mmproj = {};
442+
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
443+
// optionally, handle mmproj model when -hf is specified
444+
params.mmproj = res.mmproj;
445+
}
446+
// only download mmproj if the current example is using it
447+
for (const auto & ex : mmproj_examples) {
448+
if (curr_ex == ex) {
449+
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
450+
break;
451+
}
452+
}
453+
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
454+
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
455+
}
456+
438457
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
439458
common_params & params = ctx_arg.params;
440459

@@ -588,22 +607,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
588607

589608
// handle model and download
590609
if (!skip_model_download) {
591-
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
592-
if (params.no_mmproj) {
593-
params.mmproj = {};
594-
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
595-
// optionally, handle mmproj model when -hf is specified
596-
params.mmproj = res.mmproj;
597-
}
598-
// only download mmproj if the current example is using it
599-
for (const auto & ex : mmproj_examples) {
600-
if (ctx_arg.ex == ex) {
601-
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
602-
break;
603-
}
604-
}
605-
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
606-
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
610+
common_params_handle_models(params, ctx_arg.ex);
607611
}
608612

609613
// model is required (except for server)

common/arg.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,8 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
129129
// see: https://github.com/ggml-org/llama.cpp/issues/18163
130130
void common_params_add_preset_options(std::vector<common_arg> & args);
131131

132+
// Populate model paths (main model, mmproj, etc) from -hf if necessary
133+
void common_params_handle_models(common_params & params, llama_example curr_ex);
134+
132135
// initialize argument parser context - used by test-arg-parser and preset
133136
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

common/preset.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,13 @@ void common_preset::merge(const common_preset & other) {
163163
}
164164
}
165165

166-
void common_preset::apply_to_params(common_params & params) const {
166+
void common_preset::apply_to_params(common_params & params, const std::set<std::string> & handled_keys) const {
167167
for (const auto & [opt, val] : options) {
168+
if (!handled_keys.empty()) {
169+
if (!opt.env || handled_keys.find(opt.env) == handled_keys.end()) {
170+
continue;
171+
}
172+
}
168173
// apply each option to params
169174
if (opt.handler_string) {
170175
opt.handler_string(params, val);

common/preset.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ struct common_preset {
4343
void merge(const common_preset & other);
4444

4545
// apply preset options to common_params
46-
void apply_to_params(common_params & params) const;
46+
// optionally specify handled_keys to only apply a subset of options (identified by their env), if empty, apply all options
47+
void apply_to_params(common_params & params, const std::set<std::string> & handled_keys = std::set<std::string>()) const;
4748
};
4849

4950
// interface for multiple presets in one file

tools/mtmd/clip.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ struct clip_model_loader {
994994
bool has_audio = false;
995995

996996
// TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
997-
clip_model_loader(const char * fname) : fname(fname) {
997+
clip_model_loader(const char * fname, bool skip_tensors = false) : fname(fname) {
998998
struct ggml_context * meta = nullptr;
999999

10001000
struct gguf_init_params params = {
@@ -1040,7 +1040,7 @@ struct clip_model_loader {
10401040
}
10411041

10421042
// tensors
1043-
{
1043+
if (!skip_tensors) {
10441044
for (int i = 0; i < n_tensors; ++i) {
10451045
const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
10461046
const size_t offset = gguf_get_tensor_offset(ctx_gguf.get(), i);
@@ -2927,6 +2927,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
29272927
return {ctx_vision, ctx_audio};
29282928
}
29292929

2930+
struct clip_cap clip_get_cap(const char * fname) {
2931+
clip_cap res;
2932+
clip_model_loader loader(fname, /* skip_tensors= */ true);
2933+
res.has_vision = loader.has_vision;
2934+
res.has_audio = loader.has_audio;
2935+
return res;
2936+
}
2937+
29302938
struct clip_image_size * clip_image_size_init() {
29312939
struct clip_image_size * load_image_size = new struct clip_image_size();
29322940
load_image_size->width = 448;

tools/mtmd/clip.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,9 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
116116
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
117117
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
118118
bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
119+
120+
struct clip_cap {
121+
bool has_vision;
122+
bool has_audio;
123+
};
124+
struct clip_cap clip_get_cap(const char * fname);

tools/mtmd/mtmd.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,19 @@ void mtmd_log_set(ggml_log_callback log_callback, void * user_data) {
14231423
g_logger_state.log_callback_user_data = user_data;
14241424
}
14251425

1426+
struct mtmd_caps mtmd_get_cap_from_file(const char * fname) {
1427+
try {
1428+
auto tmp = clip_get_cap(fname);
1429+
mtmd_caps cap;
1430+
cap.inp_audio = tmp.has_audio;
1431+
cap.inp_vision = tmp.has_vision;
1432+
return cap;
1433+
} catch (const std::exception & e) {
1434+
LOG_ERR("%s: failed to get capabilities from file '%s': %s\n", __func__, fname, e.what());
1435+
return mtmd_caps{ false, false };
1436+
}
1437+
}
1438+
14261439
//
14271440
// Debugging API (NOT intended for public use)
14281441
//

tools/mtmd/mtmd.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
244244
// If this is not called, or NULL is supplied, everything is output on stderr.
245245
MTMD_API void mtmd_log_set(ggml_log_callback log_callback, void * user_data);
246246

247+
// EXPERIMENTAL API to get mmproj's capabilities without initializing the full context
248+
// This is only intended to be used by llama-server, breaking changes is expected
249+
struct mtmd_caps {
250+
bool inp_vision;
251+
bool inp_audio;
252+
};
253+
MTMD_API struct mtmd_caps mtmd_get_cap_from_file(const char * mmproj_fname);
254+
247255
/////////////////////////////////////////
248256

249257
// test function, to be used in test-mtmd-c-api.c

tools/server/server-models.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,30 @@ void server_model_meta::update_args(common_preset_context & ctx_preset, std::str
161161
args = preset.to_args(bin_path);
162162
}
163163

164+
void server_model_meta::update_caps() {
165+
try {
166+
common_params params;
167+
preset.apply_to_params(params, {
168+
"LLAMA_ARG_MODEL",
169+
"LLAMA_ARG_MODEL_URL",
170+
"LLAMA_ARG_MMPROJ",
171+
"LLAMA_ARG_MMPROJ_URL",
172+
"LLAMA_ARG_HF_REPO",
173+
"LLAMA_ARG_HF_REPO_FILE",
174+
});
175+
params.offline = true; // avoid any unwanted network call during capability detection
176+
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
177+
if (params.mmproj.path.empty()) {
178+
multimodal = { false, false };
179+
} else {
180+
multimodal = mtmd_get_cap_from_file(params.mmproj.path.c_str());
181+
}
182+
} catch (const std::exception & e) {
183+
LOG_WRN("failed to initialize common_params for multimodal capability detection: %s\n", e.what());
184+
multimodal = { false, false };
185+
}
186+
}
187+
164188
//
165189
// server_models
166190
//
@@ -236,6 +260,7 @@ void server_models::add_model(server_model_meta && meta) {
236260
}
237261

238262
meta.update_args(ctx_preset, bin_path); // render args
263+
meta.update_caps();
239264
std::string name = meta.name;
240265
mapping[name] = instance_t{
241266
/* subproc */ std::make_shared<subprocess_s>(),
@@ -346,8 +371,10 @@ void server_models::load_models() {
346371
/* status */ SERVER_MODEL_STATUS_UNLOADED,
347372
/* last_used */ 0,
348373
/* args */ std::vector<std::string>(),
374+
/* loaded_info */ {},
349375
/* exit_code */ 0,
350376
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
377+
/* multimodal */ mtmd_caps{false, false},
351378
};
352379
add_model(std::move(meta));
353380
}
@@ -481,6 +508,7 @@ void server_models::load_models() {
481508

482509
inst.meta.exit_code = 0; // clear failed state so the model can be reloaded
483510
inst.meta.update_args(ctx_preset, bin_path);
511+
inst.meta.update_caps();
484512
}
485513

486514
// add models that are new in this reload
@@ -496,8 +524,10 @@ void server_models::load_models() {
496524
/* status */ SERVER_MODEL_STATUS_UNLOADED,
497525
/* last_used */ 0,
498526
/* args */ std::vector<std::string>(),
527+
/* loaded_info */ {},
499528
/* exit_code */ 0,
500529
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
530+
/* multimodal */ mtmd_caps{false, false},
501531
};
502532
add_model(std::move(meta));
503533
newly_added.push_back(name);
@@ -1206,14 +1236,28 @@ void server_models_routes::init_routes() {
12061236
status["failed"] = true;
12071237
}
12081238

1239+
// pi coding agent multimodal compatibility
1240+
json input_modalities = json::array({"text"});
1241+
if (meta.multimodal.inp_vision) {
1242+
input_modalities.push_back("image");
1243+
}
1244+
if (meta.multimodal.inp_audio) {
1245+
input_modalities.push_back("audio");
1246+
}
1247+
json architecture {
1248+
{"input_modalities", input_modalities},
1249+
{"output_modalities", json::array({"text"})},
1250+
};
1251+
12091252
json model_info = json {
1210-
{"id", meta.name},
1211-
{"aliases", meta.aliases},
1212-
{"tags", meta.tags},
1213-
{"object", "model"}, // for OAI-compat
1214-
{"owned_by", "llamacpp"}, // for OAI-compat
1215-
{"created", t}, // for OAI-compat
1216-
{"status", status},
1253+
{"id", meta.name},
1254+
{"aliases", meta.aliases},
1255+
{"tags", meta.tags},
1256+
{"object", "model"}, // for OAI-compat
1257+
{"owned_by", "llamacpp"}, // for OAI-compat
1258+
{"created", t}, // for OAI-compat
1259+
{"status", status},
1260+
{"architecture", architecture},
12171261
// TODO: add other fields, may require reading GGUF metadata
12181262
};
12191263

tools/server/server-models.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct server_model_meta {
6666
json loaded_info; // info to be reflected via /v1/models endpoint
6767
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
6868
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
69+
mtmd_caps multimodal; // multimodal capabilities
6970

7071
bool is_ready() const {
7172
return status == SERVER_MODEL_STATUS_LOADED;
@@ -80,6 +81,7 @@ struct server_model_meta {
8081
}
8182

8283
void update_args(common_preset_context & ctx_presets, std::string bin_path);
84+
void update_caps();
8385
};
8486

8587
struct subprocess_s;

0 commit comments

Comments
 (0)