@@ -962,6 +962,7 @@ struct server_context_impl {
962962 struct load_progress_data {
963963 server_context_impl * ctx;
964964 std::string stage;
965+ std::vector<std::string> stages;
965966 int64_t t_last_load_progress_ms = 0 ;
966967 load_progress_data (server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {}
967968 };
@@ -982,7 +983,8 @@ struct server_context_impl {
982983 }
983984 if (d->ctx ->callback_state ) {
984985 d->ctx ->callback_state (SERVER_STATE_LOADING , {
985- {" stage" , d->stage },
986+ {" stages" , d->stages },
987+ {" current" , d->stage },
986988 {" value" , progress},
987989 });
988990 }
@@ -992,18 +994,42 @@ struct server_context_impl {
992994 // load the model and initialize llama_context
993995 // this may also be called to resume from sleeping state
994996 bool load_model (common_params & params) {
995- load_progress_data load_progress_text (this , " text_model" );
997+ load_progress_data load_progress_text (this , " text_model" );
996998 load_progress_data load_progress_mmproj (this , " mmproj_model" );
999+ load_progress_data load_progress_spec (this , " spec_model" );
9971000
998- bool is_resume = sleeping;
999-
1000- SRV_INF (" loading model '%s'\n " , params.model .path .c_str ());
1001+ const bool is_resume = sleeping;
10011002
10021003 params_base = params;
10031004 params_base.n_outputs_max = server_n_outputs_max (params_base);
10041005
1006+ const bool has_mmproj = !params.mmproj .path .empty ();
1007+ const bool has_draft = params.speculative .has_dft ();
1008+ const bool spec_mtp = std::find (params_base.speculative .types .begin (),
1009+ params_base.speculative .types .end (),
1010+ COMMON_SPECULATIVE_TYPE_DRAFT_MTP ) != params_base.speculative .types .end ();
1011+ const bool has_spec = has_draft || spec_mtp;
1012+
1013+ if (callback_state) {
1014+ std::vector<std::string> stages = {" text_model" };
1015+ if (has_spec) {
1016+ stages.push_back (" spec_model" );
1017+ }
1018+ if (has_mmproj) {
1019+ stages.push_back (" mmproj_model" );
1020+ }
1021+ load_progress_text.stages = stages;
1022+ load_progress_mmproj.stages = stages;
1023+ load_progress_spec.stages = stages;
1024+
1025+ // trigger 0% progress
1026+ load_progress_callback (0 .0f , &load_progress_text);
1027+ }
1028+
1029+
1030+ SRV_INF (" loading model '%s'\n " , params.model .path .c_str ());
1031+
10051032 std::string & mmproj_path = params_base.mmproj .path ;
1006- bool has_mmproj = !mmproj_path.empty ();
10071033 mtmd_context_params mparams = mtmd_context_params_default ();
10081034 if (has_mmproj) {
10091035 mparams.use_gpu = params_base.mmproj_use_gpu ;
@@ -1050,16 +1076,7 @@ struct server_context_impl {
10501076
10511077 // optionally reserve VRAM for the draft / MTP context before fitting the target model
10521078 if (params_base.fit_params ) {
1053- if (callback_state) {
1054- callback_state (SERVER_STATE_LOADING , {{" stage" , " fit_params" }});
1055- }
1056-
1057- const bool spec_mtp = std::find (params_base.speculative .types .begin (),
1058- params_base.speculative .types .end (),
1059- COMMON_SPECULATIVE_TYPE_DRAFT_MTP ) != params_base.speculative .types .end ();
1060- const bool has_draft = params_base.speculative .has_dft ();
1061-
1062- if (has_draft || spec_mtp) {
1079+ if (has_spec) {
10631080 common_params params_dft = params_base;
10641081 bool measure_model_bytes = true ;
10651082
@@ -1151,11 +1168,7 @@ struct server_context_impl {
11511168
11521169 add_bos_token = llama_vocab_get_add_bos (vocab);
11531170
1154- if (params_base.speculative .has_dft ()) {
1155- if (callback_state) {
1156- callback_state (SERVER_STATE_LOADING , {{" stage" , " spec_model" }});
1157- }
1158-
1171+ if (has_draft) {
11591172 // TODO speculative: move to common/speculative.cpp?
11601173 const auto & params_spec = params_base.speculative .draft ;
11611174
@@ -1178,6 +1191,10 @@ struct server_context_impl {
11781191
11791192 auto mparams_dft = common_model_params_to_llama (params_dft);
11801193
1194+ // progress callback
1195+ mparams_dft.progress_callback = load_progress_callback;
1196+ mparams_dft.progress_callback_user_data = &load_progress_spec;
1197+
11811198 model_dft.reset (llama_model_load_from_file (params_dft.model .path .c_str (), mparams_dft));
11821199 if (model_dft == nullptr ) {
11831200 SRV_ERR (" failed to load draft model, '%s'\n " , params_dft.model .path .c_str ());
@@ -1186,10 +1203,6 @@ struct server_context_impl {
11861203
11871204 auto cparams = common_context_params_to_llama (params_dft);
11881205
1189- const bool spec_mtp = std::find (params_base.speculative .types .begin (),
1190- params_base.speculative .types .end (),
1191- COMMON_SPECULATIVE_TYPE_DRAFT_MTP ) != params_base.speculative .types .end ();
1192-
11931206 if (spec_mtp) {
11941207 cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP ;
11951208 }
@@ -1203,8 +1216,10 @@ struct server_context_impl {
12031216
12041217 params_base.speculative .draft .ctx_tgt = ctx_tgt;
12051218 params_base.speculative .draft .ctx_dft = ctx_dft.get ();
1206- } else if (std::find (params_base.speculative .types .begin (), params_base.speculative .types .end (),
1207- COMMON_SPECULATIVE_TYPE_DRAFT_MTP ) != params_base.speculative .types .end ()) {
1219+ } else if (spec_mtp) {
1220+ // no new model load, so we simply report 0.0 and 1.0 progress
1221+ load_progress_callback (0 .0f , &load_progress_spec);
1222+
12081223 SRV_INF (" creating MTP draft context against the target model '%s'\n " ,
12091224 params_base.model .path .c_str ());
12101225
@@ -1224,6 +1239,8 @@ struct server_context_impl {
12241239
12251240 params_base.speculative .draft .ctx_tgt = ctx_tgt;
12261241 params_base.speculative .draft .ctx_dft = ctx_dft.get ();
1242+
1243+ load_progress_callback (1 .0f , &load_progress_spec);
12271244 }
12281245
12291246 if (has_mmproj) {
0 commit comments