Skip to content

Commit d6d8995

Browse files
authored
server: real-time model load progress tracking via /models/sse (#24828)
* server: real-time model load progress tracking via /models/sse * update docs * add mutex for notify_to_router * correct docs
1 parent 8a118ee commit d6d8995

4 files changed

Lines changed: 106 additions & 7 deletions

File tree

tools/server/README.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,9 +1859,33 @@ Example events:
18591859

18601860
{
18611861
"model": "...",
1862-
"event": "download_finished",
1862+
"event": "model_status",
18631863
"data": {
1864-
"status": "loading"
1864+
"status": "loading",
1865+
"progress": {
1866+
"stage": "fit_params",
1867+
"value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value"
1868+
}
1869+
}
1870+
}
1871+
1872+
{
1873+
"model": "...",
1874+
"event": "model_status",
1875+
"data": {
1876+
"status": "loaded",
1877+
"info": {
1878+
// note: only include info on first load
1879+
// waking up from sleep doesn't have this
1880+
}
1881+
}
1882+
}
1883+
1884+
{
1885+
"model": "...",
1886+
"event": "model_status",
1887+
"data": {
1888+
"status": "sleeping"
18651889
}
18661890
}
18671891

tools/server/server-context.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,8 @@ struct server_context_impl {
833833

834834
bool sleeping = false;
835835

836+
int64_t t_last_load_progress_ms = 0;
837+
836838
void destroy() {
837839
spec.reset();
838840
ctx_dft.reset();
@@ -863,6 +865,30 @@ struct server_context_impl {
863865
sleeping = new_state;
864866
}
865867

868+
static bool load_progress_callback(float progress, void * user_data) {
869+
auto * ctx = static_cast<server_context_impl *>(user_data);
870+
GGML_ASSERT(ctx);
871+
// always emit the first and final sample; throttle the rest to one per 200ms
872+
{
873+
auto & t_last = ctx->t_last_load_progress_ms;
874+
const int64_t t_now = ggml_time_ms();
875+
const bool first = t_last == 0;
876+
const bool done = progress >= 1.0f;
877+
const bool throttled = !first && !done && (t_now - t_last) < 200;
878+
if (throttled) {
879+
return true;
880+
}
881+
t_last = t_now;
882+
}
883+
if (ctx->callback_state) {
884+
ctx->callback_state(SERVER_STATE_LOADING, {
885+
{"stage", "text_model"},
886+
{"value", progress},
887+
});
888+
}
889+
return true;
890+
}
891+
866892
// load the model and initialize llama_context
867893
// this may also be called to resume from sleeping state
868894
bool load_model(common_params & params) {
@@ -916,6 +942,10 @@ struct server_context_impl {
916942

917943
// optionally reserve VRAM for the draft / MTP context before fitting the target model
918944
if (params_base.fit_params) {
945+
if (callback_state) {
946+
callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}});
947+
}
948+
919949
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
920950
params_base.speculative.types.end(),
921951
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
@@ -991,6 +1021,13 @@ struct server_context_impl {
9911021
}
9921022
}
9931023

1024+
// attach a progress callback
1025+
{
1026+
t_last_load_progress_ms = 0;
1027+
params_base.load_progress_callback = load_progress_callback;
1028+
params_base.load_progress_callback_user_data = this;
1029+
}
1030+
9941031
llama_init = common_init_from_params(params_base);
9951032

9961033
model_tgt = llama_init->model();
@@ -1008,6 +1045,10 @@ struct server_context_impl {
10081045
add_bos_token = llama_vocab_get_add_bos(vocab);
10091046

10101047
if (params_base.speculative.has_dft()) {
1048+
if (callback_state) {
1049+
callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}});
1050+
}
1051+
10111052
// TODO speculative: move to common/speculative.cpp?
10121053
const auto & params_spec = params_base.speculative.draft;
10131054

@@ -1079,6 +1120,10 @@ struct server_context_impl {
10791120
}
10801121

10811122
if (has_mmproj) {
1123+
if (callback_state) {
1124+
callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}});
1125+
}
1126+
10821127
if (!is_resume) {
10831128
mtmd_helper_log_set(common_log_default_callback, nullptr);
10841129
}
@@ -1259,6 +1304,10 @@ struct server_context_impl {
12591304
return init();
12601305
}
12611306

1307+
if (callback_state) {
1308+
callback_state(SERVER_STATE_READY, {});
1309+
}
1310+
12621311
return true;
12631312
}
12641313

@@ -1335,6 +1384,9 @@ struct server_context_impl {
13351384
const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
13361385
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
13371386

1387+
// IMPORTANT: chat_params is reused across sleeping / resuming states,
1388+
// never store llama_context/llama_model pointers in chat_params,
1389+
// as they may be invalidated after sleeping
13381390
chat_params = {
13391391
/* use_jinja */ params_base.use_jinja,
13401392
/* prefill_assistant */ params_base.prefill_assistant,
@@ -3734,7 +3786,10 @@ struct server_res_generator : server_http_res {
37343786
void server_context::set_state_callback(server_state_callback_t callback) {
37353787
impl->callback_state = std::move(callback);
37363788
impl->queue_tasks.on_sleeping_state([this](bool sleeping) {
3737-
impl->callback_state(sleeping ? SERVER_STATE_SLEEPING : SERVER_STATE_READY, {});
3789+
if (sleeping) {
3790+
impl->callback_state(SERVER_STATE_SLEEPING, {});
3791+
}
3792+
// for sleeping == false, event is emitted by load_model()
37383793
});
37393794
}
37403795

tools/server/server-models.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ void server_models::load_models() {
442442
/* last_used */ 0,
443443
/* args */ std::vector<std::string>(),
444444
/* loaded_info */ {},
445+
/* progress */ {},
445446
/* exit_code */ 0,
446447
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
447448
/* multimodal */ mtmd_caps{false, false},
@@ -608,6 +609,7 @@ void server_models::load_models() {
608609
/* last_used */ 0,
609610
/* args */ std::vector<std::string>(),
610611
/* loaded_info */ {},
612+
/* progress */ {},
611613
/* exit_code */ 0,
612614
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
613615
/* multimodal */ mtmd_caps{false, false},
@@ -1140,6 +1142,9 @@ void server_models::update_status(const std::string & name, const update_status_
11401142
if (!args.loaded_info.is_null()) {
11411143
meta.loaded_info = args.loaded_info;
11421144
}
1145+
if (!args.progress.is_null()) {
1146+
meta.progress = args.progress;
1147+
}
11431148
}
11441149
// broadcast status change to SSE
11451150
{
@@ -1152,6 +1157,9 @@ void server_models::update_status(const std::string & name, const update_status_
11521157
if (!args.loaded_info.is_null()) {
11531158
data["info"] = args.loaded_info;
11541159
}
1160+
if (!args.progress.is_null()) {
1161+
data["progress"] = args.progress;
1162+
}
11551163
// note: notify_sse doesn't acquire the lock, so no deadlock here
11561164
notify_sse("status_change", name, data);
11571165
}
@@ -1322,16 +1330,21 @@ void server_models::handle_child_state(const std::string & name, const std::stri
13221330
switch (state) {
13231331
case SERVER_STATE_LOADING:
13241332
{
1325-
// do nothing for now
1326-
// TODO: report loading progress for first load and wakeup from sleep
1333+
update_status(name, {
1334+
SERVER_MODEL_STATUS_LOADING,
1335+
0,
1336+
nullptr, // no loaded_info yet
1337+
payload,
1338+
});
13271339
} break;
13281340
case SERVER_STATE_READY:
13291341
{
13301342
update_status(name, {
13311343
SERVER_MODEL_STATUS_LOADED,
13321344
0,
13331345
// note: payload can be empty if this is a wakeup from sleep
1334-
payload.size() > 0 ? payload : nullptr
1346+
payload.size() > 0 ? payload : nullptr,
1347+
{}, // reset progress info
13351348
});
13361349
} break;
13371350
case SERVER_STATE_SLEEPING:
@@ -1384,6 +1397,7 @@ void server_child::notify_to_router(const std::string & state, const json & payl
13841397
{"state", state},
13851398
{"payload", payload},
13861399
};
1400+
std::lock_guard<std::mutex> lk(mtx_stdout);
13871401
common_log_pause(common_log_main());
13881402
fflush(stdout);
13891403
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str());

tools/server/server-models.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ struct server_model_meta {
7272
int64_t last_used = 0; // for LRU unloading
7373
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
7474
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
75+
json progress; // reflect load or download progress info, if any
7576
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
7677
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
7778
mtmd_caps multimodal; // multimodal capabilities
@@ -170,12 +171,14 @@ struct server_models {
170171
// to stop the download, call unload()
171172
void download(common_params_model && model, common_download_opts && opts);
172173

173-
// update the status of a model instance (thread-safe)
174174
struct update_status_args {
175175
server_model_status status;
176176
int exit_code = 0; // only valid if status == UNLOADED
177177
json loaded_info = nullptr;
178+
json progress = nullptr;
178179
};
180+
// update the status of a model instance (thread-safe)
181+
// also send SSE notification to /models/sse endpoint
179182
void update_status(const std::string & name, const update_status_args & args);
180183
void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true);
181184

@@ -208,6 +211,9 @@ struct server_models {
208211
};
209212

210213
struct server_child {
214+
// serializes the notify_to_router writes
215+
std::mutex mtx_stdout;
216+
211217
// return true if the current process is a child server instance
212218
bool is_child();
213219

0 commit comments

Comments
 (0)