diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5983133..d45677f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,29 +8,46 @@ on: env: CARGO_TERM_COLOR: always - HF_ENDPOINT: https://hub-ci.huggingface.co - RUST_BACKTRACE: 1 jobs: fmt: name: Format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@nightly + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: dtolnay/rust-toolchain@3c5f7ea28cd621ae0bf5283f0e981fb97b8a7af9 # master with: + toolchain: nightly components: rustfmt + - uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- - run: cargo +nightly fmt --all --check clippy: name: Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: dtolnay/rust-toolchain@3c5f7ea28cd621ae0bf5283f0e981fb97b8a7af9 # master with: + toolchain: stable components: clippy - - uses: Swatinem/rust-cache@v2 + - uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- # lint with all features, and with only default features - run: cargo clippy -p huggingface-hub --all-features -- -D warnings - run: cargo clippy -p huggingface-hub -- -D warnings @@ -39,9 +56,19 @@ jobs: name: Build (release) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: dtolnay/rust-toolchain@3c5f7ea28cd621ae0bf5283f0e981fb97b8a7af9 # master + with: + toolchain: stable + - uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- - run: cargo build -p huggingface-hub --all-features --release - run: cargo build -p huggingface-hub --release - run: cargo build -p huggingface-hub --features xet --release @@ -51,9 +78,19 @@ jobs: name: Test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: dtolnay/rust-toolchain@3c5f7ea28cd621ae0bf5283f0e981fb97b8a7af9 # master + with: + toolchain: stable + - uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- - name: Set up Python for cache interop tests uses: actions/setup-python@v5 with: @@ -65,7 +102,8 @@ jobs: pip install huggingface_hub[cli] - name: All tests env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_CI_TOKEN: ${{ secrets.HF_CI_TOKEN }} + HF_PROD_TOKEN: ${{ secrets.HF_PROD_TOKEN }} HF_TEST_WRITE: "1" run: | source .venv/bin/activate diff --git a/Cargo.lock b/Cargo.lock index 6a2ff94..65036d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -997,6 +997,7 @@ dependencies = [ "futures", "globset", "hf-xet", + "libc", "owo-colors", "pathdiff", "rand 0.9.2", diff --git a/huggingface_hub/Cargo.toml b/huggingface_hub/Cargo.toml index 9728da8..31d9f81 100644 --- a/huggingface_hub/Cargo.toml +++ b/huggingface_hub/Cargo.toml @@ -75,6 +75,7 @@ sha2 = "0.10" serial_test = "3" assert_cmd = "2" anyhow = "1" +libc = "0.2" [[example]] name = "repo" diff --git a/huggingface_hub/src/api/commits.rs b/huggingface_hub/src/api/commits.rs index 284d8c4..12685c5 100644 --- a/huggingface_hub/src/api/commits.rs +++ b/huggingface_hub/src/api/commits.rs @@ -5,11 +5,11 @@ use url::Url; use crate::constants; use crate::diff::HFFileDiff; use crate::error::Result; -use crate::repository::{ - HFRepository, RepoCreateBranchParams, RepoCreateTagParams, RepoDeleteBranchParams, RepoDeleteTagParams, +use crate::repository::HFRepository; +use crate::types::{ + GitCommitInfo, GitRefs, RepoCreateBranchParams, RepoCreateTagParams, RepoDeleteBranchParams, RepoDeleteTagParams, RepoGetCommitDiffParams, RepoGetRawDiffParams, RepoListCommitsParams, RepoListRefsParams, }; -use crate::types::{GitCommitInfo, GitRefs}; impl HFRepository { /// Stream commit history for the repository at a given revision. diff --git a/huggingface_hub/src/api/files.rs b/huggingface_hub/src/api/files.rs index 0fef7b5..983b42f 100644 --- a/huggingface_hub/src/api/files.rs +++ b/huggingface_hub/src/api/files.rs @@ -11,12 +11,13 @@ use tokio::io::AsyncWriteExt; use url::Url; use crate::error::{HFError, Result}; -use crate::repository::{ - HFRepository, RepoCreateCommitParams, RepoDeleteFileParams, RepoDeleteFolderParams, RepoDownloadFileParams, - RepoDownloadFileStreamParams, RepoDownloadFileToBytesParams, RepoGetPathsInfoParams, RepoListFilesParams, - RepoListTreeParams, RepoSnapshotDownloadParams, RepoUploadFileParams, RepoUploadFolderParams, +use crate::repository::HFRepository; +use crate::types::{ + AddSource, CommitInfo, CommitOperation, RepoCreateCommitParams, RepoDeleteFileParams, RepoDeleteFolderParams, + RepoDownloadFileParams, RepoDownloadFileStreamParams, RepoDownloadFileToBytesParams, RepoGetPathsInfoParams, + RepoListFilesParams, RepoListTreeParams, RepoSnapshotDownloadParams, RepoTreeEntry, RepoType, RepoUploadFileParams, + RepoUploadFolderParams, }; -use crate::types::{AddSource, CommitInfo, CommitOperation, RepoTreeEntry, RepoType}; use crate::{cache, constants}; impl HFRepository { diff --git a/huggingface_hub/src/api/repo.rs b/huggingface_hub/src/api/repo.rs index 0f116fd..28dcacd 100644 --- a/huggingface_hub/src/api/repo.rs +++ b/huggingface_hub/src/api/repo.rs @@ -4,10 +4,10 @@ use url::Url; use crate::client::HFClient; use crate::constants; use crate::error::{HFError, Result}; -use crate::repository::{HFRepository, RepoFileExistsParams, RepoRevisionExistsParams, RepoUpdateSettingsParams}; +use crate::repository::HFRepository; use crate::types::{ CreateRepoParams, DatasetInfo, DeleteRepoParams, ListDatasetsParams, ListModelsParams, ListSpacesParams, ModelInfo, - MoveRepoParams, RepoUrl, SpaceInfo, + MoveRepoParams, RepoFileExistsParams, RepoRevisionExistsParams, RepoUpdateSettingsParams, RepoUrl, SpaceInfo, }; impl HFRepository { @@ -471,7 +471,7 @@ sync_api_stream! { sync_api! { impl HFRepository -> HFRepositorySync { - fn info(&self, params: &crate::repository::RepoInfoParams) -> Result; + fn info(&self, params: &crate::types::RepoInfoParams) -> Result; fn exists(&self) -> Result; fn revision_exists(&self, params: &RepoRevisionExistsParams) -> Result; fn file_exists(&self, params: &RepoFileExistsParams) -> Result; diff --git a/huggingface_hub/src/api/spaces.rs b/huggingface_hub/src/api/spaces.rs index 33173bc..6a3724e 100644 --- a/huggingface_hub/src/api/spaces.rs +++ b/huggingface_hub/src/api/spaces.rs @@ -1,10 +1,9 @@ -use crate::SpaceVariableDeleteParams; use crate::error::Result; -use crate::repository::{ - HFSpace, SpaceHardwareRequestParams, SpaceSecretDeleteParams, SpaceSecretParams, SpaceSleepTimeParams, - SpaceVariableParams, +use crate::repository::HFSpace; +use crate::types::{ + DuplicateSpaceParams, RepoUrl, SpaceHardwareRequestParams, SpaceRuntime, SpaceSecretDeleteParams, + SpaceSecretParams, SpaceSleepTimeParams, SpaceVariableDeleteParams, SpaceVariableParams, }; -use crate::types::{DuplicateSpaceParams, RepoUrl, SpaceRuntime}; impl HFSpace { /// Fetch the current runtime state of the Space (hardware, stage, URL, etc.). diff --git a/huggingface_hub/src/client.rs b/huggingface_hub/src/client.rs index 891ce5d..45538f8 100644 --- a/huggingface_hub/src/client.rs +++ b/huggingface_hub/src/client.rs @@ -45,6 +45,8 @@ pub(crate) struct HFClientInner { pub(crate) token: Option, pub(crate) cache_dir: std::path::PathBuf, pub(crate) cache_enabled: bool, + #[cfg(feature = "xet")] + pub(crate) xet_state: std::sync::Mutex, } /// Builder for [`HFClient`]. @@ -182,6 +184,8 @@ impl HFClientBuilder { token, cache_dir, cache_enabled: self.cache_enabled.unwrap_or(true), + #[cfg(feature = "xet")] + xet_state: std::sync::Mutex::new(crate::xet::XetState::default()), }), }) } @@ -300,6 +304,50 @@ impl HFClient { } } +#[cfg(feature = "xet")] +impl HFClient { + /// Get or lazily create the cached XetSession. + /// + /// Returns `(session, generation)`. The generation is an opaque counter + /// that identifies which session instance this is. Pass it to + /// [`replace_xet_session`](Self::replace_xet_session) so that only the + /// caller that observed the error triggers a replacement — concurrent + /// callers that already obtained a fresh session won't clobber it. + pub(crate) fn xet_session(&self) -> Result<(xet::xet_session::XetSession, u64)> { + let mut guard = self + .inner + .xet_state + .lock() + .map_err(|e| HFError::Other(format!("xet session mutex poisoned: {e}")))?; + + if let Some(ref session) = guard.session { + return Ok((session.clone(), guard.generation)); + } + + let session = xet::xet_session::XetSessionBuilder::new() + .build() + .map_err(|e| HFError::Other(format!("Failed to build xet session: {e}")))?; + guard.session = Some(session.clone()); + guard.generation += 1; + Ok((session, guard.generation)) + } + + /// Replace the cached XetSession only if the generation matches. + /// + /// Called by xet call sites when a factory method returns an error. + /// The generation check ensures that if another thread already replaced + /// the session, this call is a no-op rather than discarding the fresh one. + pub(crate) fn replace_xet_session(&self, generation: u64, err: &xet::error::XetError) { + tracing::warn!(error = %err, generation, "replacing cached XetSession"); + let Ok(mut guard) = self.inner.xet_state.lock() else { + return; + }; + if guard.generation == generation { + guard.session = None; + } + } +} + /// Resolve token from environment or token file. /// Priority: HF_TOKEN env → HF_TOKEN_PATH file → $HF_HOME/token file. fn resolve_token() -> Option { @@ -357,4 +405,120 @@ mod tests { let path_str = api.cache_dir().to_string_lossy(); assert!(path_str.contains("huggingface") && path_str.ends_with("hub")); } + + #[cfg(feature = "xet")] + #[test] + fn test_xet_session_lazy_creation() { + let client = HFClientBuilder::new().build().unwrap(); + assert!(client.inner.xet_state.lock().unwrap().session.is_none()); + let (_s1, _gen) = client.xet_session().unwrap(); + assert!(client.inner.xet_state.lock().unwrap().session.is_some()); + } + + #[cfg(feature = "xet")] + #[test] + fn test_xet_session_shared_across_clones() { + let client = HFClientBuilder::new().build().unwrap(); + let clone = client.clone(); + let (_s1, _gen) = client.xet_session().unwrap(); + assert!(clone.inner.xet_state.lock().unwrap().session.is_some()); + } + + #[cfg(feature = "xet")] + #[test] + fn test_xet_session_recovers_after_abort() { + let client = HFClientBuilder::new().build().unwrap(); + + let (session, generation) = client.xet_session().unwrap(); + session.abort().unwrap(); + + match session.new_file_download_group() { + Ok(_) => panic!("expected error after abort"), + Err(e) => client.replace_xet_session(generation, &e), + } + + let (recovered, _) = client.xet_session().unwrap(); + assert!(recovered.new_file_download_group().is_ok()); + } + + #[cfg(feature = "xet")] + #[test] + fn test_xet_session_recovers_after_sigint_abort() { + let client = HFClientBuilder::new().build().unwrap(); + + let (session, generation) = client.xet_session().unwrap(); + session.sigint_abort().unwrap(); + + client.replace_xet_session(generation, &xet::error::XetError::KeyboardInterrupt); + + let (recovered, _) = client.xet_session().unwrap(); + assert!(recovered.new_file_download_group().is_ok()); + } + + /// Simulates the call-site retry pattern used in xet.rs: + /// 1. Get session + generation, factory call fails + /// 2. Call replace_xet_session(generation) to drop the bad session + /// 3. Get fresh session, factory call succeeds + #[cfg(feature = "xet")] + #[test] + fn test_replace_and_retry_after_abort() { + let client = HFClientBuilder::new().build().unwrap(); + + let (session, generation) = client.xet_session().unwrap(); + assert!(session.new_file_download_group().is_ok()); + + session.abort().unwrap(); + + let group = match session.new_file_download_group() { + Ok(b) => b, + Err(e) => { + client.replace_xet_session(generation, &e); + client + .xet_session() + .unwrap() + .0 + .new_file_download_group() + .expect("fresh session factory call should succeed") + }, + }; + drop(group); + } + + /// Verifies that replace_xet_session with a stale generation is a no-op. + #[cfg(feature = "xet")] + #[test] + fn test_replace_with_stale_generation_is_noop() { + let client = HFClientBuilder::new().build().unwrap(); + + let (session, gen1) = client.xet_session().unwrap(); + session.abort().unwrap(); + + // First replace succeeds + client.replace_xet_session(gen1, &xet::error::XetError::KeyboardInterrupt); + + // Get the fresh session with a new generation + let (_fresh, gen2) = client.xet_session().unwrap(); + assert_ne!(gen1, gen2); + + // Attempting to replace with the old generation is a no-op + client.replace_xet_session(gen1, &xet::error::XetError::KeyboardInterrupt); + + // The fresh session is still cached + let (still_fresh, gen3) = client.xet_session().unwrap(); + assert_eq!(gen2, gen3); + assert!(still_fresh.new_file_download_group().is_ok()); + } + + #[cfg(feature = "xet")] + #[test] + fn test_xet_session_reuse_without_replacement() { + let client = HFClientBuilder::new().build().unwrap(); + + let (s1, g1) = client.xet_session().unwrap(); + let (s2, g2) = client.xet_session().unwrap(); + + assert_eq!(g1, g2); + assert!(s1.new_file_download_group().is_ok()); + assert!(s2.new_file_download_group().is_ok()); + } } diff --git a/huggingface_hub/src/lib.rs b/huggingface_hub/src/lib.rs index cffeb6a..f44cbfa 100644 --- a/huggingface_hub/src/lib.rs +++ b/huggingface_hub/src/lib.rs @@ -32,6 +32,8 @@ pub mod types; #[cfg(feature = "xet")] pub mod xet; +pub mod test_utils; + #[cfg(feature = "blocking")] pub use blocking::{HFClientSync, HFRepoSync, HFRepositorySync, HFSpaceSync}; pub use client::{HFClient, HFClientBuilder}; diff --git a/huggingface_hub/src/repository.rs b/huggingface_hub/src/repository.rs index f573d19..d425ccc 100644 --- a/huggingface_hub/src/repository.rs +++ b/huggingface_hub/src/repository.rs @@ -1,14 +1,10 @@ use std::fmt; use std::ops::Deref; -use std::path::PathBuf; use std::sync::Arc; -use serde::Serialize; -use typed_builder::TypedBuilder; - use crate::client::HFClient; use crate::error::{HFError, Result}; -use crate::types::{AddSource, CommitOperation, RepoInfo, RepoType}; +use crate::types::{RepoInfo, RepoInfoParams, RepoType}; /// A handle for a single repository on the Hugging Face Hub. /// @@ -79,389 +75,6 @@ impl fmt::Debug for HFSpace { } } -#[derive(Default, TypedBuilder)] -pub struct RepoInfoParams { - /// Git revision (branch, tag, or commit SHA) to fetch info for. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoRevisionExistsParams { - /// Git revision (branch, tag, or commit SHA) to check for existence. - #[builder(setter(into))] - pub revision: String, -} - -#[derive(TypedBuilder)] -pub struct RepoFileExistsParams { - /// Path of the file to check within the repository. - #[builder(setter(into))] - pub filename: String, - /// Git revision to check. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, -} - -#[derive(Default, TypedBuilder)] -pub struct RepoListFilesParams { - /// Git revision to list files from. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, -} - -#[derive(Default, TypedBuilder)] -pub struct RepoListTreeParams { - /// Git revision to list the tree from. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Whether to list files recursively in subdirectories. - #[builder(default)] - pub recursive: bool, - /// Whether to include expanded metadata (size, LFS info) for each entry. - #[builder(default)] - pub expand: bool, - /// Maximum number of tree entries to return. - #[builder(default, setter(strip_option))] - pub limit: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoGetPathsInfoParams { - /// List of file paths within the repository to retrieve info for. - pub paths: Vec, - /// Git revision to query. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoDownloadFileParams { - /// Path of the file to download within the repository. - #[builder(setter(into))] - pub filename: String, - /// Local directory to download the file into. When set, the file is saved with its repo path structure. - #[builder(default, setter(strip_option))] - pub local_dir: Option, - /// Git revision to download from. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// If `true`, re-download the file even if a cached copy exists. - #[builder(default, setter(strip_option))] - pub force_download: Option, - /// If `true`, only return the file if it is already cached locally; never make a network request. - #[builder(default, setter(strip_option))] - pub local_files_only: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoDownloadFileStreamParams { - /// Path of the file to stream within the repository. - #[builder(setter(into))] - pub filename: String, - /// Git revision to stream from. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Byte range to request (HTTP Range header). Useful for partial downloads. - #[builder(default, setter(strip_option))] - pub range: Option>, -} - -pub type RepoDownloadFileToBytesParams = RepoDownloadFileStreamParams; -pub type RepoDownloadFileToBytesParamsBuilder = RepoDownloadFileStreamParamsBuilder; - -#[derive(Default, TypedBuilder)] -pub struct RepoSnapshotDownloadParams { - /// Git revision to download. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Glob patterns for files to include in the download. Only matching files are downloaded. - #[builder(default, setter(strip_option))] - pub allow_patterns: Option>, - /// Glob patterns for files to exclude from the download. - #[builder(default, setter(strip_option))] - pub ignore_patterns: Option>, - /// Local directory to download the snapshot into. - #[builder(default, setter(strip_option))] - pub local_dir: Option, - /// If `true`, re-download all files even if cached copies exist. - #[builder(default, setter(strip_option))] - pub force_download: Option, - /// If `true`, only return files already cached locally; never make network requests. - #[builder(default, setter(strip_option))] - pub local_files_only: Option, - /// Maximum number of concurrent file downloads. - #[builder(default, setter(strip_option))] - pub max_workers: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoUploadFileParams { - /// Source of the file content to upload (bytes or file path). - pub source: AddSource, - /// Destination path within the repository. - #[builder(setter(into))] - pub path_in_repo: String, - /// Git revision (branch) to upload to. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Commit message for the upload. - #[builder(default, setter(into, strip_option))] - pub commit_message: Option, - /// Extended description for the commit. - #[builder(default, setter(into, strip_option))] - pub commit_description: Option, - /// If `true`, create a pull request instead of committing directly. - #[builder(default, setter(strip_option))] - pub create_pr: Option, - /// Expected parent commit SHA. The upload fails if the branch head has moved past this commit. - #[builder(default, setter(into, strip_option))] - pub parent_commit: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoUploadFolderParams { - /// Local folder path to upload. - #[builder(setter(into))] - pub folder_path: PathBuf, - /// Destination directory within the repository. Defaults to the repo root. - #[builder(default, setter(into, strip_option))] - pub path_in_repo: Option, - /// Git revision (branch) to upload to. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Commit message for the upload. - #[builder(default, setter(into, strip_option))] - pub commit_message: Option, - /// Extended description for the commit. - #[builder(default, setter(into, strip_option))] - pub commit_description: Option, - /// If `true`, create a pull request instead of committing directly. - #[builder(default, setter(strip_option))] - pub create_pr: Option, - /// Glob patterns for files to include from the local folder. - #[builder(default, setter(strip_option))] - pub allow_patterns: Option>, - /// Glob patterns for files to exclude from the local folder. - #[builder(default, setter(strip_option))] - pub ignore_patterns: Option>, - /// Glob patterns for remote files to delete that are not present locally. - #[builder(default, setter(strip_option))] - pub delete_patterns: Option>, -} - -#[derive(TypedBuilder)] -pub struct RepoDeleteFileParams { - /// Path of the file to delete within the repository. - #[builder(setter(into))] - pub path_in_repo: String, - /// Git revision (branch) to delete from. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Commit message for the deletion. - #[builder(default, setter(into, strip_option))] - pub commit_message: Option, - /// If `true`, create a pull request instead of committing directly. - #[builder(default, setter(strip_option))] - pub create_pr: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoDeleteFolderParams { - /// Path of the folder to delete within the repository. - #[builder(setter(into))] - pub path_in_repo: String, - /// Git revision (branch) to delete from. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Commit message for the deletion. - #[builder(default, setter(into, strip_option))] - pub commit_message: Option, - /// If `true`, create a pull request instead of committing directly. - #[builder(default, setter(strip_option))] - pub create_pr: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoCreateCommitParams { - /// List of file operations (additions, deletions, copies) to include in the commit. - pub operations: Vec, - /// Commit message. - #[builder(setter(into))] - pub commit_message: String, - /// Extended description for the commit. - #[builder(default, setter(into, strip_option))] - pub commit_description: Option, - /// Git revision (branch) to commit to. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// If `true`, create a pull request instead of committing directly. - #[builder(default, setter(strip_option))] - pub create_pr: Option, - /// Expected parent commit SHA. The commit fails if the branch head has moved past this commit. - #[builder(default, setter(into, strip_option))] - pub parent_commit: Option, -} - -#[derive(Default, TypedBuilder)] -pub struct RepoListCommitsParams { - /// Git revision (branch, tag, or commit SHA) to list commits from. Defaults to the main branch. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Maximum number of commits to return. - #[builder(default, setter(strip_option))] - pub limit: Option, -} - -#[derive(Default, TypedBuilder)] -pub struct RepoListRefsParams { - /// Whether to include pull request refs in the listing. - #[builder(default)] - pub include_pull_requests: bool, -} - -#[derive(TypedBuilder)] -pub struct RepoGetCommitDiffParams { - /// Revision to compare against the parent (branch, tag, or commit SHA). - #[builder(setter(into))] - pub compare: String, -} - -#[derive(TypedBuilder)] -pub struct RepoGetRawDiffParams { - /// Revision to compare against the parent (branch, tag, or commit SHA). - #[builder(setter(into))] - pub compare: String, -} - -#[derive(TypedBuilder)] -pub struct RepoCreateBranchParams { - /// Name of the branch to create. - #[builder(setter(into))] - pub branch: String, - /// Revision to branch from. Defaults to the current main branch head. - #[builder(default, setter(into, strip_option))] - pub revision: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoDeleteBranchParams { - /// Name of the branch to delete. - #[builder(setter(into))] - pub branch: String, -} - -#[derive(TypedBuilder)] -pub struct RepoCreateTagParams { - /// Name of the tag to create. - #[builder(setter(into))] - pub tag: String, - /// Revision to tag. Defaults to the current main branch head. - #[builder(default, setter(into, strip_option))] - pub revision: Option, - /// Annotation message for the tag. - #[builder(default, setter(into, strip_option))] - pub message: Option, -} - -#[derive(TypedBuilder)] -pub struct RepoDeleteTagParams { - /// Name of the tag to delete. - #[builder(setter(into))] - pub tag: String, -} - -#[derive(Default, TypedBuilder, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct RepoUpdateSettingsParams { - /// Whether the repository should be private. - #[builder(default, setter(strip_option))] - #[serde(skip_serializing_if = "Option::is_none")] - pub private: Option, - /// Access-gating mode for the repository (e.g. `auto`, `manual`). - #[builder(default, setter(strip_option))] - #[serde(skip_serializing_if = "Option::is_none")] - pub gated: Option, - /// Repository description shown on the Hub page. - #[builder(default, setter(into, strip_option))] - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Whether discussions are disabled on this repository. - #[builder(default, setter(strip_option))] - #[serde(skip_serializing_if = "Option::is_none")] - pub discussions_disabled: Option, - /// Email address to receive gated-access request notifications. - #[builder(default, setter(into, strip_option))] - #[serde(skip_serializing_if = "Option::is_none")] - pub gated_notifications_email: Option, - /// When to send gated-access notifications (e.g. `each`, `daily`). - #[builder(default, setter(strip_option))] - #[serde(skip_serializing_if = "Option::is_none")] - pub gated_notifications_mode: Option, -} - -#[cfg(feature = "spaces")] -#[derive(TypedBuilder)] -pub struct SpaceHardwareRequestParams { - /// Hardware flavor to request (e.g. `"cpu-basic"`, `"t4-small"`, `"a10g-small"`). - #[builder(setter(into))] - pub hardware: String, - /// Number of seconds of inactivity before the Space is put to sleep. `0` means never sleep. - #[builder(default, setter(strip_option))] - pub sleep_time: Option, -} - -#[cfg(feature = "spaces")] -#[derive(TypedBuilder)] -pub struct SpaceSleepTimeParams { - /// Number of seconds of inactivity before the Space is put to sleep. `0` means never sleep. - pub sleep_time: u64, -} - -#[cfg(feature = "spaces")] -#[derive(TypedBuilder)] -pub struct SpaceSecretParams { - /// Secret key name. - #[builder(setter(into))] - pub key: String, - /// Secret value. - #[builder(setter(into))] - pub value: String, - /// Human-readable description of the secret. - #[builder(default, setter(into, strip_option))] - pub description: Option, -} - -#[cfg(feature = "spaces")] -#[derive(TypedBuilder)] -pub struct SpaceSecretDeleteParams { - /// Secret key name to delete. - #[builder(setter(into))] - pub key: String, -} - -#[cfg(feature = "spaces")] -#[derive(TypedBuilder)] -pub struct SpaceVariableParams { - /// Variable key name. - #[builder(setter(into))] - pub key: String, - /// Variable value. - #[builder(setter(into))] - pub value: String, - /// Human-readable description of the variable. - #[builder(default, setter(into, strip_option))] - pub description: Option, -} - -#[cfg(feature = "spaces")] -#[derive(TypedBuilder)] -pub struct SpaceVariableDeleteParams { - /// Variable key name to delete. - #[builder(setter(into))] - pub key: String, -} - impl HFClient { /// Create an [`HFRepository`] handle for any repo type. pub fn repo(&self, repo_type: RepoType, owner: impl Into, name: impl Into) -> HFRepository { diff --git a/huggingface_hub/src/test_utils.rs b/huggingface_hub/src/test_utils.rs new file mode 100644 index 0000000..8e004be --- /dev/null +++ b/huggingface_hub/src/test_utils.rs @@ -0,0 +1,62 @@ +// Shared constants and helpers for integration tests. + +use std::sync::OnceLock; + +// --- Environment variable names --- + +pub const HF_TOKEN: &str = "HF_TOKEN"; +pub const HF_CI_TOKEN: &str = "HF_CI_TOKEN"; +pub const HF_PROD_TOKEN: &str = "HF_PROD_TOKEN"; +pub const HF_TEST_WRITE: &str = "HF_TEST_WRITE"; +pub const HF_ENDPOINT: &str = "HF_ENDPOINT"; +pub const GITHUB_ACTIONS: &str = "GITHUB_ACTIONS"; +pub const HF_HUB_CACHE: &str = "HF_HUB_CACHE"; +pub const HF_HOME: &str = "HF_HOME"; +pub const XDG_CACHE_HOME: &str = "XDG_CACHE_HOME"; + +// --- Endpoints --- + +pub const PROD_ENDPOINT: &str = "https://huggingface.co"; +pub const HUB_CI_ENDPOINT: &str = "https://hub-ci.huggingface.co"; + +// --- Common helpers --- + +pub fn is_ci() -> bool { + static VALUE: OnceLock = OnceLock::new(); + *VALUE.get_or_init(|| std::env::var(GITHUB_ACTIONS).is_ok()) +} + +pub fn write_enabled() -> bool { + static VALUE: OnceLock = OnceLock::new(); + *VALUE.get_or_init(|| std::env::var(HF_TEST_WRITE).ok().is_some_and(|v| v == "1")) +} + +/// Resolve a token suitable for hub-ci writes. +/// CI: uses HF_CI_TOKEN. Local: uses HF_TOKEN. +pub fn resolve_hub_ci_token() -> Option { + static VALUE: OnceLock> = OnceLock::new(); + VALUE + .get_or_init(|| { + if is_ci() { + std::env::var(HF_CI_TOKEN).ok() + } else { + std::env::var(HF_TOKEN).ok() + } + }) + .clone() +} + +/// Resolve a token for production access. +/// CI: uses HF_PROD_TOKEN. Local: uses HF_TOKEN. +pub fn resolve_prod_token() -> Option { + static VALUE: OnceLock> = OnceLock::new(); + VALUE + .get_or_init(|| { + if is_ci() { + std::env::var(HF_PROD_TOKEN).ok() + } else { + std::env::var(HF_TOKEN).ok() + } + }) + .clone() +} diff --git a/huggingface_hub/src/types/mod.rs b/huggingface_hub/src/types/mod.rs index 86adfdb..9d57962 100644 --- a/huggingface_hub/src/types/mod.rs +++ b/huggingface_hub/src/types/mod.rs @@ -2,6 +2,7 @@ pub mod cache; pub mod commit; pub mod params; pub mod repo; +pub mod repo_params; pub mod user; #[cfg(feature = "spaces")] @@ -10,6 +11,7 @@ pub mod spaces; pub use commit::*; pub use params::*; pub use repo::*; +pub use repo_params::*; #[cfg(feature = "spaces")] pub use spaces::*; pub use user::*; diff --git a/huggingface_hub/src/types/repo_params.rs b/huggingface_hub/src/types/repo_params.rs new file mode 100644 index 0000000..3a63672 --- /dev/null +++ b/huggingface_hub/src/types/repo_params.rs @@ -0,0 +1,390 @@ +use std::path::PathBuf; + +use serde::Serialize; +use typed_builder::TypedBuilder; + +use super::commit::{AddSource, CommitOperation}; +use super::repo::{GatedApprovalMode, GatedNotificationsMode}; + +#[derive(Default, TypedBuilder)] +pub struct RepoInfoParams { + /// Git revision (branch, tag, or commit SHA) to fetch info for. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoRevisionExistsParams { + /// Git revision (branch, tag, or commit SHA) to check for existence. + #[builder(setter(into))] + pub revision: String, +} + +#[derive(TypedBuilder)] +pub struct RepoFileExistsParams { + /// Path of the file to check within the repository. + #[builder(setter(into))] + pub filename: String, + /// Git revision to check. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, +} + +#[derive(Default, TypedBuilder)] +pub struct RepoListFilesParams { + /// Git revision to list files from. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, +} + +#[derive(Default, TypedBuilder)] +pub struct RepoListTreeParams { + /// Git revision to list the tree from. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Whether to list files recursively in subdirectories. + #[builder(default)] + pub recursive: bool, + /// Whether to include expanded metadata (size, LFS info) for each entry. + #[builder(default)] + pub expand: bool, + /// Maximum number of tree entries to return. + #[builder(default, setter(strip_option))] + pub limit: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoGetPathsInfoParams { + /// List of file paths within the repository to retrieve info for. + pub paths: Vec, + /// Git revision to query. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoDownloadFileParams { + /// Path of the file to download within the repository. + #[builder(setter(into))] + pub filename: String, + /// Local directory to download the file into. When set, the file is saved with its repo path structure. + #[builder(default, setter(strip_option))] + pub local_dir: Option, + /// Git revision to download from. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// If `true`, re-download the file even if a cached copy exists. + #[builder(default, setter(strip_option))] + pub force_download: Option, + /// If `true`, only return the file if it is already cached locally; never make a network request. + #[builder(default, setter(strip_option))] + pub local_files_only: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoDownloadFileStreamParams { + /// Path of the file to stream within the repository. + #[builder(setter(into))] + pub filename: String, + /// Git revision to stream from. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Byte range to request (HTTP Range header). Useful for partial downloads. + #[builder(default, setter(strip_option))] + pub range: Option>, +} + +pub type RepoDownloadFileToBytesParams = RepoDownloadFileStreamParams; +pub type RepoDownloadFileToBytesParamsBuilder = RepoDownloadFileStreamParamsBuilder; + +#[derive(Default, TypedBuilder)] +pub struct RepoSnapshotDownloadParams { + /// Git revision to download. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Glob patterns for files to include in the download. Only matching files are downloaded. + #[builder(default, setter(strip_option))] + pub allow_patterns: Option>, + /// Glob patterns for files to exclude from the download. + #[builder(default, setter(strip_option))] + pub ignore_patterns: Option>, + /// Local directory to download the snapshot into. + #[builder(default, setter(strip_option))] + pub local_dir: Option, + /// If `true`, re-download all files even if cached copies exist. + #[builder(default, setter(strip_option))] + pub force_download: Option, + /// If `true`, only return files already cached locally; never make network requests. + #[builder(default, setter(strip_option))] + pub local_files_only: Option, + /// Maximum number of concurrent file downloads. + #[builder(default, setter(strip_option))] + pub max_workers: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoUploadFileParams { + /// Source of the file content to upload (bytes or file path). + pub source: AddSource, + /// Destination path within the repository. + #[builder(setter(into))] + pub path_in_repo: String, + /// Git revision (branch) to upload to. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Commit message for the upload. + #[builder(default, setter(into, strip_option))] + pub commit_message: Option, + /// Extended description for the commit. + #[builder(default, setter(into, strip_option))] + pub commit_description: Option, + /// If `true`, create a pull request instead of committing directly. + #[builder(default, setter(strip_option))] + pub create_pr: Option, + /// Expected parent commit SHA. The upload fails if the branch head has moved past this commit. + #[builder(default, setter(into, strip_option))] + pub parent_commit: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoUploadFolderParams { + /// Local folder path to upload. + #[builder(setter(into))] + pub folder_path: PathBuf, + /// Destination directory within the repository. Defaults to the repo root. + #[builder(default, setter(into, strip_option))] + pub path_in_repo: Option, + /// Git revision (branch) to upload to. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Commit message for the upload. + #[builder(default, setter(into, strip_option))] + pub commit_message: Option, + /// Extended description for the commit. + #[builder(default, setter(into, strip_option))] + pub commit_description: Option, + /// If `true`, create a pull request instead of committing directly. + #[builder(default, setter(strip_option))] + pub create_pr: Option, + /// Glob patterns for files to include from the local folder. + #[builder(default, setter(strip_option))] + pub allow_patterns: Option>, + /// Glob patterns for files to exclude from the local folder. + #[builder(default, setter(strip_option))] + pub ignore_patterns: Option>, + /// Glob patterns for remote files to delete that are not present locally. + #[builder(default, setter(strip_option))] + pub delete_patterns: Option>, +} + +#[derive(TypedBuilder)] +pub struct RepoDeleteFileParams { + /// Path of the file to delete within the repository. + #[builder(setter(into))] + pub path_in_repo: String, + /// Git revision (branch) to delete from. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Commit message for the deletion. + #[builder(default, setter(into, strip_option))] + pub commit_message: Option, + /// If `true`, create a pull request instead of committing directly. + #[builder(default, setter(strip_option))] + pub create_pr: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoDeleteFolderParams { + /// Path of the folder to delete within the repository. + #[builder(setter(into))] + pub path_in_repo: String, + /// Git revision (branch) to delete from. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Commit message for the deletion. + #[builder(default, setter(into, strip_option))] + pub commit_message: Option, + /// If `true`, create a pull request instead of committing directly. + #[builder(default, setter(strip_option))] + pub create_pr: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoCreateCommitParams { + /// List of file operations (additions, deletions, copies) to include in the commit. + pub operations: Vec, + /// Commit message. + #[builder(setter(into))] + pub commit_message: String, + /// Extended description for the commit. + #[builder(default, setter(into, strip_option))] + pub commit_description: Option, + /// Git revision (branch) to commit to. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// If `true`, create a pull request instead of committing directly. + #[builder(default, setter(strip_option))] + pub create_pr: Option, + /// Expected parent commit SHA. The commit fails if the branch head has moved past this commit. + #[builder(default, setter(into, strip_option))] + pub parent_commit: Option, +} + +#[derive(Default, TypedBuilder)] +pub struct RepoListCommitsParams { + /// Git revision (branch, tag, or commit SHA) to list commits from. Defaults to the main branch. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Maximum number of commits to return. + #[builder(default, setter(strip_option))] + pub limit: Option, +} + +#[derive(Default, TypedBuilder)] +pub struct RepoListRefsParams { + /// Whether to include pull request refs in the listing. + #[builder(default)] + pub include_pull_requests: bool, +} + +#[derive(TypedBuilder)] +pub struct RepoGetCommitDiffParams { + /// Revision to compare against the parent (branch, tag, or commit SHA). + #[builder(setter(into))] + pub compare: String, +} + +#[derive(TypedBuilder)] +pub struct RepoGetRawDiffParams { + /// Revision to compare against the parent (branch, tag, or commit SHA). + #[builder(setter(into))] + pub compare: String, +} + +#[derive(TypedBuilder)] +pub struct RepoCreateBranchParams { + /// Name of the branch to create. + #[builder(setter(into))] + pub branch: String, + /// Revision to branch from. Defaults to the current main branch head. + #[builder(default, setter(into, strip_option))] + pub revision: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoDeleteBranchParams { + /// Name of the branch to delete. + #[builder(setter(into))] + pub branch: String, +} + +#[derive(TypedBuilder)] +pub struct RepoCreateTagParams { + /// Name of the tag to create. + #[builder(setter(into))] + pub tag: String, + /// Revision to tag. Defaults to the current main branch head. + #[builder(default, setter(into, strip_option))] + pub revision: Option, + /// Annotation message for the tag. + #[builder(default, setter(into, strip_option))] + pub message: Option, +} + +#[derive(TypedBuilder)] +pub struct RepoDeleteTagParams { + /// Name of the tag to delete. + #[builder(setter(into))] + pub tag: String, +} + +#[derive(Default, TypedBuilder, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct RepoUpdateSettingsParams { + /// Whether the repository should be private. + #[builder(default, setter(strip_option))] + #[serde(skip_serializing_if = "Option::is_none")] + pub private: Option, + /// Access-gating mode for the repository (e.g. `auto`, `manual`). + #[builder(default, setter(strip_option))] + #[serde(skip_serializing_if = "Option::is_none")] + pub gated: Option, + /// Repository description shown on the Hub page. + #[builder(default, setter(into, strip_option))] + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Whether discussions are disabled on this repository. + #[builder(default, setter(strip_option))] + #[serde(skip_serializing_if = "Option::is_none")] + pub discussions_disabled: Option, + /// Email address to receive gated-access request notifications. + #[builder(default, setter(into, strip_option))] + #[serde(skip_serializing_if = "Option::is_none")] + pub gated_notifications_email: Option, + /// When to send gated-access notifications (e.g. `each`, `daily`). + #[builder(default, setter(strip_option))] + #[serde(skip_serializing_if = "Option::is_none")] + pub gated_notifications_mode: Option, +} + +#[cfg(feature = "spaces")] +#[derive(TypedBuilder)] +pub struct SpaceHardwareRequestParams { + /// Hardware flavor to request (e.g. `"cpu-basic"`, `"t4-small"`, `"a10g-small"`). + #[builder(setter(into))] + pub hardware: String, + /// Number of seconds of inactivity before the Space is put to sleep. `0` means never sleep. + #[builder(default, setter(strip_option))] + pub sleep_time: Option, +} + +#[cfg(feature = "spaces")] +#[derive(TypedBuilder)] +pub struct SpaceSleepTimeParams { + /// Number of seconds of inactivity before the Space is put to sleep. `0` means never sleep. + pub sleep_time: u64, +} + +#[cfg(feature = "spaces")] +#[derive(TypedBuilder)] +pub struct SpaceSecretParams { + /// Secret key name. + #[builder(setter(into))] + pub key: String, + /// Secret value. + #[builder(setter(into))] + pub value: String, + /// Human-readable description of the secret. + #[builder(default, setter(into, strip_option))] + pub description: Option, +} + +#[cfg(feature = "spaces")] +#[derive(TypedBuilder)] +pub struct SpaceSecretDeleteParams { + /// Secret key name to delete. + #[builder(setter(into))] + pub key: String, +} + +#[cfg(feature = "spaces")] +#[derive(TypedBuilder)] +pub struct SpaceVariableParams { + /// Variable key name. + #[builder(setter(into))] + pub key: String, + /// Variable value. + #[builder(setter(into))] + pub value: String, + /// Human-readable description of the variable. + #[builder(default, setter(into, strip_option))] + pub description: Option, +} + +#[cfg(feature = "spaces")] +#[derive(TypedBuilder)] +pub struct SpaceVariableDeleteParams { + /// Variable key name to delete. + #[builder(setter(into))] + pub key: String, +} diff --git a/huggingface_hub/src/xet.rs b/huggingface_hub/src/xet.rs index 712d5f5..618c986 100644 --- a/huggingface_hub/src/xet.rs +++ b/huggingface_hub/src/xet.rs @@ -7,7 +7,7 @@ use std::path::PathBuf; use serde::Deserialize; -use xet::xet_session::{Sha256Policy, XetFileInfo, XetFileMetadata, XetSession, XetSessionBuilder}; +use xet::xet_session::{Sha256Policy, XetFileInfo, XetFileMetadata}; use crate::client::HFClient; use crate::constants; @@ -23,6 +23,12 @@ struct XetTokenResponse { cas_url: String, } +#[derive(Default)] +pub(crate) struct XetState { + pub(crate) session: Option, + pub(crate) generation: u64, +} + pub struct XetConnectionInfo { pub endpoint: String, pub access_token: String, @@ -66,10 +72,18 @@ fn xet_token_url( format!("{}/api/{}/{}/xet-{}-token/{}", api.endpoint(), segment, repo_id, token_type, revision) } -fn build_xet_session() -> Result { - XetSessionBuilder::new() - .build() - .map_err(|e| HFError::Other(format!("Failed to build xet session: {e}"))) +/// Returns `true` if the error indicates the XetSession is permanently +/// poisoned and must be replaced before retrying. +#[cfg(test)] +fn is_session_poisoned(err: &xet::error::XetError) -> bool { + use xet::error::XetError; + matches!( + err, + XetError::UserCancelled(_) + | XetError::AlreadyCompleted + | XetError::PreviousTaskError(_) + | XetError::KeyboardInterrupt + ) } pub(crate) struct XetBatchFile { @@ -94,7 +108,6 @@ impl HFRepository { let file_size: u64 = crate::api::files::extract_file_size(head_response).unwrap_or(0); let conn = fetch_xet_connection_info(&self.hf_client, "read", &repo_path, repo_type, revision).await?; - let session = build_xet_session()?; tokio::fs::create_dir_all(local_dir).await?; let dest_path = local_dir.join(filename); @@ -102,18 +115,27 @@ impl HFRepository { tokio::fs::create_dir_all(parent).await?; } - let group = session - .new_file_download_group() - .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))? - .with_endpoint(conn.endpoint.clone()) - .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) - .with_token_refresh_url( - xet_token_url(&self.hf_client, "read", &repo_path, repo_type, revision), - self.hf_client.auth_headers(), - ) - .build() - .await - .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; + let (session, generation) = self.hf_client.xet_session()?; + let group = match session.new_file_download_group() { + Ok(b) => b, + Err(e) => { + self.hf_client.replace_xet_session(generation, &e); + self.hf_client + .xet_session()? + .0 + .new_file_download_group() + .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))? + }, + } + .with_endpoint(conn.endpoint.clone()) + .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) + .with_token_refresh_url( + xet_token_url(&self.hf_client, "read", &repo_path, repo_type, revision), + self.hf_client.auth_headers(), + ) + .build() + .await + .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; let file_info = XetFileInfo::new(file_hash, file_size); @@ -140,7 +162,6 @@ impl HFRepository { let repo_path = self.repo_path(); let repo_type = Some(self.repo_type); let conn = fetch_xet_connection_info(&self.hf_client, "read", &repo_path, repo_type, revision).await?; - let session = build_xet_session()?; if let Some(parent) = path.parent() { tokio::fs::create_dir_all(parent).await?; @@ -148,18 +169,27 @@ impl HFRepository { let incomplete_path = PathBuf::from(format!("{}.incomplete", path.display())); - let group = session - .new_file_download_group() - .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))? - .with_endpoint(conn.endpoint.clone()) - .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) - .with_token_refresh_url( - xet_token_url(&self.hf_client, "read", &repo_path, repo_type, revision), - self.hf_client.auth_headers(), - ) - .build() - .await - .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; + let (session, generation) = self.hf_client.xet_session()?; + let group = match session.new_file_download_group() { + Ok(b) => b, + Err(e) => { + self.hf_client.replace_xet_session(generation, &e); + self.hf_client + .xet_session()? + .0 + .new_file_download_group() + .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))? + }, + } + .with_endpoint(conn.endpoint.clone()) + .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) + .with_token_refresh_url( + xet_token_url(&self.hf_client, "read", &repo_path, repo_type, revision), + self.hf_client.auth_headers(), + ) + .build() + .await + .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; let file_info = XetFileInfo::new(file_hash.to_string(), file_size); @@ -185,20 +215,28 @@ impl HFRepository { let repo_path = self.repo_path(); let repo_type = Some(self.repo_type); let conn = fetch_xet_connection_info(&self.hf_client, "read", &repo_path, repo_type, revision).await?; - let session = build_xet_session()?; - - let group = session - .new_file_download_group() - .map_err(|e| HFError::Other(format!("Xet batch download failed: {e}")))? - .with_endpoint(conn.endpoint.clone()) - .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) - .with_token_refresh_url( - xet_token_url(&self.hf_client, "read", &repo_path, repo_type, revision), - self.hf_client.auth_headers(), - ) - .build() - .await - .map_err(|e| HFError::Other(format!("Xet batch download failed: {e}")))?; + + let (session, generation) = self.hf_client.xet_session()?; + let group = match session.new_file_download_group() { + Ok(b) => b, + Err(e) => { + self.hf_client.replace_xet_session(generation, &e); + self.hf_client + .xet_session()? + .0 + .new_file_download_group() + .map_err(|e| HFError::Other(format!("Xet batch download failed: {e}")))? + }, + } + .with_endpoint(conn.endpoint.clone()) + .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) + .with_token_refresh_url( + xet_token_url(&self.hf_client, "read", &repo_path, repo_type, revision), + self.hf_client.auth_headers(), + ) + .build() + .await + .map_err(|e| HFError::Other(format!("Xet batch download failed: {e}")))?; let mut incomplete_paths = Vec::with_capacity(files.len()); for file in files { @@ -243,20 +281,28 @@ impl HFRepository { let repo_path = self.repo_path(); let repo_type = Some(self.repo_type); let conn = fetch_xet_connection_info(&self.hf_client, "read", &repo_path, repo_type, revision).await?; - let session = build_xet_session()?; - - let group = session - .new_download_stream_group() - .map_err(|e| HFError::Other(format!("Xet stream download failed: {e}")))? - .with_endpoint(conn.endpoint.clone()) - .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) - .with_token_refresh_url( - xet_token_url(&self.hf_client, "read", &repo_path, repo_type, revision), - self.hf_client.auth_headers(), - ) - .build() - .await - .map_err(|e| HFError::Other(format!("Xet stream download failed: {e}")))?; + + let (session, generation) = self.hf_client.xet_session()?; + let group = match session.new_download_stream_group() { + Ok(b) => b, + Err(e) => { + self.hf_client.replace_xet_session(generation, &e); + self.hf_client + .xet_session()? + .0 + .new_download_stream_group() + .map_err(|e| HFError::Other(format!("Xet stream download failed: {e}")))? + }, + } + .with_endpoint(conn.endpoint.clone()) + .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) + .with_token_refresh_url( + xet_token_url(&self.hf_client, "read", &repo_path, repo_type, revision), + self.hf_client.auth_headers(), + ) + .build() + .await + .map_err(|e| HFError::Other(format!("Xet stream download failed: {e}")))?; let file_info = XetFileInfo::new(file_hash.to_string(), file_size); @@ -285,21 +331,29 @@ impl HFRepository { tracing::info!(repo = repo_path.as_str(), "fetching xet write token"); let conn = fetch_xet_connection_info(&self.hf_client, "write", &repo_path, repo_type, revision).await?; tracing::info!(endpoint = conn.endpoint.as_str(), "xet write token obtained, building session"); - let session = build_xet_session()?; tracing::info!("building xet upload commit"); - let commit = session - .new_upload_commit() - .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))? - .with_endpoint(conn.endpoint.clone()) - .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) - .with_token_refresh_url( - xet_token_url(&self.hf_client, "write", &repo_path, repo_type, revision), - self.hf_client.auth_headers(), - ) - .build() - .await - .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))?; + let (session, generation) = self.hf_client.xet_session()?; + let commit = match session.new_upload_commit() { + Ok(b) => b, + Err(e) => { + self.hf_client.replace_xet_session(generation, &e); + self.hf_client + .xet_session()? + .0 + .new_upload_commit() + .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))? + }, + } + .with_endpoint(conn.endpoint.clone()) + .with_token_info(conn.access_token.clone(), conn.expiration_unix_epoch) + .with_token_refresh_url( + xet_token_url(&self.hf_client, "write", &repo_path, repo_type, revision), + self.hf_client.auth_headers(), + ) + .build() + .await + .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))?; tracing::info!("xet upload commit built, queuing file uploads"); let mut task_ids_in_order = Vec::with_capacity(files.len()); @@ -347,3 +401,47 @@ impl HFClient { fetch_xet_connection_info(self, params.token_type.as_str(), ¶ms.repo_id, params.repo_type, revision).await } } + +#[cfg(test)] +mod tests { + use xet::error::XetError; + + use super::*; + + #[test] + fn test_session_poisoned_positive() { + assert!(is_session_poisoned(&XetError::UserCancelled("test".into()))); + assert!(is_session_poisoned(&XetError::AlreadyCompleted)); + assert!(is_session_poisoned(&XetError::PreviousTaskError("err".into()))); + assert!(is_session_poisoned(&XetError::KeyboardInterrupt)); + } + + #[test] + fn test_session_poisoned_negative() { + let non_poisoned = [ + XetError::Network("timeout".into()), + XetError::Authentication("bad token".into()), + XetError::Io("disk full".into()), + XetError::Internal("bug".into()), + XetError::Timeout("slow".into()), + XetError::NotFound("missing".into()), + XetError::DataIntegrity("corrupt".into()), + XetError::Configuration("bad config".into()), + XetError::Cancelled("cancelled".into()), + XetError::WrongRuntimeMode("wrong mode".into()), + XetError::TaskError("task failed".into()), + ]; + for err in &non_poisoned { + assert!(!is_session_poisoned(err), "{err:?} should NOT be classified as poisoned"); + } + } + + #[test] + fn test_xet_error_message_preserved_in_hferror() { + let xet_err = XetError::Network("connection reset by peer".into()); + let hf_err = HFError::Other(format!("Xet download failed: {xet_err}")); + let msg = hf_err.to_string(); + assert!(msg.contains("Xet download failed"), "missing prefix: {msg}"); + assert!(msg.contains("connection reset by peer"), "missing original message: {msg}"); + } +} diff --git a/huggingface_hub/tests/blocking_test.rs b/huggingface_hub/tests/blocking_test.rs index 1fe95b8..905213f 100644 --- a/huggingface_hub/tests/blocking_test.rs +++ b/huggingface_hub/tests/blocking_test.rs @@ -2,70 +2,68 @@ //! Integration tests for the synchronous HFClientSync wrapper. //! -//! These mirror a subset of the async integration tests to verify that the -//! blocking API works correctly end-to-end. +//! Read-only tests use hardcoded prod repos via `prod_sync_api()`. +//! Write tests using hub-ci via `ci_sync_api()` (require HF_TEST_WRITE=1). //! -//! Read-only tests: require HF_TOKEN, skip if not set. -//! Write tests: require HF_TOKEN + HF_TEST_WRITE=1, skip otherwise. -//! -//! Run: HF_TOKEN=hf_xxx cargo test -p huggingface-hub --features blocking --test blocking_test +//! Local: HF_TOKEN=hf_xxx cargo test -p huggingface-hub --features blocking --test blocking_test +//! CI: The workflow sets HF_CI_TOKEN + HF_PROD_TOKEN. -use huggingface_hub::repository::{ - RepoCreateBranchParams, RepoCreateCommitParams, RepoDeleteBranchParams, RepoDownloadFileParams, - RepoFileExistsParams, RepoGetCommitDiffParams, RepoGetRawDiffParams, RepoListCommitsParams, RepoListFilesParams, - RepoListRefsParams, RepoListTreeParams, RepoRevisionExistsParams, RepoUploadFileParams, RepoUploadFolderParams, -}; +use huggingface_hub::test_utils::*; use huggingface_hub::types::*; use huggingface_hub::{HFClientBuilder, HFClientSync, RepoInfo, RepoInfoParams}; -fn sync_api() -> Option { - if std::env::var("HF_TOKEN").is_err() { - return None; - } - let api = HFClientBuilder::new().build().expect("Failed to create HFClient"); +fn prod_sync_api() -> Option { + let api = if is_ci() { + let token = resolve_prod_token()?; + HFClientBuilder::new() + .token(token) + .endpoint(PROD_ENDPOINT) + .build() + .expect("Failed to create HFClient") + } else { + if std::env::var(HF_TOKEN).is_err() { + return None; + } + HFClientBuilder::new().build().expect("Failed to create HFClient") + }; Some(HFClientSync::from_api(api).expect("Failed to create HFClientSync")) } -fn write_enabled() -> bool { - std::env::var("HF_TEST_WRITE").ok().is_some_and(|v| v == "1") -} - -fn is_hub_ci() -> bool { - std::env::var("HF_ENDPOINT") - .ok() - .is_some_and(|v| v.contains("hub-ci.huggingface.co")) +fn ci_sync_api() -> Option { + let api = if is_ci() { + let token = std::env::var(HF_CI_TOKEN).ok()?; + HFClientBuilder::new() + .token(token) + .endpoint(HUB_CI_ENDPOINT) + .build() + .expect("Failed to create HFClient") + } else { + if std::env::var(HF_TOKEN).is_err() { + return None; + } + HFClientBuilder::new().build().expect("Failed to create HFClient") + }; + Some(HFClientSync::from_api(api).expect("Failed to create HFClientSync")) } fn test_org() -> &'static str { - if is_hub_ci() { "valid_org" } else { "huggingface" } + "huggingface" } fn test_user() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user" - } else { - "julien-c" - } + "julien-c" } fn test_model_author() -> &'static str { - if is_hub_ci() { "valid_org" } else { "openai-community" } + "openai-community" } fn test_model_repo() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/gpt2" - } else { - "openai-community/gpt2" - } + "openai-community/gpt2" } fn test_dataset_repo() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/hacker-news" - } else { - "xet-team/xet-spec-reference-files" - } + "xet-team/xet-spec-reference-files" } /// Split a `"owner/name"` string into an `HFRepositorySync` handle. @@ -91,7 +89,7 @@ fn dataset_handle(api: &HFClientSync, repo_id: &str) -> huggingface_hub::blockin #[test] fn test_sync_model_info() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let model_repo = test_model_repo(); let repo = repo_handle(&api, model_repo); let info = repo.info(&RepoInfoParams::default()).unwrap(); @@ -103,7 +101,7 @@ fn test_sync_model_info() { #[test] fn test_sync_dataset_info() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let dataset_repo = test_dataset_repo(); let repo = dataset_handle(&api, dataset_repo); let info = repo.info(&RepoInfoParams::default()).unwrap(); @@ -115,7 +113,7 @@ fn test_sync_dataset_info() { #[test] fn test_sync_repo_handle_info_and_file_exists() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let model_repo = test_model_repo(); let repo = repo_handle(&api, model_repo); @@ -133,14 +131,14 @@ fn test_sync_repo_handle_info_and_file_exists() { #[test] fn test_sync_repo_exists() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; assert!(repo_handle(&api, test_model_repo()).exists().unwrap()); assert!(!repo_handle(&api, "this-repo-definitely-does-not-exist-12345").exists().unwrap()); } #[test] fn test_sync_file_exists() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let repo = repo_handle(&api, test_model_repo()); assert!( repo.file_exists(&RepoFileExistsParams::builder().filename("config.json").build()) @@ -157,7 +155,7 @@ fn test_sync_file_exists() { #[test] fn test_sync_list_models() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let author = test_model_author(); let params = ListModelsParams::builder().author(author).limit(3_usize).build(); let models = api.list_models(¶ms).unwrap(); @@ -167,7 +165,7 @@ fn test_sync_list_models() { #[test] fn test_sync_list_datasets() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let params = ListDatasetsParams::builder().author(test_org()).limit(3_usize).build(); let datasets = api.list_datasets(¶ms).unwrap(); assert!(!datasets.is_empty()); @@ -175,7 +173,7 @@ fn test_sync_list_datasets() { #[test] fn test_sync_list_repo_files() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let files = repo_handle(&api, test_model_repo()) .list_files(&RepoListFilesParams::default()) .unwrap(); @@ -185,7 +183,7 @@ fn test_sync_list_repo_files() { #[test] fn test_sync_list_repo_tree() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let entries = repo_handle(&api, test_model_repo()) .list_tree(&RepoListTreeParams::default()) .unwrap(); @@ -197,7 +195,7 @@ fn test_sync_list_repo_tree() { #[test] fn test_sync_list_repo_commits() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let commits = repo_handle(&api, test_model_repo()) .list_commits(&RepoListCommitsParams::default()) .unwrap(); @@ -210,7 +208,7 @@ fn test_sync_list_repo_commits() { #[test] fn test_sync_list_repo_refs() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let refs = repo_handle(&api, test_model_repo()) .list_refs(&RepoListRefsParams::default()) .unwrap(); @@ -220,7 +218,7 @@ fn test_sync_list_repo_refs() { #[test] fn test_sync_revision_exists() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let repo = repo_handle(&api, test_model_repo()); assert!( repo.revision_exists(&RepoRevisionExistsParams::builder().revision("main").build()) @@ -237,7 +235,7 @@ fn test_sync_revision_exists() { #[test] fn test_sync_download_file() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let dir = tempfile::tempdir().unwrap(); let params = RepoDownloadFileParams::builder() .filename("config.json") @@ -254,20 +252,20 @@ fn test_sync_download_file() { #[test] fn test_sync_whoami() { - let Some(api) = sync_api() else { return }; + let Some(api) = ci_sync_api() else { return }; let user = api.whoami().unwrap(); assert!(!user.username.is_empty()); } #[test] fn test_sync_auth_check() { - let Some(api) = sync_api() else { return }; + let Some(api) = ci_sync_api() else { return }; api.auth_check().unwrap(); } #[test] fn test_sync_get_user_overview() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let username = test_user(); let user = api.get_user_overview(username).unwrap(); assert_eq!(user.username, username); @@ -275,7 +273,7 @@ fn test_sync_get_user_overview() { #[test] fn test_sync_get_organization_overview() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let org_name = test_org(); let org = api.get_organization_overview(org_name).unwrap(); assert_eq!(org.name, org_name); @@ -283,7 +281,7 @@ fn test_sync_get_organization_overview() { #[test] fn test_sync_list_organization_members() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let members = api.list_organization_members(test_org(), None).unwrap(); assert!(!members.is_empty()); } @@ -292,7 +290,7 @@ fn test_sync_list_organization_members() { #[test] fn test_sync_get_commit_diff() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let gpt2 = repo_handle(&api, test_model_repo()); let commits = gpt2.list_commits(&RepoListCommitsParams::default()).unwrap(); assert!(commits.len() >= 2); @@ -309,7 +307,7 @@ fn test_sync_get_commit_diff() { #[test] fn test_sync_get_raw_diff_stream() { - let Some(api) = sync_api() else { return }; + let Some(api) = prod_sync_api() else { return }; let gpt2 = repo_handle(&api, test_model_repo()); let commits = gpt2.list_commits(&RepoListCommitsParams::default()).unwrap(); assert!(commits.len() >= 2); @@ -363,7 +361,7 @@ fn delete_test_repo(api: &HFClientSync, repo_id: &str) { #[test] fn test_sync_create_and_delete_repo() { - let Some(api) = sync_api() else { return }; + let Some(api) = ci_sync_api() else { return }; if !write_enabled() { return; } @@ -405,7 +403,7 @@ fn test_sync_create_and_delete_repo() { #[test] fn test_sync_create_commit() { - let Some(api) = sync_api() else { return }; + let Some(api) = ci_sync_api() else { return }; if !write_enabled() { return; } @@ -441,7 +439,7 @@ fn test_sync_create_commit() { #[test] fn test_sync_upload_folder() { - let Some(api) = sync_api() else { return }; + let Some(api) = ci_sync_api() else { return }; if !write_enabled() { return; } @@ -473,7 +471,7 @@ fn test_sync_upload_folder() { #[test] fn test_sync_branch_operations() { - let Some(api) = sync_api() else { return }; + let Some(api) = ci_sync_api() else { return }; if !write_enabled() { return; } diff --git a/huggingface_hub/tests/cache_test.rs b/huggingface_hub/tests/cache_test.rs index f2b473a..260eb03 100644 --- a/huggingface_hub/tests/cache_test.rs +++ b/huggingface_hub/tests/cache_test.rs @@ -9,69 +9,67 @@ use std::path::Path; -use huggingface_hub::repository::{RepoDownloadFileParams, RepoSnapshotDownloadParams}; -use huggingface_hub::{HFClient, HFClientBuilder, HFError}; +use huggingface_hub::test_utils::*; +use huggingface_hub::{HFClient, HFClientBuilder, HFError, RepoDownloadFileParams, RepoSnapshotDownloadParams}; use serial_test::serial; fn api() -> Option { - if std::env::var("HF_TOKEN").is_err() { - return None; + if is_ci() { + let token = resolve_prod_token()?; + Some( + HFClientBuilder::new() + .token(token) + .endpoint(PROD_ENDPOINT) + .build() + .expect("Failed to create HFClient"), + ) + } else { + if std::env::var(HF_TOKEN).is_err() { + return None; + } + Some(HFClientBuilder::new().build().expect("Failed to create HFClient")) } - Some(HFClientBuilder::new().build().expect("Failed to create HFClient")) } -fn is_hub_ci() -> bool { - std::env::var("HF_ENDPOINT") - .ok() - .is_some_and(|v| v.contains("hub-ci.huggingface.co")) +fn api_with_cache(cache_dir: &std::path::Path) -> HFClient { + if is_ci() { + let token = std::env::var(HF_PROD_TOKEN).expect("HF_PROD_TOKEN required in CI for prod repo tests"); + HFClientBuilder::new() + .token(token) + .endpoint(PROD_ENDPOINT) + .cache_dir(cache_dir) + .build() + .expect("Failed to create HFClient") + } else { + HFClientBuilder::new() + .cache_dir(cache_dir) + .build() + .expect("Failed to create HFClient") + } } fn test_model_parts() -> (&'static str, &'static str) { - if is_hub_ci() { - ("huggingface-hub-rust-test-user", "gpt2") - } else { - ("", "gpt2") - } + ("", "gpt2") } fn test_model_repo_id() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/gpt2" - } else { - "gpt2" - } + "gpt2" } fn test_dataset_parts() -> (&'static str, &'static str) { - if is_hub_ci() { - ("huggingface-hub-rust-test-user", "hacker-news") - } else { - ("rajpurkar", "squad") - } + ("rajpurkar", "squad") } fn test_dataset_repo_id() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/hacker-news" - } else { - "rajpurkar/squad" - } + "rajpurkar/squad" } fn test_model_cache_fragment() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user--gpt2" - } else { - "gpt2" - } + "gpt2" } fn test_dataset_cache_fragment() -> &'static str { - if is_hub_ci() { - "datasets--huggingface-hub-rust-test-user--hacker-news" - } else { - "datasets--rajpurkar--squad" - } + "datasets--rajpurkar--squad" } fn find_repo_folder(cache_dir: &Path, name_fragment: &str) -> std::path::PathBuf { @@ -138,7 +136,7 @@ fn list_files_recursive(dir: &Path) -> Vec { async fn test_download_file_to_cache() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let path = api .model(test_model_parts().0, test_model_parts().1) @@ -167,7 +165,7 @@ async fn test_download_file_to_cache() { async fn test_download_file_cache_hit() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model(test_model_parts().0, test_model_parts().1); let path1 = repo @@ -184,7 +182,7 @@ async fn test_download_file_cache_hit() { #[tokio::test] async fn test_download_file_local_files_only_miss() { let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let result = api .model(test_model_parts().0, test_model_parts().1) @@ -202,7 +200,7 @@ async fn test_download_file_local_files_only_miss() { async fn test_download_file_local_files_only_hit() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model(test_model_parts().0, test_model_parts().1); let path1 = repo @@ -227,7 +225,7 @@ async fn test_download_file_local_files_only_hit() { async fn test_download_file_cache_symlink_structure() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let path = api .model(test_model_parts().0, test_model_parts().1) @@ -245,7 +243,7 @@ async fn test_download_file_cache_symlink_structure() { async fn test_snapshot_download() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let snapshot_dir = api .model(test_model_parts().0, test_model_parts().1) @@ -271,7 +269,7 @@ async fn test_snapshot_download() { async fn test_cache_hit_no_redownload() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model(test_model_parts().0, test_model_parts().1); repo.download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) @@ -293,7 +291,7 @@ async fn test_cache_hit_no_redownload() { async fn test_force_download_bypasses_cache() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model(test_model_parts().0, test_model_parts().1); repo.download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) @@ -323,7 +321,7 @@ async fn test_force_download_bypasses_cache() { async fn test_force_download_ignores_no_exist() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); // Create a stale .no_exist marker — network download should succeed // regardless since .no_exist is only consulted via resolve_from_cache_only @@ -357,7 +355,7 @@ async fn test_force_download_ignores_no_exist() { async fn test_no_exist_marker_on_404() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let result = api .model(test_model_parts().0, test_model_parts().1) @@ -383,7 +381,7 @@ async fn test_no_exist_marker_on_404() { async fn test_no_exist_marker_prevents_request() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model(test_model_parts().0, test_model_parts().1); @@ -416,7 +414,7 @@ async fn test_no_exist_marker_prevents_request() { async fn test_no_exist_writes_ref_on_404() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let _ = api .model(test_model_parts().0, test_model_parts().1) @@ -441,7 +439,7 @@ async fn test_no_exist_writes_ref_on_404() { async fn test_ref_written_for_branch_download() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); api.model(test_model_parts().0, test_model_parts().1) .download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) @@ -461,7 +459,7 @@ async fn test_ref_written_for_branch_download() { async fn test_no_ref_for_commit_hash_download() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); // First get the commit hash via a normal download api.model(test_model_parts().0, test_model_parts().1) @@ -477,7 +475,7 @@ async fn test_no_ref_for_commit_hash_download() { // Now download in a fresh cache using the commit hash directly let cache_dir2 = tempfile::tempdir().unwrap(); - let api2 = HFClientBuilder::new().cache_dir(cache_dir2.path()).build().unwrap(); + let api2 = api_with_cache(cache_dir2.path()); api2.model(test_model_parts().0, test_model_parts().1) .download_file( @@ -502,7 +500,7 @@ async fn test_no_ref_for_commit_hash_download() { async fn test_download_by_commit_hash() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); // Get commit hash from a normal download api.model(test_model_parts().0, test_model_parts().1) @@ -517,7 +515,7 @@ async fn test_download_by_commit_hash() { // Download by commit hash in fresh cache let cache_dir2 = tempfile::tempdir().unwrap(); - let api2 = HFClientBuilder::new().cache_dir(cache_dir2.path()).build().unwrap(); + let api2 = api_with_cache(cache_dir2.path()); let path = api2 .model(test_model_parts().0, test_model_parts().1) .download_file( @@ -545,7 +543,7 @@ async fn test_download_by_commit_hash() { async fn test_offline_fallback_with_cached_file() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); // Populate cache let original_path = api @@ -597,7 +595,7 @@ async fn test_offline_fallback_without_cache_propagates_error() { async fn test_snapshot_download_ignore_patterns() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let snapshot_dir = api .model(test_model_parts().0, test_model_parts().1) @@ -619,7 +617,7 @@ async fn test_snapshot_download_ignore_patterns() { #[tokio::test] async fn test_snapshot_download_local_files_only_miss() { let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let result = api .model(test_model_parts().0, test_model_parts().1) @@ -632,7 +630,7 @@ async fn test_snapshot_download_local_files_only_miss() { async fn test_snapshot_download_local_files_only_hit() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model(test_model_parts().0, test_model_parts().1); let dir1 = repo @@ -655,7 +653,7 @@ async fn test_snapshot_download_local_files_only_hit() { async fn test_snapshot_download_by_commit_hash() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); // First get the commit hash api.model(test_model_parts().0, test_model_parts().1) @@ -670,7 +668,7 @@ async fn test_snapshot_download_by_commit_hash() { // Snapshot download in fresh cache by commit hash let cache_dir2 = tempfile::tempdir().unwrap(); - let api2 = HFClientBuilder::new().cache_dir(cache_dir2.path()).build().unwrap(); + let api2 = api_with_cache(cache_dir2.path()); let snapshot_dir = api2 .model(test_model_parts().0, test_model_parts().1) .snapshot_download( @@ -698,7 +696,7 @@ async fn test_snapshot_download_by_commit_hash() { async fn test_snapshot_download_force_download() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model(test_model_parts().0, test_model_parts().1); repo.snapshot_download( @@ -731,7 +729,7 @@ async fn test_snapshot_download_force_download() { async fn test_snapshot_download_returns_correct_path() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let snapshot_dir = api .model(test_model_parts().0, test_model_parts().1) @@ -757,7 +755,7 @@ async fn test_snapshot_download_returns_correct_path() { async fn test_cache_directory_layout() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); api.model(test_model_parts().0, test_model_parts().1) .download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) @@ -784,7 +782,7 @@ async fn test_cache_directory_layout() { async fn test_blob_deduplication_across_downloads() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model(test_model_parts().0, test_model_parts().1); @@ -813,7 +811,7 @@ async fn test_blob_deduplication_across_downloads() { async fn test_dataset_repo_type_cache_folder() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let path = api .dataset(test_dataset_parts().0, test_dataset_parts().1) @@ -831,7 +829,7 @@ async fn test_download_to_local_dir_no_cache() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); let local_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let path = api .model(test_model_parts().0, test_model_parts().1) @@ -864,7 +862,7 @@ async fn test_download_to_local_dir_no_cache() { async fn test_concurrent_downloads_same_file() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let mut handles = Vec::new(); for _ in 0..4 { @@ -906,23 +904,23 @@ async fn test_concurrent_downloads_same_file() { fn test_hf_hub_cache_env_var() { let dir = tempfile::tempdir().unwrap(); // Save and set env - let old_val = std::env::var("HF_HUB_CACHE").ok(); + let old_val = std::env::var(HF_HUB_CACHE).ok(); // SAFETY: test runs serially (#[serial]) so no concurrent env access - unsafe { std::env::set_var("HF_HUB_CACHE", dir.path()) }; + unsafe { std::env::set_var(HF_HUB_CACHE, dir.path()) }; let api = HFClientBuilder::new().build().unwrap(); // Verify through a download attempt that would use the cache dir // We can't easily inspect the private field, but we can check the // builder override works by using an explicit cache_dir - let api2 = HFClientBuilder::new().cache_dir(dir.path()).build().unwrap(); + let api2 = api_with_cache(dir.path()); // Both should work without error drop(api); drop(api2); // Restore env match old_val { - Some(v) => unsafe { std::env::set_var("HF_HUB_CACHE", v) }, - None => unsafe { std::env::remove_var("HF_HUB_CACHE") }, + Some(v) => unsafe { std::env::set_var(HF_HUB_CACHE, v) }, + None => unsafe { std::env::remove_var(HF_HUB_CACHE) }, } } @@ -931,15 +929,15 @@ fn test_hf_hub_cache_env_var() { fn test_xdg_cache_home_env_var() { let dir = tempfile::tempdir().unwrap(); // Save existing env vars - let old_hub_cache = std::env::var("HF_HUB_CACHE").ok(); - let old_hf_home = std::env::var("HF_HOME").ok(); - let old_xdg = std::env::var("XDG_CACHE_HOME").ok(); + let old_hub_cache = std::env::var(HF_HUB_CACHE).ok(); + let old_hf_home = std::env::var(HF_HOME).ok(); + let old_xdg = std::env::var(XDG_CACHE_HOME).ok(); // SAFETY: test runs serially (#[serial]) so no concurrent env access unsafe { - std::env::remove_var("HF_HUB_CACHE"); - std::env::remove_var("HF_HOME"); - std::env::set_var("XDG_CACHE_HOME", dir.path()); + std::env::remove_var(HF_HUB_CACHE); + std::env::remove_var(HF_HOME); + std::env::set_var(XDG_CACHE_HOME, dir.path()); } let api = HFClientBuilder::new().build().unwrap(); @@ -949,16 +947,16 @@ fn test_xdg_cache_home_env_var() { // SAFETY: test runs serially (#[serial]) so no concurrent env access unsafe { match old_hub_cache { - Some(v) => std::env::set_var("HF_HUB_CACHE", v), - None => std::env::remove_var("HF_HUB_CACHE"), + Some(v) => std::env::set_var(HF_HUB_CACHE, v), + None => std::env::remove_var(HF_HUB_CACHE), } match old_hf_home { - Some(v) => std::env::set_var("HF_HOME", v), - None => std::env::remove_var("HF_HOME"), + Some(v) => std::env::set_var(HF_HOME, v), + None => std::env::remove_var(HF_HOME), } match old_xdg { - Some(v) => std::env::set_var("XDG_CACHE_HOME", v), - None => std::env::remove_var("XDG_CACHE_HOME"), + Some(v) => std::env::set_var(XDG_CACHE_HOME, v), + None => std::env::remove_var(XDG_CACHE_HOME), } } } @@ -1019,7 +1017,7 @@ async fn test_interop_python_downloads_first() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); let script = format!( r#" @@ -1044,7 +1042,7 @@ print(path) .unwrap(); let blob_count_before = std::fs::read_dir(repo_folder.path().join("blobs")).unwrap().count(); - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); let path = api .model(test_model_parts().0, test_model_parts().1) .download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) @@ -1067,9 +1065,9 @@ async fn test_interop_rust_downloads_first() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); api.model(test_model_parts().0, test_model_parts().1) .download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) .await @@ -1107,7 +1105,7 @@ async fn test_interop_mixed_partial_downloads() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); let script = format!( r#" @@ -1124,7 +1122,7 @@ hf_hub_download("{repo_id}", "README.md") let output = std::process::Command::new(&python).args(["-c", &script]).output().unwrap(); assert!(output.status.success()); - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); let repo = api.model(test_model_parts().0, test_model_parts().1); repo.download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) .await @@ -1173,7 +1171,7 @@ async fn test_interop_python_snapshot_rust_snapshot() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); let script = format!( r#" @@ -1202,7 +1200,7 @@ print(path) .unwrap(); let blob_count_before = std::fs::read_dir(repo_folder.path().join("blobs")).unwrap().count(); - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); let snapshot_dir = api .model(test_model_parts().0, test_model_parts().1) .snapshot_download( @@ -1229,10 +1227,10 @@ async fn test_interop_rust_writes_python_validates_cache() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Rust snapshot_download: multiple files into cache - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); let snapshot_dir = api .model(test_model_parts().0, test_model_parts().1) .snapshot_download( @@ -1329,7 +1327,7 @@ print("ALL_CHECKS_PASSED") async fn test_xet_download_to_cache() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let path = match api .model("mcpotato", "42-xet-test-repo") @@ -1406,7 +1404,7 @@ async fn test_xet_download_to_cache() { async fn test_xet_snapshot_download_to_cache() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let snapshot_dir = match api .model("mcpotato", "42-xet-test-repo") @@ -1440,7 +1438,7 @@ async fn test_xet_snapshot_download_to_cache() { async fn test_xet_cache_hit_second_download() { let Some(_) = api() else { return }; let cache_dir = tempfile::tempdir().unwrap(); - let api = HFClientBuilder::new().cache_dir(cache_dir.path()).build().unwrap(); + let api = api_with_cache(cache_dir.path()); let repo = api.model("mcpotato", "42-xet-test-repo"); @@ -1488,10 +1486,10 @@ async fn test_interop_rust_no_exist_python_reads() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Rust: trigger a 404 to create a .no_exist marker - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); let _ = api .model(test_model_parts().0, test_model_parts().1) .download_file( @@ -1543,10 +1541,10 @@ async fn test_interop_rust_ref_python_reads() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Rust: download to create refs/main - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); api.model(test_model_parts().0, test_model_parts().1) .download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) .await @@ -1585,7 +1583,7 @@ async fn test_interop_python_no_exist_rust_reads() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Python: trigger 404 to create .no_exist marker let script = format!( @@ -1608,7 +1606,7 @@ print("DONE") assert!(output.status.success(), "Python failed: {}", String::from_utf8_lossy(&output.stderr)); // Rust: local_files_only should find the .no_exist marker via resolve_from_cache_only - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); let result = api .model(test_model_parts().0, test_model_parts().1) .download_file( @@ -1635,7 +1633,7 @@ async fn test_interop_python_ref_rust_local_files_only() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Python downloads file (creates refs/main + blob + symlink) let script = format!( @@ -1654,7 +1652,7 @@ hf_hub_download("{repo_id}", "config.json") assert!(output.status.success(), "Python failed: {}", String::from_utf8_lossy(&output.stderr)); // Rust: local_files_only should find the file via Python's ref - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); let path = api .model(test_model_parts().0, test_model_parts().1) .download_file( @@ -1681,10 +1679,10 @@ async fn test_interop_rust_snapshot_python_snapshot_reuse() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Rust snapshot_download first - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); api.model(test_model_parts().0, test_model_parts().1) .snapshot_download( &RepoSnapshotDownloadParams::builder() @@ -1729,10 +1727,10 @@ async fn test_interop_dataset_repo_type() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Rust downloads a dataset file - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); api.dataset(test_dataset_parts().0, test_dataset_parts().1) .download_file(&RepoDownloadFileParams::builder().filename("README.md").build()) .await @@ -1773,7 +1771,7 @@ async fn test_interop_symlink_target_format() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Python downloads file — creates the canonical symlink format let script = format!( @@ -1797,7 +1795,7 @@ print(link) // Rust downloads same file into a fresh cache let cache_dir2 = base_dir.path().join("cache2"); std::fs::create_dir_all(&cache_dir2).unwrap(); - let api = HFClientBuilder::new().cache_dir(&cache_dir2).build().unwrap(); + let api = api_with_cache(&cache_dir2); let rust_path = api .model(test_model_parts().0, test_model_parts().1) .download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) @@ -1820,7 +1818,7 @@ async fn test_interop_conditional_request_reuse() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Python downloads file — creates blob + symlink + ref let script = format!( @@ -1851,7 +1849,7 @@ hf_hub_download("{repo_id}", "config.json") // Rust downloads same file — should read etag from Python's symlink, // send If-None-Match, get 304, and NOT rewrite the blob - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); api.model(test_model_parts().0, test_model_parts().1) .download_file(&RepoDownloadFileParams::builder().filename("config.json").build()) .await @@ -1872,10 +1870,10 @@ async fn test_interop_scan_cache_counts_match() { return; }; let python = python_bin(&venv_dir); - let token = std::env::var("HF_TOKEN").unwrap(); + let token = resolve_prod_token().expect("HF_TOKEN or HF_PROD_TOKEN required"); // Download multiple files via both libraries to populate the cache - let api = HFClientBuilder::new().cache_dir(&cache_dir).build().unwrap(); + let api = api_with_cache(&cache_dir); api.model(test_model_parts().0, test_model_parts().1) .snapshot_download( &RepoSnapshotDownloadParams::builder() diff --git a/huggingface_hub/tests/cli/cli_comparison.rs b/huggingface_hub/tests/cli/cli_comparison.rs index 4cb8595..e69c89b 100644 --- a/huggingface_hub/tests/cli/cli_comparison.rs +++ b/huggingface_hub/tests/cli/cli_comparison.rs @@ -4,61 +4,35 @@ use std::sync::OnceLock; use helpers::{CliRunner, require_cli, require_token, require_write}; -fn is_hub_ci() -> bool { - std::env::var("HF_ENDPOINT") - .ok() - .is_some_and(|v| v.contains("hub-ci.huggingface.co")) -} - fn test_model_repo() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/gpt2" - } else { - "gpt2" - } + "gpt2" } fn test_dataset_repo() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/hacker-news" - } else { - "squad" - } + "squad" } fn test_dataset_download_repo() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/hacker-news" - } else { - "xet-team/xet-spec-reference-files" - } + "xet-team/xet-spec-reference-files" } fn test_model_cache_fragment() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user--gpt2" - } else { - "gpt2" - } + "gpt2" } fn test_dataset_search() -> &'static str { - if is_hub_ci() { "hacker-news" } else { "squad" } + "squad" } fn test_hf_endpoint() -> &'static str { - if is_hub_ci() { - "https://hub-ci.huggingface.co" - } else { - "https://huggingface.co" - } + "https://huggingface.co" } /// Cached whoami username, fetched once and reused across all tests. fn whoami_username() -> &'static str { static USERNAME: OnceLock = OnceLock::new(); USERNAME.get_or_init(|| { - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let out = hfrs .run_json(&["auth", "whoami"]) .expect("whoami should succeed for test setup"); @@ -595,7 +569,7 @@ fn download_default_cache_dir_not_used_when_overridden() { fn write_repo_create_and_delete() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = format!( "hfrs-test-{}", @@ -620,7 +594,7 @@ fn write_repo_create_and_delete() { fn write_repo_create_private() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = format!( "hfrs-test-private-{}", @@ -646,7 +620,7 @@ fn write_repo_create_private() { fn write_branch_create_and_delete() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = format!( "hfrs-test-branch-{}", @@ -674,7 +648,7 @@ fn write_branch_create_and_delete() { fn write_tag_create_and_delete() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = format!( "hfrs-test-tag-{}", @@ -1504,7 +1478,7 @@ fn download_no_repo_id_fails() { fn write_upload_single_file() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-file"); let full_repo = full_repo(&repo_name); @@ -1527,7 +1501,7 @@ fn write_upload_single_file() { fn write_upload_auto_create() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-autocreate"); let full_repo = full_repo(&repo_name); @@ -1549,7 +1523,7 @@ fn write_upload_auto_create() { fn write_upload_private_auto_create() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-private"); let full_repo = full_repo(&repo_name); @@ -1571,7 +1545,7 @@ fn write_upload_private_auto_create() { fn write_upload_path_in_repo() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-path"); let full_repo = full_repo(&repo_name); @@ -1603,7 +1577,7 @@ fn write_upload_path_in_repo() { fn write_upload_commit_message_and_description() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-commit"); let full_repo = full_repo(&repo_name); @@ -1632,7 +1606,7 @@ fn write_upload_commit_message_and_description() { fn write_upload_create_pr() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-pr"); let full_repo = full_repo(&repo_name); @@ -1661,7 +1635,7 @@ fn write_upload_create_pr() { fn write_upload_to_branch() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-branch"); let full_repo = full_repo(&repo_name); @@ -1690,7 +1664,7 @@ fn write_upload_to_branch() { fn write_upload_quiet() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-quiet"); let full_repo = full_repo(&repo_name); @@ -1714,7 +1688,7 @@ fn write_upload_quiet() { fn write_upload_nonexistent_path_fails() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-nopath"); let full_repo = full_repo(&repo_name); @@ -1734,7 +1708,7 @@ fn write_upload_nonexistent_path_fails() { fn write_upload_folder() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-folder"); let full_repo = full_repo(&repo_name); @@ -1773,7 +1747,7 @@ fn write_upload_folder() { fn write_upload_folder_include() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-include"); let full_repo = full_repo(&repo_name); @@ -1814,7 +1788,7 @@ fn write_upload_folder_include() { fn write_upload_folder_exclude() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-exclude"); let full_repo = full_repo(&repo_name); @@ -1855,7 +1829,7 @@ fn write_upload_folder_exclude() { fn write_upload_folder_delete() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-delete"); let full_repo = full_repo(&repo_name); @@ -1896,7 +1870,7 @@ fn write_upload_folder_delete() { fn write_upload_folder_path_in_repo() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-folder-path"); let full_repo = full_repo(&repo_name); @@ -1928,7 +1902,7 @@ fn write_upload_folder_path_in_repo() { fn write_upload_empty_excluded_folder() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-empty"); let full_repo = full_repo(&repo_name); @@ -1970,7 +1944,7 @@ fn upload_no_repo_id_fails() { fn write_upload_dataset_type() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-dataset"); let full_repo = full_repo(&repo_name); @@ -1992,7 +1966,7 @@ fn write_upload_dataset_type() { fn write_upload_large_file() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-large"); let full_repo = full_repo(&repo_name); @@ -2002,8 +1976,9 @@ fn write_upload_large_file() { let tmp = tempfile::tempdir().unwrap(); let file_path = tmp.path().join("large.bin"); - // Create 11MB file (above typical LFS threshold of 10MB) - let data = vec![0u8; 11 * 1024 * 1024]; + // Create 11MB random file (above typical LFS threshold of 10MB) + let mut data = vec![0u8; 11 * 1024 * 1024]; + rand::Fill::fill(&mut data[..], &mut rand::rng()); std::fs::write(&file_path, &data).unwrap(); let result = hfrs.run_raw(&["upload", &full_repo, file_path.to_str().unwrap()]); @@ -2016,7 +1991,7 @@ fn write_upload_large_file() { fn write_upload_special_chars() { require_token(); require_write(); - let hfrs = CliRunner::hfrs(); + let hfrs = CliRunner::hfrs_ci(); let repo_name = unique_repo_name("hfrs-upload-special"); let full_repo = full_repo(&repo_name); @@ -2156,3 +2131,152 @@ fn exit_codes() { let (code, _, _) = hfrs.run_full(&["models", "info", "nonexistent-xyz-12345"]).unwrap(); assert_ne!(code, 0, "failed command should exit non-zero"); } + +// ============================================================================= +// Signal handling tests (xet upload/download abort via SIGINT) +// ============================================================================= + +#[cfg(unix)] +#[test] +fn signal_abort_during_xet_upload() { + use std::time::{Duration, Instant}; + + require_token(); + require_write(); + let hfrs = CliRunner::hfrs_ci(); + + let repo_name = unique_repo_name("hfrs-signal-upload"); + let full_repo = full_repo(&repo_name); + + hfrs.run_raw(&["repos", "create", &repo_name]).expect("repo creation"); + + let tmp = tempfile::tempdir().unwrap(); + let file_path = tmp.path().join("large_signal_test.bin"); + + // 50MB random file — large enough that upload takes a few seconds + let mut data = vec![0u8; 50 * 1024 * 1024]; + rand::Fill::fill(&mut data[..], &mut rand::rng()); + std::fs::write(&file_path, &data).unwrap(); + drop(data); + + let mut child = hfrs + .spawn(&["upload", &full_repo, file_path.to_str().unwrap()]) + .expect("failed to spawn upload"); + + let pid = child.id(); + + std::thread::sleep(Duration::from_millis(500)); + + // Check if process already finished before we could signal it + if let Ok(Some(_status)) = child.try_wait() { + eprintln!("upload finished before SIGINT could be sent — skipping abort assertion"); + let _ = hfrs.run_raw(&["repos", "delete", &full_repo]); + return; + } + + unsafe { + libc::kill(pid as libc::pid_t, libc::SIGINT); + } + + let start = Instant::now(); + let timeout = Duration::from_secs(30); + let status = loop { + match child.try_wait() { + Ok(Some(status)) => break status, + Ok(None) => { + if start.elapsed() > timeout { + let _ = child.kill(); + let _ = child.wait(); + panic!("CLI did not exit within {timeout:?} after SIGINT"); + } + std::thread::sleep(Duration::from_millis(100)); + }, + Err(e) => panic!("error waiting for child: {e}"), + } + }; + + assert!(!status.success(), "CLI should exit non-zero after SIGINT, got: {status}"); + + let _ = hfrs.run_raw(&["repos", "delete", &full_repo]); +} + +#[cfg(unix)] +#[test] +fn signal_abort_during_xet_download() { + use std::time::{Duration, Instant}; + + require_token(); + require_write(); + let hfrs = CliRunner::hfrs_ci(); + + // First: create a repo with a large xet file + let repo_name = unique_repo_name("hfrs-signal-download"); + let full_repo = full_repo(&repo_name); + + hfrs.run_raw(&["repos", "create", &repo_name]).expect("repo creation"); + + let tmp_upload = tempfile::tempdir().unwrap(); + let upload_path = tmp_upload.path().join("large_for_download.bin"); + + let mut data = vec![0u8; 50 * 1024 * 1024]; + rand::Fill::fill(&mut data[..], &mut rand::rng()); + std::fs::write(&upload_path, &data).unwrap(); + drop(data); + + let upload_result = hfrs.run_raw(&["upload", &full_repo, upload_path.to_str().unwrap()]); + if upload_result.is_err() { + let _ = hfrs.run_raw(&["repos", "delete", &full_repo]); + panic!("upload failed, cannot test download abort: {:?}", upload_result.err()); + } + drop(tmp_upload); + + // Now download and send SIGINT mid-transfer. + // Use a short delay so the process is still transferring when the signal + // arrives. If the download completes before we can signal, skip the + // assertion — we cannot test the abort path on very fast networks. + let tmp_download = tempfile::tempdir().unwrap(); + let mut child = hfrs + .spawn(&[ + "download", + &full_repo, + "--local-dir", + tmp_download.path().to_str().unwrap(), + ]) + .expect("failed to spawn download"); + + let pid = child.id(); + + std::thread::sleep(Duration::from_millis(500)); + + // Check if process already finished before we could signal it + if let Ok(Some(_status)) = child.try_wait() { + eprintln!("download finished before SIGINT could be sent — skipping abort assertion"); + let _ = hfrs.run_raw(&["repos", "delete", &full_repo]); + return; + } + + unsafe { + libc::kill(pid as libc::pid_t, libc::SIGINT); + } + + let start = Instant::now(); + let timeout = Duration::from_secs(30); + let status = loop { + match child.try_wait() { + Ok(Some(status)) => break status, + Ok(None) => { + if start.elapsed() > timeout { + let _ = child.kill(); + let _ = child.wait(); + panic!("CLI did not exit within {timeout:?} after SIGINT"); + } + std::thread::sleep(Duration::from_millis(100)); + }, + Err(e) => panic!("error waiting for child: {e}"), + } + }; + + assert!(!status.success(), "CLI should exit non-zero after SIGINT during download, got: {status}"); + + let _ = hfrs.run_raw(&["repos", "delete", &full_repo]); +} diff --git a/huggingface_hub/tests/cli/helpers.rs b/huggingface_hub/tests/cli/helpers.rs index 5c42051..c808b7e 100644 --- a/huggingface_hub/tests/cli/helpers.rs +++ b/huggingface_hub/tests/cli/helpers.rs @@ -1,6 +1,8 @@ use std::process::Command; use std::time::Duration; +use huggingface_hub::test_utils; + pub struct CliRunner { bin: String, bin_path: Option, @@ -20,15 +22,47 @@ impl CliRunner { } } + /// Default runner — targets production (hardcoded repos). + /// In CI: uses HF_PROD_TOKEN and overrides HF_ENDPOINT to huggingface.co. + /// Locally: uses HF_TOKEN with default endpoint. pub fn hfrs() -> Self { + let is_ci = test_utils::is_ci(); + let token = if is_ci { + std::env::var(test_utils::HF_PROD_TOKEN).ok() + } else { + std::env::var(test_utils::HF_TOKEN).ok() + }; + let mut extra_env = vec![ + ("RUST_LOG".to_string(), "info".to_string()), + ("HF_LOG_LEVEL".to_string(), "info".to_string()), + ]; + if is_ci { + extra_env.push((test_utils::HF_ENDPOINT.to_string(), test_utils::PROD_ENDPOINT.to_string())); + } Self { bin: "hfrs".to_string(), bin_path: Some(env!("CARGO_BIN_EXE_hfrs").to_string()), - token: std::env::var("HF_TOKEN").ok(), - extra_env: vec![ - ("RUST_LOG".to_string(), "info".to_string()), - ("HF_LOG_LEVEL".to_string(), "info".to_string()), - ], + token, + extra_env, + env_remove: Vec::new(), + } + } + + /// Runner for write tests (hub-ci in CI, default endpoint locally). + pub fn hfrs_ci() -> Self { + let token = test_utils::resolve_hub_ci_token(); + let mut extra_env = vec![ + ("RUST_LOG".to_string(), "info".to_string()), + ("HF_LOG_LEVEL".to_string(), "info".to_string()), + ]; + if test_utils::is_ci() { + extra_env.push((test_utils::HF_ENDPOINT.to_string(), test_utils::HUB_CI_ENDPOINT.to_string())); + } + Self { + bin: "hfrs".to_string(), + bin_path: Some(env!("CARGO_BIN_EXE_hfrs").to_string()), + token, + extra_env, env_remove: Vec::new(), } } @@ -166,6 +200,17 @@ impl CliRunner { let stderr = String::from_utf8(output.stderr)?; Ok((code, stdout, stderr)) } + + /// Spawn the command as a child process, returning the handle. + /// The caller is responsible for waiting/killing. + pub fn spawn(&self, args: &[&str]) -> anyhow::Result { + let mut cmd = self.build_command(args, &[]); + let child = cmd + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn()?; + Ok(child) + } } pub fn require_cli(runner: &CliRunner) { @@ -175,13 +220,13 @@ pub fn require_cli(runner: &CliRunner) { } pub fn require_token() { - if std::env::var("HF_TOKEN").is_err() { - panic!("HF_TOKEN environment variable is required for integration tests."); + if std::env::var(test_utils::HF_TOKEN).is_err() && std::env::var(test_utils::HF_CI_TOKEN).is_err() { + panic!("HF_TOKEN or HF_CI_TOKEN environment variable is required for integration tests."); } } pub fn require_write() { - if std::env::var("HF_TEST_WRITE").is_err() { + if !test_utils::write_enabled() { panic!("HF_TEST_WRITE=1 is required for write operation tests."); } } diff --git a/huggingface_hub/tests/download_test.rs b/huggingface_hub/tests/download_test.rs index 40a8e3e..3d84e45 100644 --- a/huggingface_hub/tests/download_test.rs +++ b/huggingface_hub/tests/download_test.rs @@ -7,36 +7,34 @@ use futures::StreamExt; use huggingface_hub::repository::HFRepository; +use huggingface_hub::test_utils::*; use huggingface_hub::{HFClient, HFClientBuilder, RepoDownloadFileParams, RepoDownloadFileStreamParams}; use sha2::{Digest, Sha256}; fn api() -> Option { - if std::env::var("HF_TOKEN").is_err() { - return None; + if is_ci() { + let token = resolve_prod_token()?; + Some( + HFClientBuilder::new() + .token(token) + .endpoint(PROD_ENDPOINT) + .build() + .expect("Failed to create HFClient"), + ) + } else { + if std::env::var(HF_TOKEN).is_err() { + return None; + } + Some(HFClientBuilder::new().build().expect("Failed to create HFClient")) } - Some(HFClientBuilder::new().build().expect("Failed to create HFClient")) -} - -fn is_hub_ci() -> bool { - std::env::var("HF_ENDPOINT") - .ok() - .is_some_and(|v| v.contains("hub-ci.huggingface.co")) } fn test_model_parts() -> (&'static str, &'static str) { - if is_hub_ci() { - ("huggingface-hub-rust-test-user", "gpt2") - } else { - ("openai-community", "gpt2") - } + ("openai-community", "gpt2") } fn test_dataset_parts() -> (&'static str, &'static str) { - if is_hub_ci() { - ("huggingface-hub-rust-test-user", "hacker-news") - } else { - ("rajpurkar", "squad") - } + ("rajpurkar", "squad") } fn model(api: &HFClient, owner: &str, name: &str) -> HFRepository { diff --git a/huggingface_hub/tests/integration_test.rs b/huggingface_hub/tests/integration_test.rs index d3daa6e..e9f28b8 100644 --- a/huggingface_hub/tests/integration_test.rs +++ b/huggingface_hub/tests/integration_test.rs @@ -1,98 +1,95 @@ //! Integration tests against the live Hugging Face Hub API. //! +//! ## Local development +//! //! Read-only tests: require HF_TOKEN, skip if not set. //! Write tests: require HF_TOKEN + HF_TEST_WRITE=1, skip otherwise. //! //! Run read-only: HF_TOKEN=hf_xxx cargo test -p huggingface-hub --test integration_test //! Run all: HF_TOKEN=hf_xxx HF_TEST_WRITE=1 cargo test -p huggingface-hub --test integration_test //! +//! ## CI (GITHUB_ACTIONS=true) +//! +//! Read-only tests use HF_PROD_TOKEN against https://huggingface.co. +//! Write tests use HF_CI_TOKEN against https://hub-ci.huggingface.co. +//! //! Feature-gated tests: enable with --features, e.g.: //! HF_TOKEN=hf_xxx cargo test -p huggingface-hub --all-features --test integration_test use futures::StreamExt; -use huggingface_hub::repository::{ - HFRepository, RepoCreateBranchParams, RepoCreateCommitParams, RepoCreateTagParams, RepoDeleteBranchParams, - RepoDeleteFileParams, RepoDeleteFolderParams, RepoDeleteTagParams, RepoDownloadFileParams, RepoFileExistsParams, - RepoGetCommitDiffParams, RepoGetPathsInfoParams, RepoGetRawDiffParams, RepoInfoParams, RepoListCommitsParams, - RepoListFilesParams, RepoListRefsParams, RepoListTreeParams, RepoRevisionExistsParams, RepoUpdateSettingsParams, - RepoUploadFileParams, RepoUploadFolderParams, -}; +use huggingface_hub::repository::HFRepository; +use huggingface_hub::test_utils::*; use huggingface_hub::types::*; use huggingface_hub::{HFClient, HFClientBuilder}; #[cfg(feature = "spaces")] use huggingface_hub::{SpaceSecretDeleteParams, SpaceSecretParams, SpaceVariableDeleteParams, SpaceVariableParams}; fn api() -> Option { - if std::env::var("HF_TOKEN").is_err() { - return None; + if is_ci() { + let token = std::env::var(HF_CI_TOKEN).ok()?; + Some(build_client(&token, HUB_CI_ENDPOINT)) + } else { + default_api() + } +} + +fn prod_api() -> Option { + if is_ci() { + let token = resolve_prod_token()?; + Some(build_client(&token, PROD_ENDPOINT)) + } else { + default_api() } - Some(HFClientBuilder::new().build().expect("Failed to create HFClient")) } -fn write_enabled() -> bool { - std::env::var("HF_TEST_WRITE").ok().is_some_and(|v| v == "1") +fn default_api() -> Option { + let token = std::env::var(HF_TOKEN).ok()?; + let endpoint = std::env::var(HF_ENDPOINT).unwrap_or_else(|_| PROD_ENDPOINT.to_string()); + Some(build_client(&token, &endpoint)) } -fn is_hub_ci() -> bool { - std::env::var("HF_ENDPOINT") - .ok() - .is_some_and(|v| v.contains("hub-ci.huggingface.co")) +fn build_client(token: &str, endpoint: &str) -> HFClient { + HFClientBuilder::new() + .token(token) + .endpoint(endpoint) + .build() + .expect("Failed to create HFClient") } fn test_org() -> &'static str { - if is_hub_ci() { "valid_org" } else { "huggingface" } + "huggingface" } fn test_user() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user" - } else { - "julien-c" - } + "julien-c" } fn test_model_author() -> &'static str { - if is_hub_ci() { "valid_org" } else { "openai-community" } + "openai-community" } fn test_model_repo() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/gpt2" - } else { - "openai-community/gpt2" - } + "openai-community/gpt2" } fn test_space_repo() -> (&'static str, &'static str) { - if is_hub_ci() { - ("huggingface-hub-rust-test-user", "test-space") - } else { - ("huggingface-projects", "diffusers-gallery") - } + ("huggingface-projects", "diffusers-gallery") } fn test_space_info_repo() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/test-space" - } else { - "HuggingFaceFW/blogpost-fineweb-v1" - } + "HuggingFaceFW/blogpost-fineweb-v1" } fn test_dataset_repo() -> &'static str { - if is_hub_ci() { - "huggingface-hub-rust-test-user/hacker-news" - } else { - "xet-team/xet-spec-reference-files" - } + "xet-team/xet-spec-reference-files" } -/// Cached whoami username, fetched once and reused across all tests. +/// Cached whoami username, fetched once and reused across write tests. async fn cached_username() -> &'static str { static USERNAME: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); USERNAME .get_or_init(|| async { - let client = HFClientBuilder::new().build().expect("Failed to create HFClient for whoami"); + let client = api().expect("API client required for cached_username"); client.whoami().await.expect("whoami failed").username }) .await @@ -121,7 +118,7 @@ fn repo_typed(api: &HFClient, repo_id: &str, repo_type: RepoType) -> HFRepositor #[tokio::test] async fn test_model_info() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let model_repo = test_model_repo(); let info = repo(&api, model_repo).info(&RepoInfoParams::default()).await.unwrap(); match info { @@ -132,7 +129,7 @@ async fn test_model_info() { #[tokio::test] async fn test_repo_handle_info_and_file_exists() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let model_repo = test_model_repo(); let repo = repo(&api, model_repo); @@ -151,7 +148,7 @@ async fn test_repo_handle_info_and_file_exists() { #[tokio::test] async fn test_dataset_info() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dataset_repo = test_dataset_repo(); let info = repo_typed(&api, dataset_repo, RepoType::Dataset) .info(&RepoInfoParams::default()) @@ -165,14 +162,14 @@ async fn test_dataset_info() { #[tokio::test] async fn test_repo_exists() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; assert!(repo(&api, test_model_repo()).exists().await.unwrap()); assert!(!repo(&api, "this-repo-definitely-does-not-exist-12345").exists().await.unwrap()); } #[tokio::test] async fn test_file_exists() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let model_repo = test_model_repo(); assert!( repo(&api, model_repo) @@ -191,7 +188,7 @@ async fn test_file_exists() { #[tokio::test] async fn test_list_models() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let author = test_model_author(); let params = ListModelsParams::builder().author(author).limit(3_usize).build(); let stream = api.list_models(¶ms).unwrap(); @@ -208,7 +205,7 @@ async fn test_list_models() { #[tokio::test] async fn test_list_repo_files() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let files = repo(&api, test_model_repo()) .list_files(&RepoListFilesParams::default()) .await @@ -219,7 +216,7 @@ async fn test_list_repo_files() { #[tokio::test] async fn test_list_repo_tree() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let r = repo(&api, test_model_repo()); let stream = r.list_tree(&RepoListTreeParams::default()).unwrap(); futures::pin_mut!(stream); @@ -227,11 +224,11 @@ async fn test_list_repo_tree() { let mut found_config = false; while let Some(entry) = stream.next().await { let entry = entry.unwrap(); - if let RepoTreeEntry::File { path, .. } = &entry { - if path == "config.json" { - found_config = true; - break; - } + if let RepoTreeEntry::File { path, .. } = &entry + && path == "config.json" + { + found_config = true; + break; } } assert!(found_config); @@ -239,7 +236,7 @@ async fn test_list_repo_tree() { #[tokio::test] async fn test_list_repo_commits() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let r = repo(&api, test_model_repo()); let stream = r.list_commits(&RepoListCommitsParams::default()).unwrap(); futures::pin_mut!(stream); @@ -251,7 +248,7 @@ async fn test_list_repo_commits() { #[tokio::test] async fn test_list_repo_refs() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let refs = repo(&api, test_model_repo()) .list_refs(&RepoListRefsParams::default()) .await @@ -263,7 +260,7 @@ async fn test_list_repo_refs() { #[tokio::test] async fn test_revision_exists() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let model_repo = test_model_repo(); assert!( repo(&api, model_repo) @@ -282,7 +279,7 @@ async fn test_revision_exists() { #[tokio::test] async fn test_download_file() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let path = repo(&api, test_model_repo()) .download_file( @@ -303,20 +300,20 @@ async fn test_download_file() { #[tokio::test] async fn test_whoami() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let user = api.whoami().await.unwrap(); assert!(!user.username.is_empty()); } #[tokio::test] async fn test_auth_check() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; api.auth_check().await.unwrap(); } #[tokio::test] async fn test_get_user_overview() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let username = test_user(); let user = api.get_user_overview(username).await.unwrap(); assert_eq!(user.username, username); @@ -324,7 +321,7 @@ async fn test_get_user_overview() { #[tokio::test] async fn test_get_organization_overview() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let org_name = test_org(); let org = api.get_organization_overview(org_name).await.unwrap(); assert_eq!(org.name, org_name); @@ -332,7 +329,7 @@ async fn test_get_organization_overview() { #[tokio::test] async fn test_list_user_followers() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let stream = api.list_user_followers(test_user(), None).unwrap(); futures::pin_mut!(stream); let first = stream.next().await; @@ -342,7 +339,7 @@ async fn test_list_user_followers() { #[tokio::test] async fn test_list_user_following() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let stream = api.list_user_following(test_user(), None).unwrap(); futures::pin_mut!(stream); let first = stream.next().await; @@ -352,7 +349,7 @@ async fn test_list_user_following() { #[tokio::test] async fn test_list_organization_members() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let stream = api.list_organization_members(test_org(), None).unwrap(); futures::pin_mut!(stream); let first = stream.next().await; @@ -364,7 +361,7 @@ async fn test_list_organization_members() { #[tokio::test] async fn test_space_info() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let space_repo = test_space_info_repo(); let info = repo_typed(&api, space_repo, RepoType::Space) .info(&RepoInfoParams::default()) @@ -378,7 +375,7 @@ async fn test_space_info() { #[tokio::test] async fn test_list_datasets() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let params = ListDatasetsParams::builder().author(test_org()).limit(3_usize).build(); let stream = api.list_datasets(¶ms).unwrap(); futures::pin_mut!(stream); @@ -393,7 +390,7 @@ async fn test_list_datasets() { #[tokio::test] async fn test_list_spaces() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let params = ListSpacesParams::builder().author(test_org()).limit(3_usize).build(); let stream = api.list_spaces(¶ms).unwrap(); futures::pin_mut!(stream); @@ -410,7 +407,7 @@ async fn test_list_spaces() { #[tokio::test] async fn test_get_paths_info() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let entries = repo(&api, test_model_repo()) .get_paths_info( &RepoGetPathsInfoParams::builder() @@ -435,7 +432,7 @@ async fn test_get_paths_info() { #[tokio::test] async fn test_get_commit_diff() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let gpt2 = repo(&api, test_model_repo()); let stream = gpt2.list_commits(&RepoListCommitsParams::default()).unwrap(); @@ -457,7 +454,7 @@ async fn test_get_commit_diff() { #[tokio::test] async fn test_get_raw_diff() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let gpt2 = repo(&api, test_model_repo()); let stream = gpt2.list_commits(&RepoListCommitsParams::default()).unwrap(); @@ -825,7 +822,7 @@ async fn test_move_repo() { #[cfg(feature = "spaces")] #[tokio::test] async fn test_get_space_runtime() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let (owner, name) = test_space_repo(); let space = api.space(owner, name); let runtime = space.runtime().await.unwrap(); @@ -835,11 +832,14 @@ async fn test_get_space_runtime() { #[cfg(feature = "spaces")] #[tokio::test] async fn test_duplicate_space() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; if !write_enabled() { return; } - let username = cached_username().await; + // Must use prod_api because the source space only exists on production. + // Cannot reuse cached_username() here — it resolves via api() which targets + // hub-ci in CI, a different user than the prod token. + let username = api.whoami().await.expect("whoami failed").username; let to_id = format!("{}/hub-rust-test-dup-space-{}", username, uuid_v4_short()); let params = DuplicateSpaceParams::builder() diff --git a/huggingface_hub/tests/xet_transfer_test.rs b/huggingface_hub/tests/xet_transfer_test.rs index 165fe5a..ca35bfb 100644 --- a/huggingface_hub/tests/xet_transfer_test.rs +++ b/huggingface_hub/tests/xet_transfer_test.rs @@ -19,6 +19,7 @@ use std::sync::atomic::{AtomicU32, Ordering}; use std::time::{SystemTime, UNIX_EPOCH}; use futures::StreamExt; +use huggingface_hub::test_utils::*; use huggingface_hub::types::{AddSource, CreateRepoParams, DeleteRepoParams}; use huggingface_hub::{ HFClient, HFClientBuilder, HFRepository, RepoDownloadFileParams, RepoDownloadFileStreamParams, @@ -30,15 +31,38 @@ use tokio::sync::OnceCell; static WHOAMI_USERNAME: OnceCell = OnceCell::const_new(); +/// Client for write tests — hub-ci in CI, default endpoint locally. fn api() -> Option { - if std::env::var("HF_TOKEN").is_err() { - return None; + if is_ci() { + let token = std::env::var(HF_CI_TOKEN).ok()?; + Some( + HFClientBuilder::new() + .token(token) + .endpoint(HUB_CI_ENDPOINT) + .build() + .expect("Failed to create HFClient"), + ) + } else { + let token = std::env::var(HF_TOKEN).ok()?; + Some(HFClientBuilder::new().token(token).build().expect("Failed to create HFClient")) } - Some(HFClientBuilder::new().build().expect("Failed to create HFClient")) } -fn write_enabled() -> bool { - std::env::var("HF_TEST_WRITE").ok().is_some_and(|v| v == "1") +/// Client for read-only tests against hardcoded production repos. +fn prod_api() -> Option { + if is_ci() { + let token = resolve_prod_token()?; + Some( + HFClientBuilder::new() + .token(token) + .endpoint(PROD_ENDPOINT) + .build() + .expect("Failed to create HFClient"), + ) + } else { + let token = std::env::var(HF_TOKEN).ok()?; + Some(HFClientBuilder::new().token(token).build().expect("Failed to create HFClient")) + } } static COUNTER: AtomicU32 = AtomicU32::new(0); @@ -295,7 +319,7 @@ async fn test_upload_from_file_path() { #[tokio::test] async fn test_download_from_known_xet_repo() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let result = repo_handle(&api, "mcpotato", "42-xet-test-repo") @@ -382,7 +406,7 @@ async fn test_upload_200mb_random_data_and_verify() { #[tokio::test] async fn test_xet_download_stream_full() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let repo = repo_handle(&api, "mcpotato", "42-xet-test-repo"); @@ -415,7 +439,7 @@ async fn test_xet_download_stream_full() { #[tokio::test] async fn test_xet_download_stream_range() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let repo = repo_handle(&api, "mcpotato", "42-xet-test-repo"); @@ -452,7 +476,7 @@ async fn test_xet_download_stream_range() { #[tokio::test] async fn test_xet_download_stream_range_middle() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let repo = repo_handle(&api, "mcpotato", "42-xet-test-repo");