@@ -308,7 +308,8 @@ static bool common_download_head(CURL * curl,
308308// download one single file from remote URL to local path
309309static bool common_download_file_single_online (const std::string & url,
310310 const std::string & path,
311- const std::string & bearer_token) {
311+ const std::string & bearer_token,
312+ const common_header_list & custom_headers) {
312313 static const int max_attempts = 3 ;
313314 static const int retry_delay_seconds = 2 ;
314315 for (int i = 0 ; i < max_attempts; ++i) {
@@ -330,6 +331,11 @@ static bool common_download_file_single_online(const std::string & url,
330331 common_load_model_from_url_headers headers;
331332 curl_easy_setopt (curl.get (), CURLOPT_HEADERDATA, &headers);
332333 curl_slist_ptr http_headers;
334+
335+ for (const auto & h : custom_headers) {
336+ std::string s = h.first + " : " + h.second ;
337+ http_headers.ptr = curl_slist_append (http_headers.ptr , s.c_str ());
338+ }
333339 const bool was_perform_successful = common_download_head (curl.get (), http_headers, url, bearer_token);
334340 if (!was_perform_successful) {
335341 head_request_ok = false ;
@@ -454,8 +460,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
454460 curl_easy_setopt (curl.get (), CURLOPT_MAXFILESIZE, params.max_size );
455461 }
456462 http_headers.ptr = curl_slist_append (http_headers.ptr , " User-Agent: llama-cpp" );
463+
457464 for (const auto & header : params.headers ) {
458- http_headers.ptr = curl_slist_append (http_headers.ptr , header.c_str ());
465+ std::string header_ = header.first + " : " + header.second ;
466+ http_headers.ptr = curl_slist_append (http_headers.ptr , header_.c_str ());
459467 }
460468 curl_easy_setopt (curl.get (), CURLOPT_HTTPHEADER, http_headers.ptr );
461469
@@ -619,7 +627,8 @@ static bool common_pull_file(httplib::Client & cli,
619627// download one single file from remote URL to local path
620628static bool common_download_file_single_online (const std::string & url,
621629 const std::string & path,
622- const std::string & bearer_token) {
630+ const std::string & bearer_token,
631+ const common_header_list & custom_headers) {
623632 static const int max_attempts = 3 ;
624633 static const int retry_delay_seconds = 2 ;
625634
@@ -629,6 +638,9 @@ static bool common_download_file_single_online(const std::string & url,
629638 if (!bearer_token.empty ()) {
630639 default_headers.insert ({" Authorization" , " Bearer " + bearer_token});
631640 }
641+ for (const auto & h : custom_headers) {
642+ default_headers.emplace (h.first , h.second );
643+ }
632644 cli.set_default_headers (default_headers);
633645
634646 const bool file_exists = std::filesystem::exists (path);
@@ -734,13 +746,9 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
734746 auto [cli, parts] = common_http_client (url);
735747
736748 httplib::Headers headers = {{" User-Agent" , " llama-cpp" }};
749+
737750 for (const auto & header : params.headers ) {
738- size_t pos = header.find (' :' );
739- if (pos != std::string::npos) {
740- headers.emplace (header.substr (0 , pos), header.substr (pos + 1 ));
741- } else {
742- headers.emplace (header, " " );
743- }
751+ headers.emplace (header.first , header.second );
744752 }
745753
746754 if (params.timeout > 0 ) {
@@ -772,9 +780,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
772780static bool common_download_file_single (const std::string & url,
773781 const std::string & path,
774782 const std::string & bearer_token,
775- bool offline) {
783+ bool offline,
784+ const common_header_list & headers) {
776785 if (!offline) {
777- return common_download_file_single_online (url, path, bearer_token);
786+ return common_download_file_single_online (url, path, bearer_token, headers );
778787 }
779788
780789 if (!std::filesystem::exists (path)) {
@@ -788,13 +797,24 @@ static bool common_download_file_single(const std::string & url,
788797
789798// download multiple files from remote URLs to local paths
790799// the input is a vector of pairs <url, path>
791- static bool common_download_file_multiple (const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
800+ static bool common_download_file_multiple (const std::vector<std::pair<std::string, std::string>> & urls,
801+ const std::string & bearer_token,
802+ bool offline,
803+ const common_header_list & headers) {
792804 // Prepare download in parallel
793805 std::vector<std::future<bool >> futures_download;
806+ futures_download.reserve (urls.size ());
807+
794808 for (auto const & item : urls) {
795- futures_download.push_back (std::async (std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
796- return common_download_file_single (it.first , it.second , bearer_token, offline);
797- }, item));
809+ futures_download.push_back (
810+ std::async (
811+ std::launch::async,
812+ [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
813+ return common_download_file_single (it.first , it.second , bearer_token, offline, headers);
814+ },
815+ item
816+ )
817+ );
798818 }
799819
800820 // Wait for all downloads to complete
@@ -807,17 +827,17 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
807827 return true ;
808828}
809829
810- bool common_download_model (
811- const common_params_model & model ,
812- const std::string & bearer_token ,
813- bool offline ) {
830+ bool common_download_model (const common_params_model & model,
831+ const std::string & bearer_token ,
832+ bool offline ,
833+ const common_header_list & headers ) {
814834 // Basic validation of the model.url
815835 if (model.url .empty ()) {
816836 LOG_ERR (" %s: invalid model url\n " , __func__);
817837 return false ;
818838 }
819839
820- if (!common_download_file_single (model.url , model.path , bearer_token, offline)) {
840+ if (!common_download_file_single (model.url , model.path , bearer_token, offline, headers )) {
821841 return false ;
822842 }
823843
@@ -876,13 +896,16 @@ bool common_download_model(
876896 }
877897
878898 // Download in parallel
879- common_download_file_multiple (urls, bearer_token, offline);
899+ common_download_file_multiple (urls, bearer_token, offline, headers );
880900 }
881901
882902 return true ;
883903}
884904
885- common_hf_file_res common_get_hf_file (const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) {
905+ common_hf_file_res common_get_hf_file (const std::string & hf_repo_with_tag,
906+ const std::string & bearer_token,
907+ bool offline,
908+ const common_header_list & custom_headers) {
886909 auto parts = string_split<std::string>(hf_repo_with_tag, ' :' );
887910 std::string tag = parts.size () > 1 ? parts.back () : " latest" ;
888911 std::string hf_repo = parts[0 ];
@@ -893,10 +916,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
893916 std::string url = get_model_endpoint () + " v2/" + hf_repo + " /manifests/" + tag;
894917
895918 // headers
896- std::vector<std::string> headers;
897- headers.push_back (" Accept: application/json" );
919+ common_header_list headers = custom_headers ;
920+ headers.push_back ({ " Accept" , " application/json" } );
898921 if (!bearer_token.empty ()) {
899- headers.push_back (" Authorization: Bearer " + bearer_token);
922+ headers.push_back ({ " Authorization" , " Bearer " + bearer_token} );
900923 }
901924 // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
902925 // User-Agent header is already set in common_remote_get_content, no need to set it here
@@ -1031,9 +1054,10 @@ std::string common_docker_resolve_model(const std::string & docker) {
10311054 const std::string url_prefix = " https://registry-1.docker.io/v2/" + repo;
10321055 std::string manifest_url = url_prefix + " /manifests/" + tag;
10331056 common_remote_params manifest_params;
1034- manifest_params.headers .push_back (" Authorization: Bearer " + token);
1035- manifest_params.headers .push_back (
1036- " Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json" );
1057+ manifest_params.headers .push_back ({" Authorization" , " Bearer " + token});
1058+ manifest_params.headers .push_back ({" Accept" ,
1059+ " application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
1060+ });
10371061 auto manifest_res = common_remote_get_content (manifest_url, manifest_params);
10381062 if (manifest_res.first != 200 ) {
10391063 throw std::runtime_error (" Failed to get Docker manifest, HTTP code: " + std::to_string (manifest_res.first ));
@@ -1070,7 +1094,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
10701094 std::string local_path = fs_get_cache_file (model_filename);
10711095
10721096 const std::string blob_url = url_prefix + " /blobs/" + gguf_digest;
1073- if (!common_download_file_single (blob_url, local_path, token, false )) {
1097+ if (!common_download_file_single (blob_url, local_path, token, false , {} )) {
10741098 throw std::runtime_error (" Failed to download Docker Model" );
10751099 }
10761100
@@ -1084,11 +1108,11 @@ std::string common_docker_resolve_model(const std::string & docker) {
10841108
10851109#else
10861110
1087- common_hf_file_res common_get_hf_file (const std::string &, const std::string &, bool ) {
1111+ common_hf_file_res common_get_hf_file (const std::string &, const std::string &, bool , const common_header_list & ) {
10881112 throw std::runtime_error (" download functionality is not enabled in this build" );
10891113}
10901114
1091- bool common_download_model (const common_params_model &, const std::string &, bool ) {
1115+ bool common_download_model (const common_params_model &, const std::string &, bool , const common_header_list & ) {
10921116 throw std::runtime_error (" download functionality is not enabled in this build" );
10931117}
10941118
0 commit comments