Skip to content

Commit 7c082bc

Browse files
authored
server: fix report progress for loading spec models, add "stages" list (#24870)
* server: fix report progress for loading spec models, add "stages" list * improve * nits * nits 2
1 parent bddfd2b commit 7c082bc

2 files changed

Lines changed: 50 additions & 29 deletions

File tree

tools/server/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,11 +1863,15 @@ Example events:
18631863
"data": {
18641864
"status": "loading",
18651865
"progress": {
1866-
"stage": "fit_params",
1867-
"value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value"
1866+
"stages": ["text_model", "spec_model", "mmproj_model"],
1867+
"current": "text_model",
1868+
"value": 0.5
18681869
}
18691870
}
18701871
}
1872+
// note for "loading" status:
1873+
// - subsequent events will follow the same order of "stages" list
1874+
// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap
18711875

18721876
{
18731877
"model": "...",

tools/server/server-context.cpp

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,7 @@ struct server_context_impl {
962962
struct load_progress_data {
963963
server_context_impl * ctx;
964964
std::string stage;
965+
std::vector<std::string> stages;
965966
int64_t t_last_load_progress_ms = 0;
966967
load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {}
967968
};
@@ -982,7 +983,8 @@ struct server_context_impl {
982983
}
983984
if (d->ctx->callback_state) {
984985
d->ctx->callback_state(SERVER_STATE_LOADING, {
985-
{"stage", d->stage},
986+
{"stages", d->stages},
987+
{"current", d->stage},
986988
{"value", progress},
987989
});
988990
}
@@ -992,18 +994,42 @@ struct server_context_impl {
992994
// load the model and initialize llama_context
993995
// this may also be called to resume from sleeping state
994996
bool load_model(common_params & params) {
995-
load_progress_data load_progress_text(this, "text_model");
997+
load_progress_data load_progress_text (this, "text_model");
996998
load_progress_data load_progress_mmproj(this, "mmproj_model");
999+
load_progress_data load_progress_spec (this, "spec_model");
9971000

998-
bool is_resume = sleeping;
999-
1000-
SRV_INF("loading model '%s'\n", params.model.path.c_str());
1001+
const bool is_resume = sleeping;
10011002

10021003
params_base = params;
10031004
params_base.n_outputs_max = server_n_outputs_max(params_base);
10041005

1006+
const bool has_mmproj = !params.mmproj.path.empty();
1007+
const bool has_draft = params.speculative.has_dft();
1008+
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
1009+
params_base.speculative.types.end(),
1010+
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
1011+
const bool has_spec = has_draft || spec_mtp;
1012+
1013+
if (callback_state) {
1014+
std::vector<std::string> stages = {"text_model"};
1015+
if (has_spec) {
1016+
stages.push_back("spec_model");
1017+
}
1018+
if (has_mmproj) {
1019+
stages.push_back("mmproj_model");
1020+
}
1021+
load_progress_text.stages = stages;
1022+
load_progress_mmproj.stages = stages;
1023+
load_progress_spec.stages = stages;
1024+
1025+
// trigger 0% progress
1026+
load_progress_callback(0.0f, &load_progress_text);
1027+
}
1028+
1029+
1030+
SRV_INF("loading model '%s'\n", params.model.path.c_str());
1031+
10051032
std::string & mmproj_path = params_base.mmproj.path;
1006-
bool has_mmproj = !mmproj_path.empty();
10071033
mtmd_context_params mparams = mtmd_context_params_default();
10081034
if (has_mmproj) {
10091035
mparams.use_gpu = params_base.mmproj_use_gpu;
@@ -1050,16 +1076,7 @@ struct server_context_impl {
10501076

10511077
// optionally reserve VRAM for the draft / MTP context before fitting the target model
10521078
if (params_base.fit_params) {
1053-
if (callback_state) {
1054-
callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}});
1055-
}
1056-
1057-
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
1058-
params_base.speculative.types.end(),
1059-
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
1060-
const bool has_draft = params_base.speculative.has_dft();
1061-
1062-
if (has_draft || spec_mtp) {
1079+
if (has_spec) {
10631080
common_params params_dft = params_base;
10641081
bool measure_model_bytes = true;
10651082

@@ -1151,11 +1168,7 @@ struct server_context_impl {
11511168

11521169
add_bos_token = llama_vocab_get_add_bos(vocab);
11531170

1154-
if (params_base.speculative.has_dft()) {
1155-
if (callback_state) {
1156-
callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}});
1157-
}
1158-
1171+
if (has_draft) {
11591172
// TODO speculative: move to common/speculative.cpp?
11601173
const auto & params_spec = params_base.speculative.draft;
11611174

@@ -1178,6 +1191,10 @@ struct server_context_impl {
11781191

11791192
auto mparams_dft = common_model_params_to_llama(params_dft);
11801193

1194+
// progress callback
1195+
mparams_dft.progress_callback = load_progress_callback;
1196+
mparams_dft.progress_callback_user_data = &load_progress_spec;
1197+
11811198
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
11821199
if (model_dft == nullptr) {
11831200
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
@@ -1186,10 +1203,6 @@ struct server_context_impl {
11861203

11871204
auto cparams = common_context_params_to_llama(params_dft);
11881205

1189-
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
1190-
params_base.speculative.types.end(),
1191-
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
1192-
11931206
if (spec_mtp) {
11941207
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
11951208
}
@@ -1203,8 +1216,10 @@ struct server_context_impl {
12031216

12041217
params_base.speculative.draft.ctx_tgt = ctx_tgt;
12051218
params_base.speculative.draft.ctx_dft = ctx_dft.get();
1206-
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
1207-
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) {
1219+
} else if (spec_mtp) {
1220+
// no new model load, so we simply report 0.0 and 1.0 progress
1221+
load_progress_callback(0.0f, &load_progress_spec);
1222+
12081223
SRV_INF("creating MTP draft context against the target model '%s'\n",
12091224
params_base.model.path.c_str());
12101225

@@ -1224,6 +1239,8 @@ struct server_context_impl {
12241239

12251240
params_base.speculative.draft.ctx_tgt = ctx_tgt;
12261241
params_base.speculative.draft.ctx_dft = ctx_dft.get();
1242+
1243+
load_progress_callback(1.0f, &load_progress_spec);
12271244
}
12281245

12291246
if (has_mmproj) {

0 commit comments

Comments
 (0)