diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 7aedb9df683..1db615b832a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -47,6 +47,10 @@ if (WHISPER_COMMON_FFMPEG) endif() +# add json lib (used by the HF cache/download subsystem) +add_library(json_cpp INTERFACE) +target_include_directories(json_cpp INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) + add_library(${TARGET} STATIC common.h common.cpp @@ -56,12 +60,30 @@ add_library(${TARGET} STATIC common-whisper.cpp grammar-parser.h grammar-parser.cpp + http.h + hf-cache.h + hf-cache.cpp ${COMMON_SOURCES_FFMPEG} ) include(DefaultTargetOptions) -target_link_libraries(${TARGET} PRIVATE whisper ${COMMON_EXTRA_LIBS} ${CMAKE_DL_LIBS}) +# the ported HF cache subsystem (hf-cache.cpp) uses std::filesystem / std::string_view +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +# vendored cpp-httplib header lives under examples/server/ (used by http.h) +target_include_directories(${TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/server) + +# HTTPS support for the HF download path (cpp-httplib + OpenSSL). Off by default; +# when OFF an https:// attempt prints a rebuild hint (see http.h). +option(WHISPER_OPENSSL "whisper: enable OpenSSL for HTTPS HuggingFace downloads" OFF) +if (WHISPER_OPENSSL) + find_package(OpenSSL REQUIRED) + target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT) + target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto) +endif() + +target_link_libraries(${TARGET} PRIVATE whisper json_cpp ${COMMON_EXTRA_LIBS} ${CMAKE_DL_LIBS}) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(${TARGET} PROPERTIES FOLDER "libs") @@ -85,10 +107,6 @@ if (WHISPER_SDL2) set_target_properties(${TARGET} PROPERTIES FOLDER "libs") endif() -# add json lib -add_library(json_cpp INTERFACE) -target_include_directories(json_cpp INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) - # examples include_directories(${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index e505bf0e18d..f5165320a6e 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -85,6 +85,8 @@ struct whisper_params { std::string prompt; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; std::string model = "models/ggml-base.en.bin"; + std::string hf_repo; + std::string hf_file; std::string grammar; std::string grammar_rule; @@ -199,6 +201,8 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; } else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; } else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-hf" || arg == "--hf-repo") { params.hf_repo = ARGV_NEXT; } + else if (arg == "-hff" || arg == "--hf-file") { params.hf_file = ARGV_NEXT; } else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; } else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } @@ -282,6 +286,8 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false"); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -hf REPO, --hf-repo REPO [%-7s] HuggingFace repo (org/repo) to resolve from cache\n", params.hf_repo.c_str()); + fprintf(stderr, " -hff FILE, --hf-file FILE [%-7s] file within the HuggingFace repo (e.g. ggml-base.en.bin)\n", params.hf_file.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", ""); fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); @@ -1070,6 +1076,15 @@ int main(int argc, char ** argv) { } } + // resolve HF repo-id -> cached model path if -hf given and -m was left at its default + if (!params.hf_repo.empty() && params.model == "models/ggml-base.en.bin") { + params.model = whisper_hf_resolve_model(params.hf_repo, params.hf_file); + if (params.model.empty()) { + // whisper_hf_resolve_model prints a specific diagnostic for every failure mode + return 3; + } + } + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (ctx == nullptr) { diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index b12481c013f..96ca7b66db1 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -3,6 +3,7 @@ #include "common-whisper.h" #include "common.h" +#include "hf-cache.h" #include "whisper.h" @@ -31,7 +32,10 @@ #include #endif +#include +#include #include +#include #include #ifdef WHISPER_COMMON_FFMPEG @@ -243,5 +247,113 @@ bool speak_with_file(const std::string & command, const std::string & text, cons return true; } +// filename test for whisper GGML models: ggml-*.bin +static bool whisper_hf_is_ggml_bin(const std::string & name) { + return name.rfind("ggml-", 0) == 0 && name.size() >= 4 && + name.compare(name.size() - 4, 4, ".bin") == 0; +} + +// pick the primary file from a listing: exact hf_file match, else the first ggml-*.bin +static const hf_cache::hf_file * whisper_hf_pick_primary(const hf_cache::hf_files & files, const std::string & hf_file) { + for (const auto & file : files) { + if (!hf_file.empty()) { + if (file.path == hf_file) { + return &file; + } + } else { + const std::string name = std::filesystem::path(file.path).filename().string(); + if (whisper_hf_is_ggml_bin(name)) { + return &file; + } + } + } + return nullptr; +} + +// collect the entries whose filename matches ggml-*.bin +static hf_cache::hf_files whisper_hf_ggml_candidates(const hf_cache::hf_files & files) { + hf_cache::hf_files out; + for (const auto & file : files) { + const std::string name = std::filesystem::path(file.path).filename().string(); + if (whisper_hf_is_ggml_bin(name)) { + out.push_back(file); + } + } + return out; +} + +// print an error message followed by the sorted list of candidate filenames +static void whisper_hf_print_candidates(const std::string & msg, const hf_cache::hf_files & candidates) { + fprintf(stderr, "%s\n", msg.c_str()); + std::vector names; + for (const auto & file : candidates) { + names.push_back(std::filesystem::path(file.path).filename().string()); + } + std::sort(names.begin(), names.end()); + for (const auto & name : names) { + fprintf(stderr, " - %s\n", name.c_str()); + } +} + +std::string whisper_hf_resolve_model(const std::string & hf_repo, const std::string & hf_file) { + const char * token_env = std::getenv("HF_TOKEN"); + const std::string token = token_env ? token_env : ""; + + // honor an HF offline mode (huggingface_hub convention): skip the network path entirely + const char * offline_env = std::getenv("HF_HUB_OFFLINE"); + const bool offline = offline_env && *offline_env && std::string(offline_env) != "0"; + + // -hf alone (no --hf-file): cache-first, and refuse ambiguity rather than guess. + if (hf_file.empty()) { + const hf_cache::hf_files cached = whisper_hf_ggml_candidates(hf_cache::get_cached_files(hf_repo)); + if (cached.size() == 1) { + return hf_cache::finalize_file(cached.front()); + } + if (cached.size() > 1) { + whisper_hf_print_candidates( + "error: multiple models cached for " + hf_repo + "; specify one with -hff/--hf-file:", cached); + return ""; + } + + // none cached + if (offline) { + fprintf(stderr, "error: %s not found in HF cache\n", hf_repo.c_str()); + return ""; + } + + const hf_cache::hf_files remote = whisper_hf_ggml_candidates(hf_cache::get_repo_files(hf_repo, token)); + if (remote.empty()) { + fprintf(stderr, "error: no models found in %s\n", hf_repo.c_str()); + return ""; + } + // don't auto-pick/download a multi-model repo; list what's available instead + whisper_hf_print_candidates( + "error: multiple models available in " + hf_repo + "; specify one with -hff/--hf-file:", remote); + return ""; + } + + // explicit --hf-file: download-first with cache fall-back (Phase 2, unchanged). + + // 1. try download first: list the repo over the network and fetch the primary file. + // get_repo_files swallows network errors into an empty result (graceful degradation). + if (!offline) { + const hf_cache::hf_files remote = hf_cache::get_repo_files(hf_repo, token); + if (const hf_cache::hf_file * primary = whisper_hf_pick_primary(remote, hf_file)) { + if (hf_cache::download_file(*primary, token)) { + return hf_cache::finalize_file(*primary); + } + } + } + + // 2. fall back to the on-disk HF hub cache scan (Phase 1 behavior). + const hf_cache::hf_files cached = hf_cache::get_cached_files(hf_repo); + if (const hf_cache::hf_file * primary = whisper_hf_pick_primary(cached, hf_file)) { + return hf_cache::finalize_file(*primary); + } + + fprintf(stderr, "error: file '%s' not found in %s (cache or network)\n", hf_file.c_str(), hf_repo.c_str()); + return ""; +} + #undef STB_VORBIS_HEADER_ONLY #include "stb_vorbis.c" diff --git a/examples/common-whisper.h b/examples/common-whisper.h index aec430d3635..f3e55c5af41 100644 --- a/examples/common-whisper.h +++ b/examples/common-whisper.h @@ -33,3 +33,7 @@ int utf8_trailing_bytes_needed(const std::string & s); // write text to file, and call system("command voice_id file") bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id); + +// returns a concrete model path, or "" if the repo/file is not resolvable from the local cache. +// Phase 1: cache-only (get_cached_files + finalize_file). Phase 2 adds download. +std::string whisper_hf_resolve_model(const std::string & hf_repo, const std::string & hf_file); diff --git a/examples/hf-cache.cpp b/examples/hf-cache.cpp new file mode 100644 index 00000000000..aa61c723022 --- /dev/null +++ b/examples/hf-cache.cpp @@ -0,0 +1,670 @@ +#include "hf-cache.h" + +#include "whisper.h" +#include "http.h" + +#include "json.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nl = nlohmann; + +// whisper.cpp does not vendor llama.cpp's log.h; route the ported logging to stderr. +#define LOG_WRN(...) fprintf(stderr, __VA_ARGS__) +#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__) + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#define HOME_DIR "USERPROFILE" +#include +#else +#define HOME_DIR "HOME" +#include +#include +#endif + +namespace hf_cache { + +namespace fs = std::filesystem; + +// local string helpers (whisper's examples/common.h lacks llama.cpp's string_* utilities) +static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + for (size_t pos = 0; (pos = s.find(search, pos)) != std::string::npos; pos += replace.length()) { + s.erase(pos, search.length()); + s.insert(pos, replace); + } +} + +static bool string_starts_with(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) == 0; +} + +// mirrors llama.cpp's common_get_model_endpoint(): MODEL_ENDPOINT / HF_ENDPOINT, default HuggingFace +static std::string get_model_endpoint() { + const char * endpoint_env = std::getenv("MODEL_ENDPOINT"); + const char * hf_endpoint = std::getenv("HF_ENDPOINT"); + std::string endpoint = "https://huggingface.co/"; + if (endpoint_env) { + endpoint = endpoint_env; + } else if (hf_endpoint) { + endpoint = hf_endpoint; + } + if (endpoint.back() != '/') { + endpoint += '/'; + } + return endpoint; +} + +static fs::path get_cache_directory() { + static const fs::path cache = []() { + struct { + const char * var; + fs::path path; + } entries[] = { + {"LLAMA_CACHE", fs::path()}, + {"HF_HUB_CACHE", fs::path()}, + {"HUGGINGFACE_HUB_CACHE", fs::path()}, + {"HF_HOME", fs::path("hub")}, + {"XDG_CACHE_HOME", fs::path("huggingface") / "hub"}, + {HOME_DIR, fs::path(".cache") / "huggingface" / "hub"} + }; + for (const auto & entry : entries) { + if (auto * p = std::getenv(entry.var); p && *p) { + fs::path base(p); + return entry.path.empty() ? base : base / entry.path; + } + } +#ifndef _WIN32 + const struct passwd * pw = getpwuid(getuid()); + + if (pw && pw->pw_dir && *pw->pw_dir) { + return fs::path(pw->pw_dir) / ".cache" / "huggingface" / "hub"; + } +#endif + throw std::runtime_error("Failed to determine HF cache directory"); + }(); + + return cache; +} + +static std::string folder_name_to_repo(const std::string & folder) { + constexpr std::string_view prefix = "models--"; + if (folder.rfind(prefix, 0)) { + return {}; + } + std::string result = folder.substr(prefix.length()); + string_replace_all(result, "--", "/"); + return result; +} + +static std::string repo_to_folder_name(const std::string & repo_id) { + constexpr std::string_view prefix = "models--"; + std::string result = std::string(prefix) + repo_id; + string_replace_all(result, "/", "--"); + return result; +} + +static fs::path get_repo_path(const std::string & repo_id) { + return get_cache_directory() / repo_to_folder_name(repo_id); +} + +static bool is_hex_char(const char c) { + return (c >= 'A' && c <= 'F') || + (c >= 'a' && c <= 'f') || + (c >= '0' && c <= '9'); +} + +static bool is_hex_string(const std::string & s, size_t expected_len) { + if (s.length() != expected_len) { + return false; + } + for (const char c : s) { + if (!is_hex_char(c)) { + return false; + } + } + return true; +} + +static bool is_alphanum(const char c) { + return (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9'); +} + +static bool is_special_char(char c) { + return c == '/' || c == '.' || c == '-'; +} + +// base chars [A-Za-z0-9_] are always valid +// special chars [/.-] must be surrounded by base chars +// exactly one '/' required +static bool is_valid_repo_id(const std::string & repo_id) { + if (repo_id.empty() || repo_id.length() > 256) { + return false; + } + int slash = 0; + bool special = true; + + for (const char c : repo_id) { + if (is_alphanum(c) || c == '_') { + special = false; + } else if (is_special_char(c)) { + if (special) { + return false; + } + slash += (c == '/'); + special = true; + } else { + return false; + } + } + return !special && slash == 1; +} + +static bool is_valid_hf_token(const std::string & token) { + if (token.length() < 37 || token.length() > 256 || + !string_starts_with(token, "hf_")) { + return false; + } + for (size_t i = 3; i < token.length(); ++i) { + if (!is_alphanum(token[i])) { + return false; + } + } + return true; +} + +static bool is_valid_commit(const std::string & hash) { + return is_hex_string(hash, 40); +} + +static bool is_valid_oid(const std::string & oid) { + return is_hex_string(oid, 40) || is_hex_string(oid, 64); +} + +static bool is_valid_subpath(const fs::path & path, const fs::path & subpath) { + if (subpath.is_absolute()) { + return false; // never do a / b with b absolute + } + auto b = fs::absolute(path).lexically_normal(); + auto t = (b / subpath).lexically_normal(); + auto [b_end, _] = std::mismatch(b.begin(), b.end(), t.begin(), t.end()); + + return b_end == b.end(); +} + +static void safe_write_file(const fs::path & path, const std::string & data) { + fs::path path_tmp = path.string() + ".tmp"; + + if (path.has_parent_path()) { + fs::create_directories(path.parent_path()); + } + + std::ofstream file(path_tmp); + file << data; + file.close(); + + std::error_code ec; + + if (!file.fail()) { + fs::rename(path_tmp, path, ec); + } + if (file.fail() || ec) { + fs::remove(path_tmp, ec); + throw std::runtime_error("failed to write file: " + path.string()); + } +} + +static nl::json api_get(const std::string & url, + const std::string & token) { + auto [cli, parts] = common_http_client(url); + + httplib::Headers headers = { + {"User-Agent", "whisper-cpp/" + std::string(whisper_version())}, + {"Accept", "application/json"} + }; + + if (is_valid_hf_token(token)) { + headers.emplace("Authorization", "Bearer " + token); + } else if (!token.empty()) { + LOG_WRN("%s: invalid token, authentication disabled\n", __func__); + } + + if (auto res = cli.Get(parts.path, headers)) { + auto body = res->body; + + if (res->status == 200) { + return nl::json::parse(res->body); + } + try { + body = nl::json::parse(res->body)["error"].get(); + } catch (...) { } + + throw std::runtime_error("GET failed (" + std::to_string(res->status) + "): " + body); + } else { + throw std::runtime_error("HTTPLIB failed: " + httplib::to_string(res.error())); + } +} + +static std::string get_repo_commit(const std::string & repo_id, + const std::string & token) { + try { + auto endpoint = get_model_endpoint(); + auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token); + + if (!json.is_object() || + !json.contains("branches") || !json["branches"].is_array()) { + LOG_WRN("%s: missing 'branches' for '%s'\n", __func__, repo_id.c_str()); + return {}; + } + + fs::path refs_path = get_repo_path(repo_id) / "refs"; + std::string name; + std::string commit; + + for (const auto & branch : json["branches"]) { + if (!branch.is_object() || + !branch.contains("name") || !branch["name"].is_string() || + !branch.contains("targetCommit") || !branch["targetCommit"].is_string()) { + continue; + } + std::string _name = branch["name"].get(); + std::string _commit = branch["targetCommit"].get(); + + if (!is_valid_subpath(refs_path, _name)) { + LOG_WRN("%s: skip invalid branch: %s\n", __func__, _name.c_str()); + continue; + } + if (!is_valid_commit(_commit)) { + LOG_WRN("%s: skip invalid commit: %s\n", __func__, _commit.c_str()); + continue; + } + + if (_name == "main") { + name = _name; + commit = _commit; + break; + } + + if (name.empty() || commit.empty()) { + name = _name; + commit = _commit; + } + } + + if (name.empty() || commit.empty()) { + LOG_WRN("%s: no valid branch for '%s'\n", __func__, repo_id.c_str()); + return {}; + } + + safe_write_file(refs_path / name, commit); + return commit; + + } catch (const nl::json::exception & e) { + LOG_ERR("%s: JSON error: %s\n", __func__, e.what()); + } catch (const std::exception & e) { + LOG_ERR("%s: error: %s\n", __func__, e.what()); + } + return {}; +} + +hf_files get_repo_files(const std::string & repo_id, + const std::string & token) { + if (!is_valid_repo_id(repo_id)) { + LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str()); + return {}; + } + + std::string commit = get_repo_commit(repo_id, token); + if (commit.empty()) { + LOG_WRN("%s: failed to resolve commit for %s\n", __func__, repo_id.c_str()); + return {}; + } + + fs::path blobs_path = get_repo_path(repo_id) / "blobs"; + fs::path commit_path = get_repo_path(repo_id) / "snapshots" / commit; + + hf_files files; + + try { + auto endpoint = get_model_endpoint(); + auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token); + + if (!json.is_array()) { + LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str()); + return {}; + } + + for (const auto & item : json) { + if (!item.is_object() || + !item.contains("type") || !item["type"].is_string() || item["type"] != "file" || + !item.contains("path") || !item["path"].is_string()) { + continue; + } + + hf_file file; + file.repo_id = repo_id; + file.path = item["path"].get(); + + if (!is_valid_subpath(commit_path, file.path)) { + LOG_WRN("%s: skip invalid path: %s\n", __func__, file.path.c_str()); + continue; + } + + if (item.contains("lfs") && item["lfs"].is_object()) { + if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) { + file.oid = item["lfs"]["oid"].get(); + } + } else if (item.contains("oid") && item["oid"].is_string()) { + file.oid = item["oid"].get(); + } + + if (!file.oid.empty() && !is_valid_oid(file.oid)) { + LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str()); + continue; + } + + file.url = endpoint + repo_id + "/resolve/" + commit + "/" + file.path; + + fs::path final_path = commit_path / file.path; + file.final_path = final_path.string(); + + if (!file.oid.empty() && !fs::exists(final_path)) { + fs::path local_path = blobs_path / file.oid; + file.local_path = local_path.string(); + } else { + file.local_path = file.final_path; + } + + files.push_back(file); + } + } catch (const nl::json::exception & e) { + LOG_ERR("%s: JSON error: %s\n", __func__, e.what()); + } catch (const std::exception & e) { + LOG_ERR("%s: error: %s\n", __func__, e.what()); + } + return files; +} + +static std::string get_cached_ref(const fs::path & repo_path) { + fs::path refs_path = repo_path / "refs"; + if (!fs::is_directory(refs_path)) { + return {}; + } + std::string fallback; + + for (const auto & entry : fs::directory_iterator(refs_path)) { + if (!entry.is_regular_file()) { + continue; + } + std::ifstream f(entry.path()); + std::string commit; + if (!f || !std::getline(f, commit) || commit.empty()) { + continue; + } + if (!is_valid_commit(commit)) { + LOG_WRN("%s: skip invalid commit: %s\n", __func__, commit.c_str()); + continue; + } + if (entry.path().filename() == "main") { + return commit; + } + if (fallback.empty()) { + fallback = commit; + } + } + return fallback; +} + +hf_files get_cached_files(const std::string & repo_id) { + fs::path cache_dir = get_cache_directory(); + if (!fs::exists(cache_dir)) { + return {}; + } + + if (!repo_id.empty() && !is_valid_repo_id(repo_id)) { + LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str()); + return {}; + } + + hf_files files; + + for (const auto & repo : fs::directory_iterator(cache_dir)) { + if (!repo.is_directory()) { + continue; + } + fs::path snapshots_path = repo.path() / "snapshots"; + + if (!fs::exists(snapshots_path)) { + continue; + } + std::string _repo_id = folder_name_to_repo(repo.path().filename().string()); + + if (!is_valid_repo_id(_repo_id)) { + continue; + } + if (!repo_id.empty() && _repo_id != repo_id) { + continue; + } + std::string commit = get_cached_ref(repo.path()); + fs::path commit_path = snapshots_path / commit; + + if (commit.empty() || !fs::is_directory(commit_path)) { + continue; + } + for (const auto & entry : fs::recursive_directory_iterator(commit_path)) { + if (!entry.is_regular_file() && !entry.is_symlink()) { + continue; + } + fs::path path = entry.path().lexically_relative(commit_path); + + if (!path.empty()) { + hf_file file; + file.repo_id = _repo_id; + file.path = path.generic_string(); + file.local_path = entry.path().string(); + file.final_path = file.local_path; + files.push_back(std::move(file)); + } + } + } + + return files; +} + +bool download_file(const hf_file & file, const std::string & token) { + if (file.url.empty() || file.local_path.empty()) { + return false; + } + + std::error_code ec; + fs::path local_path(file.local_path); + + // already downloaded (blob present) -> nothing to do + if (fs::exists(local_path, ec)) { + return true; + } + + try { + if (local_path.has_parent_path()) { + fs::create_directories(local_path.parent_path(), ec); + } + + fs::path path_tmp = local_path.string() + ".tmp"; + + std::ofstream ofs(path_tmp, std::ios::binary); + if (!ofs.is_open()) { + LOG_ERR("%s: failed to open '%s' for writing\n", __func__, path_tmp.string().c_str()); + return false; + } + + httplib::Headers headers = { + {"User-Agent", "whisper-cpp/" + std::string(whisper_version())} + }; + + const bool have_auth = is_valid_hf_token(token); + if (have_auth) { + headers.emplace("Authorization", "Bearer " + token); + } else if (!token.empty()) { + LOG_WRN("%s: invalid token, authentication disabled\n", __func__); + } + + const char * func = __func__; // avoid __func__ inside a lambda + const std::string origin_host = common_http_parse_url(file.url).host; + + // cpp-httplib 0.20 mishandles cross-host redirects to signed CDN URLs + // (the presigned query string is lost, yielding a 403), so follow them + // manually here, re-issuing the request against the exact Location URL. + std::string url = file.url; + bool status_ok = false; + bool got_response = false; + + for (int redirect = 0; redirect <= 10; ++redirect) { + auto [cli, parts] = common_http_client(url); + cli.set_follow_location(false); + // the signed CDN Location already carries a fully percent-encoded + // query string; httplib's default url-encoding would re-encode '+' + // (and ',', ';', ...) inside the signature and break it, so send the + // path verbatim (matching curl) to avoid a 403 from the CDN. + cli.set_url_encode(false); + + // never forward the HF bearer token to a different (CDN) host + httplib::Headers req_headers = headers; + if (have_auth && parts.host != origin_host) { + req_headers.erase("Authorization"); + } + + std::string location; + bool is_redirect = false; + status_ok = false; + got_response = false; + + auto res = cli.Get(parts.path, req_headers, + [&](const httplib::Response & response) { + got_response = true; + if (response.status >= 300 && response.status < 400 && + response.has_header("Location")) { + location = response.get_header_value("Location"); + is_redirect = true; + return false; // stop before streaming the redirect body + } + if (response.status != 200) { + LOG_WRN("%s: download failed (%d) for %s\n", func, response.status, url.c_str()); + return false; + } + status_ok = true; + return true; + }, + [&](const char * data, size_t len) { + ofs.write(data, len); + return (bool) ofs; + }); + + if (is_redirect && !location.empty()) { + url = location; + continue; + } + + if (!got_response) { + LOG_ERR("%s: HTTP error: %s\n", __func__, httplib::to_string(res.error()).c_str()); + } + break; + } + + ofs.close(); + + if (!status_ok || ofs.fail()) { + fs::remove(path_tmp, ec); + return false; + } + + fs::rename(path_tmp, local_path, ec); + if (ec) { + LOG_ERR("%s: failed to move '%s' to '%s': %s\n", __func__, + path_tmp.string().c_str(), local_path.string().c_str(), ec.message().c_str()); + fs::remove(path_tmp, ec); + return false; + } + + return true; + + } catch (const std::exception & e) { + LOG_ERR("%s: error: %s\n", __func__, e.what()); + } + return false; +} + +std::string finalize_file(const hf_file & file) { + static std::atomic symlinks_disabled{false}; + + std::error_code ec; + fs::path local_path(file.local_path); + fs::path final_path(file.final_path); + + if (local_path == final_path || fs::exists(final_path, ec)) { + return file.final_path; + } + + if (!fs::exists(local_path, ec)) { + return file.final_path; + } + + fs::create_directories(final_path.parent_path(), ec); + + if (!symlinks_disabled) { + fs::path target = fs::relative(local_path, final_path.parent_path(), ec); + if (!ec) { + fs::create_symlink(target, final_path, ec); + } + if (!ec) { + return file.final_path; + } + } + + if (!symlinks_disabled.exchange(true)) { + LOG_WRN("%s: failed to create symlink: %s\n", __func__, ec.message().c_str()); + LOG_WRN("%s: switching to degraded mode\n", __func__); + } + + fs::rename(local_path, final_path, ec); + if (ec) { + LOG_WRN("%s: failed to move file to snapshots: %s\n", __func__, ec.message().c_str()); + fs::copy(local_path, final_path, ec); + if (ec) { + LOG_ERR("%s: failed to copy file to snapshots: %s\n", __func__, ec.message().c_str()); + } + } + return file.final_path; +} + +bool remove_cached_repo(const std::string & repo_id) { + if (!is_valid_repo_id(repo_id)) { + LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str()); + return false; + } + fs::path repo_path = get_repo_path(repo_id); + std::error_code ec; + auto removed = fs::remove_all(repo_path, ec); + if (ec) { + LOG_ERR("%s: failed to remove repo cache %s: %s\n", __func__, repo_path.string().c_str(), ec.message().c_str()); + return false; + } + return removed > 0; +} + +} // namespace hf_cache diff --git a/examples/hf-cache.h b/examples/hf-cache.h new file mode 100644 index 00000000000..5ceb76f66c7 --- /dev/null +++ b/examples/hf-cache.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include + +// Ref: https://huggingface.co/docs/hub/local-cache.md + +namespace hf_cache { + +struct hf_file { + std::string path; + std::string url; + std::string local_path; + std::string final_path; + std::string oid; + std::string repo_id; +}; + +using hf_files = std::vector; + +// Get files from HF API +hf_files get_repo_files( + const std::string & repo_id, + const std::string & token +); + +hf_files get_cached_files(const std::string & repo_id = {}); + +// Download file.url -> file.local_path (blobs/), skipping if already present. +// Returns false on network failure. HTTPS requires CPPHTTPLIB_OPENSSL_SUPPORT. +bool download_file(const hf_file & file, const std::string & token); + +// Create snapshot path (link or move/copy) and return it +std::string finalize_file(const hf_file & file); + +// Remove the entire cached directory for a repo, returns true if removed +bool remove_cached_repo(const std::string & repo_id); + +} // namespace hf_cache diff --git a/examples/http.h b/examples/http.h new file mode 100644 index 00000000000..a38b03cee30 --- /dev/null +++ b/examples/http.h @@ -0,0 +1,119 @@ +#pragma once + +#include "httplib.h" + +#include +#include +#include + +struct common_http_url { + std::string scheme; + std::string user; + std::string password; + std::string host; + int port; + std::string path; +}; + +// bracket an IPv6 literal host for a URL authority (RFC 3986) +static std::string common_http_format_host(const std::string & host) { + return host.find(':') != std::string::npos ? "[" + host + "]" : host; +} + +static common_http_url common_http_parse_url(const std::string & url) { + common_http_url parts; + auto scheme_end = url.find("://"); + + if (scheme_end == std::string::npos) { + throw std::runtime_error("invalid URL: no scheme"); + } + parts.scheme = url.substr(0, scheme_end); + + if (parts.scheme != "http" && parts.scheme != "https") { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + + auto rest = url.substr(scheme_end + 3); + auto at_pos = rest.find('@'); + + if (at_pos != std::string::npos) { + auto auth = rest.substr(0, at_pos); + auto colon_pos = auth.find(':'); + if (colon_pos != std::string::npos) { + parts.user = auth.substr(0, colon_pos); + parts.password = auth.substr(colon_pos + 1); + } else { + parts.user = auth; + } + rest = rest.substr(at_pos + 1); + } + + auto slash_pos = rest.find('/'); + + if (slash_pos != std::string::npos) { + parts.host = rest.substr(0, slash_pos); + parts.path = rest.substr(slash_pos); + } else { + parts.host = rest; + parts.path = "/"; + } + + // split the authority into host and optional port, a bracketed IPv6 literal keeps its inner colons (RFC 3986) + std::string port_str; + if (!parts.host.empty() && parts.host.front() == '[') { + auto close = parts.host.find(']'); + if (close == std::string::npos) { + throw std::runtime_error("invalid IPv6 URL authority: " + parts.host); + } + auto after = parts.host.substr(close + 1); + if (!after.empty() && after.front() == ':') { + port_str = after.substr(1); + } + parts.host = parts.host.substr(1, close - 1); + } else { + auto colon_pos = parts.host.find(':'); + if (colon_pos != std::string::npos) { + port_str = parts.host.substr(colon_pos + 1); + parts.host = parts.host.substr(0, colon_pos); + } + } + + if (!port_str.empty()) { + parts.port = std::stoi(port_str); + } else if (parts.scheme == "http") { + parts.port = 80; + } else if (parts.scheme == "https") { + parts.port = 443; + } else { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + + return parts; +} + +static std::pair common_http_client(const std::string & url) { + common_http_url parts = common_http_parse_url(url); + + if (parts.host.empty()) { + throw std::runtime_error("error: invalid URL format"); + } + +#ifndef CPPHTTPLIB_OPENSSL_SUPPORT + if (parts.scheme == "https") { + throw std::runtime_error( + "HTTPS is not supported. Please rebuild with -DWHISPER_OPENSSL=ON " + "(requires OpenSSL dev files installed)" + ); + } +#endif + + httplib::Client cli(parts.scheme + "://" + common_http_format_host(parts.host) + ":" + std::to_string(parts.port)); + + if (!parts.user.empty()) { + cli.set_basic_auth(parts.user, parts.password); + } + + cli.set_follow_location(true); + + return { std::move(cli), std::move(parts) }; +} diff --git a/tests/test-hf-resolve.sh b/tests/test-hf-resolve.sh new file mode 100755 index 00000000000..40889142190 --- /dev/null +++ b/tests/test-hf-resolve.sh @@ -0,0 +1,136 @@ +#!/bin/bash + +# Offline test for whisper-cli's -hf / --hf-file HuggingFace cache resolution. +# +# It seeds a temporary HF hub cache (HF_HUB_CACHE) with the +# models--org--repo/{refs,snapshots} layout that the `hf` CLI / huggingface_hub +# produces, using an existing local `for-tests` model as the payload, then checks: +# 1. `-hf --hf-file ` resolves the cached snapshot and runs (exit 0) +# 2. a missing --hf-file prints the "file ... not found" error and exits 3 +# 3. `-m ` regression: an explicit model path still works unchanged +# 4. bare invocation (no -hf/-m) still uses the models/ggml-base.en.bin default +# 5. (optional) a no-OpenSSL build attempting an https resolve with an empty +# cache prints the "rebuild with -DWHISPER_OPENSSL=ON" hint and exits non-zero +# +# HF_HUB_OFFLINE=1 forces the resolver to skip the network path (Phase 2), so the +# warm-cache cases resolve deterministically from the seeded cache with no network. +# +# Usage: +# ./tests/test-hf-resolve.sh +# WHISPER_CLI=build-ssl/bin/whisper-cli \ +# WHISPER_CLI_NOSSL=build-nossl/bin/whisper-cli ./tests/test-hf-resolve.sh + +set -u + +cd "$(dirname "$0")/.." + +main="${WHISPER_CLI:-./build/bin/whisper-cli}" +main_nossl="${WHISPER_CLI_NOSSL:-}" +sample="samples/jfk.wav" +seed_model="models/for-tests-ggml-base.en.bin" +repo="ggerganov/whisper.cpp" +hf_file="ggml-base.en.bin" + +for f in "$main" "$sample" "$seed_model"; do + if [ ! -e "$f" ]; then + printf "required fixture not found: %s\n" "$f" + printf "build whisper-cli and ensure test models/samples are present first.\n" + exit 1 + fi +done + +tmp_cache="$(mktemp -d)" +trap 'rm -rf "$tmp_cache"' EXIT + +commit="$(printf '%040d' 1 | tr '0' 'a')" +snapshot_dir="$tmp_cache/models--ggerganov--whisper.cpp/snapshots/$commit" +refs_dir="$tmp_cache/models--ggerganov--whisper.cpp/refs" +mkdir -p "$snapshot_dir" "$refs_dir" +printf '%s' "$commit" > "$refs_dir/main" +cp "$seed_model" "$snapshot_dir/$hf_file" + +fail=0 + +# 1. cache resolution succeeds (offline: network path skipped, falls back to cache) +if HF_HUB_OFFLINE=1 HF_HUB_CACHE="$tmp_cache" "$main" -hf "$repo" --hf-file "$hf_file" -f "$sample" >/tmp/hf_resolve_ok.log 2>&1; then + if grep -qi "failed to open" /tmp/hf_resolve_ok.log; then + printf "FAIL: -hf resolved but model failed to open\n"; fail=1 + else + printf "PASS: -hf %s --hf-file %s resolved from cache (offline, exit 0)\n" "$repo" "$hf_file" + fi +else + printf "FAIL: -hf offline resolution exited non-zero\n"; cat /tmp/hf_resolve_ok.log; fail=1 +fi + +# 2. missing file -> exit 3 with clear error +HF_HUB_OFFLINE=1 HF_HUB_CACHE="$tmp_cache" "$main" -hf "$repo" --hf-file ggml-missing.bin -f "$sample" >/tmp/hf_resolve_miss.log 2>&1 +rc=$? +if [ "$rc" -eq 3 ] && grep -qi "file 'ggml-missing.bin' not found" /tmp/hf_resolve_miss.log; then + printf "PASS: missing --hf-file reports 'file ... not found' and exits 3\n" +else + printf "FAIL: missing --hf-file expected exit 3 + error message, got exit %s\n" "$rc"; fail=1 +fi + +# 2b. -hf alone (no --hf-file), single cached model -> resolves and runs (exit 0) +if HF_HUB_OFFLINE=1 HF_HUB_CACHE="$tmp_cache" "$main" -hf "$repo" -f "$sample" >/tmp/hf_resolve_single.log 2>&1; then + if grep -qi "failed to open" /tmp/hf_resolve_single.log; then + printf "FAIL: -hf alone resolved but model failed to open\n"; fail=1 + else + printf "PASS: -hf %s (no --hf-file) resolved single cached model (exit 0)\n" "$repo" + fi +else + printf "FAIL: -hf alone with a single cached model exited non-zero\n"; cat /tmp/hf_resolve_single.log; fail=1 +fi + +# 2c. -hf alone (no --hf-file), multiple cached models -> exit 3 + "multiple models cached" + both filenames +second_file="ggml-tiny.en.bin" +cp "$seed_model" "$snapshot_dir/$second_file" +HF_HUB_OFFLINE=1 HF_HUB_CACHE="$tmp_cache" "$main" -hf "$repo" -f "$sample" >/tmp/hf_resolve_multi.log 2>&1 +rc=$? +if [ "$rc" -eq 3 ] && grep -qi "multiple models cached" /tmp/hf_resolve_multi.log \ + && grep -q "$hf_file" /tmp/hf_resolve_multi.log && grep -q "$second_file" /tmp/hf_resolve_multi.log; then + printf "PASS: -hf alone with multiple cached models reports 'multiple models cached' + lists both, exits 3\n" +else + printf "FAIL: -hf alone with multiple cached models expected exit 3 + list, got exit %s\n" "$rc" + cat /tmp/hf_resolve_multi.log; fail=1 +fi +rm -f "$snapshot_dir/$second_file" + +# 3. -m regression: explicit path still works +if "$main" -m "$seed_model" -f "$sample" >/tmp/hf_resolve_m.log 2>&1; then + printf "PASS: -m %s still works (exit 0)\n" "$seed_model" +else + printf "FAIL: -m regression exited non-zero\n"; cat /tmp/hf_resolve_m.log; fail=1 +fi + +# 4. bare default unchanged: still points at models/ggml-base.en.bin +"$main" -f "$sample" >/tmp/hf_resolve_bare.log 2>&1 +if grep -qi "models/ggml-base.en.bin" /tmp/hf_resolve_bare.log; then + printf "PASS: bare invocation still uses models/ggml-base.en.bin default\n" +else + printf "FAIL: bare default no longer references models/ggml-base.en.bin\n"; cat /tmp/hf_resolve_bare.log; fail=1 +fi + +# 5. no-OpenSSL build: an https resolve against an empty cache prints the rebuild +# hint and exits non-zero. Only runs if a no-OpenSSL binary is provided. +if [ -n "$main_nossl" ] && [ -e "$main_nossl" ]; then + empty_cache="$(mktemp -d)" + HF_HUB_CACHE="$empty_cache" "$main_nossl" -hf "$repo" --hf-file "$hf_file" -f "$sample" >/tmp/hf_resolve_nossl.log 2>&1 + rc=$? + rm -rf "$empty_cache" + if [ "$rc" -ne 0 ] && grep -qi "rebuild with -DWHISPER_OPENSSL=ON" /tmp/hf_resolve_nossl.log; then + printf "PASS: no-OpenSSL https attempt prints rebuild hint and exits non-zero\n" + else + printf "FAIL: no-OpenSSL https attempt expected rebuild hint + non-zero exit, got exit %s\n" "$rc" + cat /tmp/hf_resolve_nossl.log; fail=1 + fi +else + printf "SKIP: no-OpenSSL rebuild-hint check (set WHISPER_CLI_NOSSL to enable)\n" +fi + +if [ "$fail" -ne 0 ]; then + printf "\ntest-hf-resolve: FAILED\n" + exit 1 +fi + +printf "\ntest-hf-resolve: all checks passed\n"