diff --git a/huggingface_hub/src/api/files.rs b/huggingface_hub/src/api/files.rs index 09d27aa..a757239 100644 --- a/huggingface_hub/src/api/files.rs +++ b/huggingface_hub/src/api/files.rs @@ -171,8 +171,9 @@ impl HfApi { // Determine which files should be uploaded via xet (LFS) vs inline // (regular). Files uploaded via xet are referenced by their SHA256 OID // in the commit NDJSON. - let lfs_uploaded: HashMap = - self.preupload_and_upload_lfs_files(params, revision).await?; + let lfs_uploaded: HashMap = self + .preupload_and_upload_lfs_files(params, revision, params.progress_callback.as_ref()) + .await?; let mut ndjson_lines: Vec> = Vec::new(); @@ -187,6 +188,11 @@ impl HfApi { ndjson_lines.push(serde_json::to_vec(&header_line)?); for op in ¶ms.operations { + let path_in_repo = match op { + CommitOperation::Add { path_in_repo, .. } => path_in_repo, + CommitOperation::Delete { path_in_repo } => path_in_repo, + }; + let is_lfs = lfs_uploaded.contains_key(path_in_repo); let line = match op { CommitOperation::Add { path_in_repo, source } => { if let Some((oid, size)) = lfs_uploaded.get(path_in_repo) { @@ -211,6 +217,13 @@ impl HfApi { }, }; ndjson_lines.push(serde_json::to_vec(&line)?); + + // Call progress callback for non-LFS files (LFS files already triggered callback during upload) + if !is_lfs { + if let Some(ref callback) = params.progress_callback { + callback(path_in_repo); + } + } } let body: Vec = ndjson_lines @@ -448,6 +461,7 @@ impl HfApi { &self, params: &CreateCommitParams, revision: &str, + progress_callback: Option<&crate::types::CommitProgressCallback>, ) -> Result> { let add_ops: Vec<(&String, &AddSource)> = params .operations @@ -497,12 +511,13 @@ impl HfApi { // LFS files require xet upload — fail if the feature is not enabled #[cfg(not(feature = "xet"))] { - let _ = lfs_files; + let _ = (lfs_files, progress_callback); Err(HfError::XetNotEnabled) } #[cfg(feature = "xet")] - self.upload_lfs_files_via_xet(params, revision, &lfs_files).await + self.upload_lfs_files_via_xet(params, revision, &lfs_files, progress_callback) + .await } /// Call the Hub preupload endpoint to determine upload mode per file. @@ -558,6 +573,7 @@ impl HfApi { params: &CreateCommitParams, revision: &str, lfs_files: &[&(String, u64, Vec, &AddSource)], + progress_callback: Option<&crate::types::CommitProgressCallback>, ) -> Result> { // Step 4: Compute SHA256 for LFS files let mut lfs_with_sha: Vec<(String, u64, String, &AddSource)> = Vec::new(); @@ -583,7 +599,8 @@ impl HfApi { .map(|(path, _, _, source)| (path.clone(), (*source).clone())) .collect(); - crate::xet::xet_upload(self, &xet_files, ¶ms.repo_id, params.repo_type, revision).await?; + crate::xet::xet_upload(self, &xet_files, ¶ms.repo_id, params.repo_type, revision, progress_callback) + .await?; let result: HashMap = lfs_with_sha .into_iter() diff --git a/huggingface_hub/src/types/params.rs b/huggingface_hub/src/types/params.rs index 4441c3c..755ba0b 100644 --- a/huggingface_hub/src/types/params.rs +++ b/huggingface_hub/src/types/params.rs @@ -1,10 +1,15 @@ use std::path::PathBuf; +use std::sync::Arc; use typed_builder::TypedBuilder; use super::commit::{AddSource, CommitOperation}; use super::repo::RepoType; +/// Callback function invoked after each operation is processed during a commit. +/// The argument is the path of the file that was just processed. +pub type CommitProgressCallback = Arc; + #[derive(TypedBuilder)] pub struct ModelInfoParams { #[builder(setter(into))] @@ -304,6 +309,8 @@ pub struct CreateCommitParams { pub create_pr: Option, #[builder(default, setter(into, strip_option))] pub parent_commit: Option, + #[builder(default, setter(strip_option))] + pub progress_callback: Option, } #[derive(TypedBuilder)] diff --git a/huggingface_hub/src/xet.rs b/huggingface_hub/src/xet.rs index a91f507..1b73bb0 100644 --- a/huggingface_hub/src/xet.rs +++ b/huggingface_hub/src/xet.rs @@ -148,6 +148,7 @@ pub(crate) async fn xet_upload( repo_id: &str, repo_type: Option, revision: &str, + progress_callback: Option<&crate::types::CommitProgressCallback>, ) -> Result> { let session = api.get_or_init_xet_session("write", repo_id, repo_type, revision).await?; @@ -158,7 +159,7 @@ pub(crate) async fn xet_upload( let mut task_ids_in_order = Vec::with_capacity(files.len()); - for (_path_in_repo, source) in files { + for (path_in_repo, source) in files { let handle = match source { AddSource::File(path) => commit .upload_from_path(path.clone(), Sha256Policy::Compute) @@ -170,6 +171,10 @@ pub(crate) async fn xet_upload( .map_err(|e| HfError::Other(format!("Xet upload failed: {e}")))?, }; task_ids_in_order.push(handle.task_id); + + if let Some(callback) = progress_callback { + callback(path_in_repo); + } } let results = commit