@@ -340,9 +340,7 @@ struct handle_model_result {
340340};
341341
342342static handle_model_result common_params_handle_model (struct common_params_model & model,
343- const std::string & bearer_token,
344- bool offline,
345- bool search_mtp = false ) {
343+ const common_download_opts & opts) {
346344 handle_model_result result;
347345
348346 if (!model.docker_repo .empty ()) {
@@ -354,10 +352,9 @@ static handle_model_result common_params_handle_model(struct common_params_model
354352 model.hf_file = model.path ;
355353 model.path = " " ;
356354 }
357- common_download_opts opts;
358- opts.bearer_token = bearer_token;
359- opts.offline = offline;
360- auto download_result = common_download_model (model, opts, true , search_mtp);
355+ common_download_opts hf_opts = opts;
356+ hf_opts.download_mmproj = true ; // also look for mmproj when downloading hf model
357+ auto download_result = common_download_model (model, hf_opts);
361358
362359 if (download_result.model_path .empty ()) {
363360 throw std::runtime_error (" failed to download model from Hugging Face" );
@@ -382,9 +379,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
382379 model.path = fs_get_cache_file (string_split<std::string>(f, ' /' ).back ());
383380 }
384381
385- common_download_opts opts;
386- opts.bearer_token = bearer_token;
387- opts.offline = offline;
388382 auto download_result = common_download_model (model, opts);
389383 if (download_result.model_path .empty ()) {
390384 throw std::runtime_error (" failed to download model from " + model.url );
@@ -441,35 +435,49 @@ static bool parse_bool_value(const std::string & value) {
441435// CLI argument parsing functions
442436//
443437
444- void common_params_handle_models (common_params & params, llama_example curr_ex) {
438+ bool common_params_handle_models (common_params & params, llama_example curr_ex) {
445439 const bool spec_type_draft_mtp = std::find (params.speculative .types .begin (),
446440 params.speculative .types .end (),
447441 COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative .types .end ();
448442
449- auto res = common_params_handle_model (params.model , params.hf_token , params.offline , spec_type_draft_mtp);
450- if (params.no_mmproj ) {
451- params.mmproj = {};
452- } else if (res.found_mmproj && params.mmproj .path .empty () && params.mmproj .url .empty ()) {
453- // optionally, handle mmproj model when -hf is specified
454- params.mmproj = res.mmproj ;
455- }
456- // only download mmproj if the current example is using it
457- for (const auto & ex : mmproj_examples) {
458- if (curr_ex == ex) {
459- common_params_handle_model (params.mmproj , params.hf_token , params.offline );
460- break ;
443+ common_download_opts opts;
444+ opts.bearer_token = params.hf_token ;
445+ opts.offline = params.offline ;
446+ opts.skip_download = params.skip_download ;
447+ opts.download_mtp = spec_type_draft_mtp;
448+
449+ try {
450+ auto res = common_params_handle_model (params.model , opts);
451+ if (params.no_mmproj ) {
452+ params.mmproj = {};
453+ } else if (res.found_mmproj && params.mmproj .path .empty () && params.mmproj .url .empty ()) {
454+ // optionally, handle mmproj model when -hf is specified
455+ params.mmproj = res.mmproj ;
456+ }
457+ // only download mmproj if the current example is using it
458+ for (const auto & ex : mmproj_examples) {
459+ if (curr_ex == ex) {
460+ common_params_handle_model (params.mmproj , opts);
461+ break ;
462+ }
461463 }
464+
465+ // when --spec-type mtp is set and no draft model was provided explicitly,
466+ // fall back to the MTP head discovered alongside the -hf model
467+ if (spec_type_draft_mtp && res.found_mtp &&
468+ params.speculative .draft .mparams .path .empty () &&
469+ params.speculative .draft .mparams .hf_repo .empty () &&
470+ params.speculative .draft .mparams .url .empty ()) {
471+ params.speculative .draft .mparams .path = res.mtp .path ;
472+ }
473+ common_params_handle_model (params.speculative .draft .mparams , opts);
474+ common_params_handle_model (params.vocoder .model , opts);
475+ return true ;
476+ } catch (const common_skip_download_exception &) {
477+ return false ;
478+ } catch (const std::exception &) {
479+ throw ;
462480 }
463- // when --spec-type mtp is set and no draft model was provided explicitly,
464- // fall back to the MTP head discovered alongside the -hf model
465- if (spec_type_draft_mtp && res.found_mtp &&
466- params.speculative .draft .mparams .path .empty () &&
467- params.speculative .draft .mparams .hf_repo .empty () &&
468- params.speculative .draft .mparams .url .empty ()) {
469- params.speculative .draft .mparams .path = res.mtp .path ;
470- }
471- common_params_handle_model (params.speculative .draft .mparams , params.hf_token , params.offline );
472- common_params_handle_model (params.vocoder .model , params.hf_token , params.offline );
473481}
474482
475483static bool common_params_parse_ex (int argc, char ** argv, common_params_context & ctx_arg) {
0 commit comments