diff --git a/huggingface_hub/src/api/buckets.rs b/huggingface_hub/src/api/buckets.rs new file mode 100644 index 0000000..ed722a8 --- /dev/null +++ b/huggingface_hub/src/api/buckets.rs @@ -0,0 +1,586 @@ +use std::collections::VecDeque; + +use futures::Stream; +use reqwest::header::CONTENT_TYPE; +use url::Url; + +use crate::buckets::HFBucket; +use crate::error::{HFError, NotFoundContext}; +use crate::pagination::parse_link_header_next; +use crate::types::{ + BatchOp, BatchResult, BucketCreated, BucketOverview, CreateBucketParams, ListTreeParams, ResolvedFile, TreeEntry, + UpdateBucketParams, XetToken, +}; +use crate::{HFClient, Result}; + +/// Maps HTTP status codes to `HFError` variants for bucket API responses. +/// Bucket-level 404s map to `BucketNotFound`; file-level 404s map to `EntryNotFound`. +pub(crate) async fn check_bucket_response( + response: reqwest::Response, + bucket_name: &str, + not_found_ctx: NotFoundContext, +) -> Result { + if response.status().is_success() { + return Ok(response); + } + let status = response.status(); + let url = response.url().to_string(); + let body = response.text().await.unwrap_or_default(); + Err(match status.as_u16() { + 401 => HFError::AuthRequired, + 403 => HFError::Forbidden, + 404 => match not_found_ctx { + NotFoundContext::Bucket => HFError::BucketNotFound { + bucket_name: bucket_name.to_string(), + }, + NotFoundContext::Entry { path } => HFError::EntryNotFound { + path, + repo_id: bucket_name.to_string(), + }, + _ => HFError::Http { status, url, body }, + }, + 409 => HFError::Conflict(body), + 429 => HFError::RateLimited, + _ => HFError::Http { status, url, body }, + }) +} + +impl HFBucket { + fn bucket_name(&self) -> String { + format!("{}/{}", self.namespace, self.bucket) + } + + fn bucket_url(&self) -> String { + format!("{}/api/buckets/{}/{}", self.client.inner.endpoint, self.namespace, self.bucket) + } + + /// Returns metadata about this bucket. + pub async fn info(&self) -> Result { + let resp = self + .client + .inner + .client + .get(self.bucket_url()) + .headers(self.client.auth_headers()) + .send() + .await?; + let resp = check_bucket_response(resp, &self.bucket_name(), NotFoundContext::Bucket).await?; + Ok(resp.json().await?) + } + + /// Updates visibility or CDN configuration for this bucket. + pub async fn update_settings(&self, params: &UpdateBucketParams) -> Result<()> { + let resp = self + .client + .inner + .client + .put(format!("{}/settings", self.bucket_url())) + .headers(self.client.auth_headers()) + .json(params) + .send() + .await?; + check_bucket_response(resp, &self.bucket_name(), NotFoundContext::Bucket).await?; + Ok(()) + } + + /// Adds and/or removes files in a single atomic operation. + /// + /// All `AddFile` operations are sent before `DeleteFile` operations, as required + /// by the batch protocol. The input order within each group is preserved. + pub async fn batch_files(&self, ops: &[BatchOp]) -> Result { + let (adds, deletes): (Vec<_>, Vec<_>) = ops.iter().partition(|op| matches!(op, BatchOp::AddFile(_))); + + let ndjson = adds + .iter() + .chain(deletes.iter()) + .map(|op| serde_json::to_string(op).map(|s| s + "\n")) + .collect::>()?; + + let resp = self + .client + .inner + .client + .post(format!("{}/batch", self.bucket_url())) + .headers(self.client.auth_headers()) + .header(CONTENT_TYPE, "application/x-ndjson") + .body(ndjson) + .send() + .await?; + + let resp = check_bucket_response(resp, &self.bucket_name(), NotFoundContext::Bucket).await?; + Ok(resp.json().await?) + } + + /// Lists files and directories, yielding one entry at a time. + /// + /// Uses cursor-in-body pagination: the stream fetches the next page automatically + /// when the current page's entries are exhausted. No request is made until the + /// first item is polled. + pub fn list_tree(&self, path: &str, params: &ListTreeParams) -> Result> + '_> { + let base_url = if path.is_empty() { + format!("{}/api/buckets/{}/{}/tree", self.client.inner.endpoint, self.namespace, self.bucket) + } else { + format!("{}/api/buckets/{}/{}/tree/{}", self.client.inner.endpoint, self.namespace, self.bucket, path) + }; + let bucket_name = self.bucket_name(); + let mut initial_url = Url::parse(&base_url)?; + { + let mut qp = initial_url.query_pairs_mut(); + if let Some(l) = params.limit { + qp.append_pair("limit", l.to_string().as_str()); + } + if params.recursive { + qp.append_pair("recursive", "true"); + } + qp.finish(); + } + + Ok(futures::stream::try_unfold( + (VecDeque::::new(), Some(initial_url), false), + move |(mut pending, next_url, fetched)| { + let client = self.client.clone(); + let bucket_name = bucket_name.clone(); + async move { + if let Some(entry) = pending.pop_front() { + return Ok(Some((entry, (pending, next_url, fetched)))); + } + let url = match next_url { + Some(url) => url, + None if fetched => return Ok(None), + None => { + // if !fetched + return Err(HFError::Other("Initial list Url not set".to_string())); + }, + }; + let req = client.inner.client.get(url).headers(client.auth_headers()); + let resp = req.send().await?; + let resp = check_bucket_response(resp, &bucket_name, NotFoundContext::Bucket).await?; + let next_cursor = parse_link_header_next(resp.headers()); + let entries: Vec = resp.json().await?; + + pending.extend(entries); + if let Some(entry) = pending.pop_front() { + Ok(Some((entry, (pending, next_cursor, true)))) + } else { + Ok(None) + } + } + }, + )) + } + + /// Returns metadata for a batch of file paths. + /// + /// Paths that do not exist in the bucket are silently omitted from the result. + pub async fn get_paths_info(&self, paths: &[String]) -> Result> { + #[derive(serde::Serialize)] + struct Body<'a> { + paths: &'a [String], + } + + let resp = self + .client + .inner + .client + .post(format!("{}/paths-info", self.bucket_url())) + .headers(self.client.auth_headers()) + .json(&Body { paths }) + .send() + .await?; + + let resp = check_bucket_response(resp, &self.bucket_name(), NotFoundContext::Bucket).await?; + Ok(resp.json().await?) + } + + /// Returns a short-lived JWT for uploading files to the Xet CAS. + /// Use the returned `cas_url` and `token` to push file bytes before calling `batch_files`. + pub async fn get_xet_write_token(&self) -> Result { + let resp = self + .client + .inner + .client + .get(format!("{}/xet-write-token", self.bucket_url())) + .headers(self.client.auth_headers()) + .send() + .await?; + let resp = check_bucket_response(resp, &self.bucket_name(), NotFoundContext::Bucket).await?; + Ok(resp.json().await?) + } + + /// Returns a short-lived JWT for downloading files from the Xet CAS directly. + pub async fn get_xet_read_token(&self) -> Result { + let resp = self + .client + .inner + .client + .get(format!("{}/xet-read-token", self.bucket_url())) + .headers(self.client.auth_headers()) + .send() + .await?; + let resp = check_bucket_response(resp, &self.bucket_name(), NotFoundContext::Bucket).await?; + Ok(resp.json().await?) + } + + /// Resolves a file path to a direct download URL. + /// + /// Uses the no-redirect client to capture the 302 `Location` header rather than + /// following it. Metadata is extracted from response headers: + /// `X-Linked-Size`, `X-XET-Hash`, `X-Linked-ETag`, `Last-Modified`, and `Link`. + pub async fn resolve_file(&self, path: &str) -> Result { + let url = format!("{}/buckets/{}/{}/resolve/{}", self.client.inner.endpoint, self.namespace, self.bucket, path); + let resp = self + .client + .inner + .no_redirect_client + .get(&url) + .headers(self.client.auth_headers()) + .send() + .await?; + + if !resp.status().is_redirection() { + return Err(check_bucket_response( + resp, + &self.bucket_name(), + NotFoundContext::Entry { path: path.to_string() }, + ) + .await + .unwrap_err()); + } + + let headers = resp.headers(); + + let location = headers + .get("location") + .and_then(|v| v.to_str().ok()) + .map(str::to_owned) + .ok_or_else(|| HFError::Http { + status: resp.status(), + url: url.clone(), + body: "missing Location header".to_string(), + })?; + + let size = headers + .get("x-linked-size") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + + let xet_hash = headers.get("x-xet-hash").and_then(|v| v.to_str().ok()).map(str::to_owned); + + let etag = headers.get("x-linked-etag").and_then(|v| v.to_str().ok()).map(str::to_owned); + + let last_modified = headers.get("last-modified").and_then(|v| v.to_str().ok()).map(str::to_owned); + + let mut xet_auth_url = None; + let mut xet_reconstruction_url = None; + if let Some(link) = headers.get("link").and_then(|v| v.to_str().ok()) { + for part in link.split(',') { + let part = part.trim(); + if let Some((url_part, rel_part)) = part.split_once(';') { + let u = url_part.trim().trim_start_matches('<').trim_end_matches('>').to_string(); + if rel_part.contains("xet-auth") { + xet_auth_url = Some(u); + } else if rel_part.contains("xet-reconstruction-info") { + xet_reconstruction_url = Some(u); + } + } + } + } + + Ok(ResolvedFile { + url: location, + size, + xet_hash, + etag, + last_modified, + xet_auth_url, + xet_reconstruction_url, + }) + } + + /// Resolves a file path and returns Xet reconstruction metadata. + /// + /// Sends `Accept: application/vnd.xet-fileinfo+json` to request the JSON response + /// instead of a redirect. Use the returned `reconstruction_url` to fetch chunk data + /// from the Xet CAS directly. + #[cfg(feature = "xet")] + pub async fn xet_resolve_file(&self, path: &str) -> Result { + let url = format!("{}/buckets/{}/{}/resolve/{}", self.client.inner.endpoint, self.namespace, self.bucket, path); + let resp = self + .client + .inner + .client + .get(&url) + .headers(self.client.auth_headers()) + .header("accept", "application/vnd.xet-fileinfo+json") + .send() + .await?; + let resp = + check_bucket_response(resp, &self.bucket_name(), NotFoundContext::Entry { path: path.to_string() }).await?; + Ok(resp.json().await?) + } +} + +impl HFClient { + /// Permanently deletes a bucket and all of its files. + pub async fn delete_bucket(&self, namespace: &str, bucket: &str) -> Result<()> { + let url = format!("{}/api/buckets/{}/{}", self.inner.endpoint, namespace, bucket); + let bucket_id = format!("{}/{}", namespace, bucket); + let resp = self.inner.client.delete(&url).headers(self.auth_headers()).send().await?; + check_bucket_response(resp, &bucket_id, NotFoundContext::Bucket).await?; + Ok(()) + } + + /// Creates a new bucket owned by `namespace`. + pub async fn create_bucket( + &self, + namespace: &str, + bucket: &str, + params: &CreateBucketParams, + ) -> Result { + let url = format!("{}/api/buckets/{}/{}", self.inner.endpoint, namespace, bucket); + let resp = self + .inner + .client + .post(&url) + .headers(self.auth_headers()) + .json(params) + .send() + .await?; + let bucket_id = format!("{}/{}", namespace, bucket); + let resp = check_bucket_response(resp, &bucket_id, NotFoundContext::Bucket).await?; + Ok(resp.json().await?) + } + + /// Returns a paginated stream of all buckets owned by `namespace`. + /// Pagination is driven by `Link` response headers. + pub fn list_buckets(&self, namespace: &str) -> Result> + '_> { + let url = Url::parse(&format!("{}/api/buckets/{}", self.inner.endpoint, namespace))?; + Ok(self.paginate(url, vec![], None)) + } +} + +sync_api! { + impl HFBucket -> HFBucketSync { + fn info(&self) -> Result; + fn update_settings(&self, params: &UpdateBucketParams) -> Result<()>; + fn batch_files(&self, ops: &[BatchOp]) -> Result; + fn get_paths_info(&self, paths: &[String]) -> Result>; + fn get_xet_write_token(&self) -> Result; + fn get_xet_read_token(&self) -> Result; + fn resolve_file(&self, path: &str) -> Result; + } +} + +sync_api_stream! { + impl HFBucket -> HFBucketSync { + fn list_tree(&self, path: &str, params: &ListTreeParams) -> TreeEntry; + } +} + +sync_api! { + #[cfg(feature = "xet")] + impl HFBucket -> HFBucketSync { + fn xet_resolve_file(&self, path: &str) -> Result; + } +} + +sync_api! { + impl HFClient -> HFClientSync { + fn delete_bucket(&self, namespace: &str, bucket: &str) -> Result<()>; + fn create_bucket(&self, namespace: &str, bucket: &str, params: &CreateBucketParams) -> Result; + } +} + +sync_api_stream! { + impl HFClient -> HFClientSync { + fn list_buckets(&self, namespace: &str) -> BucketOverview; + } +} + +#[cfg(test)] +mod tests { + use crate::HFClientBuilder; + + #[test] + fn bucket_constructor_sets_namespace_and_bucket() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + assert_eq!(bucket.namespace, "myuser"); + assert_eq!(bucket.bucket, "my-bucket"); + } + + #[test] + fn get_bucket_url() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + let url = format!("{}/api/buckets/{}/{}", bucket.client.inner.endpoint, bucket.namespace, bucket.bucket); + assert!(url.ends_with("/api/buckets/myuser/my-bucket")); + } + + #[test] + fn update_settings_url() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + let url = + format!("{}/api/buckets/{}/{}/settings", bucket.client.inner.endpoint, bucket.namespace, bucket.bucket); + assert!(url.ends_with("/api/buckets/myuser/my-bucket/settings")); + } + + #[test] + fn create_bucket_url() { + let client = HFClientBuilder::new().build().unwrap(); + let url = format!("{}/api/buckets/{}/{}", client.inner.endpoint, "myuser", "new-bucket"); + assert!(url.ends_with("/api/buckets/myuser/new-bucket")); + } + + #[test] + fn list_buckets_url() { + let client = HFClientBuilder::new().build().unwrap(); + let url = format!("{}/api/buckets/{}", client.inner.endpoint, "myuser"); + assert!(url.ends_with("/api/buckets/myuser")); + } + + #[test] + fn batch_files_ndjson_adds_before_deletes() { + use crate::types::{AddFileOp, BatchOp, DeleteFileOp}; + + let ops = vec![ + BatchOp::DeleteFile(DeleteFileOp { + path: "old.parquet".to_string(), + }), + BatchOp::AddFile(AddFileOp { + path: "new.parquet".to_string(), + xet_hash: "abc".to_string(), + content_type: "application/octet-stream".to_string(), + mtime: None, + }), + ]; + let (adds, deletes): (Vec<_>, Vec<_>) = ops.into_iter().partition(|op| matches!(op, BatchOp::AddFile(_))); + let ndjson: String = adds + .iter() + .chain(deletes.iter()) + .map(|op| serde_json::to_string(op).map(|s| s + "\n")) + .collect::>() + .unwrap(); + let lines: Vec<&str> = ndjson.lines().collect(); + assert_eq!(lines.len(), 2); + assert!(lines[0].contains("addFile"), "first line must be addFile, got: {}", lines[0]); + assert!(lines[1].contains("deleteFile"), "second line must be deleteFile"); + } + + #[test] + fn batch_files_each_line_ends_with_newline() { + use crate::types::{AddFileOp, BatchOp}; + let ops = vec![BatchOp::AddFile(AddFileOp { + path: "f.parquet".to_string(), + xet_hash: "h".to_string(), + content_type: "application/octet-stream".to_string(), + mtime: None, + })]; + let (adds, deletes): (Vec<_>, Vec<_>) = ops.into_iter().partition(|op| matches!(op, BatchOp::AddFile(_))); + let ndjson: String = adds + .iter() + .chain(deletes.iter()) + .map(|op| serde_json::to_string(op).map(|s| s + "\n")) + .collect::>() + .unwrap(); + assert!(ndjson.ends_with('\n')); + } + + #[test] + fn list_tree_url_empty_path() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + let url = if "".is_empty() { + format!("{}/api/buckets/{}/{}/tree", bucket.client.inner.endpoint, bucket.namespace, bucket.bucket) + } else { + format!( + "{}/api/buckets/{}/{}/tree/{}", + bucket.client.inner.endpoint, bucket.namespace, bucket.bucket, "some/path" + ) + }; + assert!(url.ends_with("/api/buckets/myuser/my-bucket/tree")); + } + + #[test] + fn list_tree_url_with_path() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + let path = "data/sub"; + let url = format!( + "{}/api/buckets/{}/{}/tree/{}", + bucket.client.inner.endpoint, bucket.namespace, bucket.bucket, path + ); + assert!(url.ends_with("/api/buckets/myuser/my-bucket/tree/data/sub")); + } + + #[test] + fn xet_token_urls() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + let write_url = format!( + "{}/api/buckets/{}/{}/xet-write-token", + bucket.client.inner.endpoint, bucket.namespace, bucket.bucket + ); + let read_url = format!( + "{}/api/buckets/{}/{}/xet-read-token", + bucket.client.inner.endpoint, bucket.namespace, bucket.bucket + ); + assert!(write_url.ends_with("/xet-write-token")); + assert!(read_url.ends_with("/xet-read-token")); + } + + #[test] + fn paths_info_url() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + let url = + format!("{}/api/buckets/{}/{}/paths-info", bucket.client.inner.endpoint, bucket.namespace, bucket.bucket); + assert!(url.ends_with("/paths-info")); + } + + #[test] + fn resolve_file_parses_link_header() { + let link = r#"; rel="xet-auth", ; rel="xet-reconstruction-info""#; + let mut xet_auth = None; + let mut xet_reconstruction = None; + for part in link.split(',') { + let part = part.trim(); + if let Some((url_part, rel_part)) = part.split_once(';') { + let url = url_part.trim().trim_start_matches('<').trim_end_matches('>').to_string(); + let rel = rel_part.trim(); + if rel.contains("xet-auth") { + xet_auth = Some(url); + } else if rel.contains("xet-reconstruction-info") { + xet_reconstruction = Some(url); + } + } + } + assert_eq!(xet_auth.unwrap(), "https://auth.example.com/token"); + assert_eq!(xet_reconstruction.unwrap(), "https://xet.example.com/reconstruct/abc"); + } + + #[test] + fn resolve_file_url() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + let url = format!( + "{}/buckets/{}/{}/resolve/{}", + bucket.client.inner.endpoint, bucket.namespace, bucket.bucket, "data/train.parquet" + ); + assert!(url.contains("/buckets/myuser/my-bucket/resolve/data/train.parquet")); + assert!(!url.contains("/api/")); + } + + #[cfg(feature = "xet")] + #[test] + fn xet_resolve_file_url() { + let client = HFClientBuilder::new().build().unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + let url = format!( + "{}/buckets/{}/{}/resolve/{}", + bucket.client.inner.endpoint, bucket.namespace, bucket.bucket, "data/train.parquet" + ); + assert!(url.contains("/buckets/myuser/my-bucket/resolve/data/train.parquet")); + } +} diff --git a/huggingface_hub/src/api/mod.rs b/huggingface_hub/src/api/mod.rs index 105ac7e..200817c 100644 --- a/huggingface_hub/src/api/mod.rs +++ b/huggingface_hub/src/api/mod.rs @@ -1,3 +1,4 @@ +pub mod buckets; pub mod commits; pub mod files; pub mod repo; diff --git a/huggingface_hub/src/bin/hfrs/main.rs b/huggingface_hub/src/bin/hfrs/main.rs index eed3d5c..a91f815 100644 --- a/huggingface_hub/src/bin/hfrs/main.rs +++ b/huggingface_hub/src/bin/hfrs/main.rs @@ -139,6 +139,9 @@ fn format_hf_error(err: &HFError) -> String { HFError::RepoNotFound { repo_id } => { format!("Repository '{repo_id}' not found. If the repo is private, make sure you are authenticated.") }, + HFError::BucketNotFound { bucket_name } => { + format!("Bucket '{bucket_name}' not found. If the bucket is private, make sure you are authenticated.") + }, HFError::EntryNotFound { path, repo_id } => { format!("File '{path}' not found in repository '{repo_id}'.") }, @@ -148,6 +151,17 @@ fn format_hf_error(err: &HFError) -> String { HFError::AuthRequired => { "Not authenticated. Run `hfrs auth login` or set the HF_TOKEN environment variable.".to_string() }, + HFError::Forbidden => { + "Permission denied. Check that your token has the required scopes for this operation.".to_string() + }, + HFError::Conflict(body) => { + if body.contains("already exists") { + "Resource already exists. Use --exist-ok to skip this error.".to_string() + } else { + format!("Conflict: {body}") + } + }, + HFError::RateLimited => "Rate limited. Please wait a moment and try again.".to_string(), HFError::Http { status, url, body } => { let status_code = status.as_u16(); match status_code { @@ -158,9 +172,6 @@ fn format_hf_error(err: &HFError) -> String { } msg }, - 403 => { - "Permission denied. Check that your token has the required scopes for this operation.".to_string() - }, 404 => { format!("Not found: {url}") }, diff --git a/huggingface_hub/src/blocking.rs b/huggingface_hub/src/blocking.rs index 5476e05..a2c62ff 100644 --- a/huggingface_hub/src/blocking.rs +++ b/huggingface_hub/src/blocking.rs @@ -48,6 +48,15 @@ pub struct HFSpaceSync { pub(crate) inner: Arc, } +/// Synchronous handle for Storage Bucket operations. +/// +/// Obtain via [`HFClientSync::bucket`]. All methods block the current thread. +#[derive(Clone)] +pub struct HFBucketSync { + pub(crate) inner: crate::buckets::HFBucket, + pub(crate) runtime: Arc, +} + impl fmt::Debug for HFClientSync { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HFClientSync").finish() @@ -120,6 +129,14 @@ impl HFClientSync { pub fn space(&self, owner: impl Into, name: impl Into) -> HFSpaceSync { HFSpaceSync::new(self.clone(), owner, name) } + + /// Creates a synchronous bucket handle. + pub fn bucket(&self, namespace: impl Into, repo: impl Into) -> HFBucketSync { + HFBucketSync { + inner: self.inner.bucket(namespace, repo), + runtime: self.runtime.clone(), + } + } } impl Deref for HFClientSync { @@ -206,6 +223,18 @@ impl From for Arc { /// Alias for [`HFRepositorySync`]. pub type HFRepoSync = HFRepositorySync; +#[cfg(test)] +mod bucket_tests { + #[test] + fn bucket_sync_constructor() { + use crate::HFClientBuilder; + let client = crate::blocking::HFClientSync::from_api(HFClientBuilder::new().build().unwrap()).unwrap(); + let bucket = client.bucket("myuser", "my-bucket"); + assert_eq!(bucket.inner.namespace, "myuser"); + assert_eq!(bucket.inner.repo, "my-bucket"); + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/huggingface_hub/src/buckets.rs b/huggingface_hub/src/buckets.rs new file mode 100644 index 0000000..532f8de --- /dev/null +++ b/huggingface_hub/src/buckets.rs @@ -0,0 +1,40 @@ +use crate::HFClient; + +/// Handle for operations on a single HuggingFace Storage Bucket. +/// +/// Obtain via [`HFClient::bucket`]. Every method adds `Authorization: Bearer ` +/// using the token configured on the client. +/// +/// # Example +/// +/// ```rust,no_run +/// # #[tokio::main] +/// # async fn main() -> huggingface_hub::Result<()> { +/// let client = huggingface_hub::HFClient::new()?; +/// let bucket = client.bucket("my-org", "my-bucket"); +/// let overview = bucket.info().await?; +/// println!("Bucket size: {} bytes", overview.size); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone)] +pub struct HFBucket { + pub(crate) client: HFClient, + /// The namespace (user or organization) that owns the bucket. + pub namespace: String, + /// The bucket name within the namespace. + pub bucket: String, +} + +impl HFClient { + /// Creates a handle for operations on a single Storage Bucket. + /// + /// No I/O is performed — this simply binds the namespace and name to the client. + pub fn bucket(&self, namespace: impl Into, repo: impl Into) -> HFBucket { + HFBucket { + client: self.clone(), + namespace: namespace.into(), + bucket: repo.into(), + } + } +} diff --git a/huggingface_hub/src/client.rs b/huggingface_hub/src/client.rs index 891ce5d..f390c88 100644 --- a/huggingface_hub/src/client.rs +++ b/huggingface_hub/src/client.rs @@ -293,7 +293,7 @@ impl HFClient { revision, repo_id: repo_id_str, }), - crate::error::NotFoundContext::Generic => Err(HFError::Http { status, url, body }), + _ => Err(HFError::Http { status, url, body }), }, _ => Err(HFError::Http { status, url, body }), } diff --git a/huggingface_hub/src/error.rs b/huggingface_hub/src/error.rs index e99e9ab..9a708e3 100644 --- a/huggingface_hub/src/error.rs +++ b/huggingface_hub/src/error.rs @@ -15,6 +15,9 @@ pub enum HFError { #[error("Repository not found: {repo_id}")] RepoNotFound { repo_id: String }, + #[error("Repository not found: {bucket_name}")] + BucketNotFound { bucket_name: String }, + #[error("Revision not found: {revision} in {repo_id}")] RevisionNotFound { repo_id: String, revision: String }, @@ -64,6 +67,25 @@ pub enum HFError { #[error("{0}")] Other(String), + + #[error("forbidden")] + Forbidden, + #[error("conflict: {0}")] + Conflict(String), + #[error("rate limited")] + RateLimited, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_error_variants_display() { + assert_eq!(HFError::Forbidden.to_string(), "forbidden"); + assert_eq!(HFError::Conflict("name taken".to_string()).to_string(), "conflict: name taken"); + assert_eq!(HFError::RateLimited.to_string(), "rate limited"); + } } impl HFError { @@ -96,6 +118,8 @@ pub type Result = std::result::Result; pub(crate) enum NotFoundContext { /// 404 means the repository does not exist Repo, + /// 404 means the bucket does not exist + Bucket, /// 404 means a file/path does not exist within the repo Entry { path: String }, /// 404 means the revision does not exist diff --git a/huggingface_hub/src/lib.rs b/huggingface_hub/src/lib.rs index cffeb6a..2710207 100644 --- a/huggingface_hub/src/lib.rs +++ b/huggingface_hub/src/lib.rs @@ -21,6 +21,7 @@ mod macros; pub mod api; #[cfg(feature = "blocking")] pub mod blocking; +pub mod buckets; pub mod cache; pub mod client; pub(crate) mod constants; @@ -33,7 +34,8 @@ pub mod types; pub mod xet; #[cfg(feature = "blocking")] -pub use blocking::{HFClientSync, HFRepoSync, HFRepositorySync, HFSpaceSync}; +pub use blocking::{HFBucketSync, HFClientSync, HFRepoSync, HFRepositorySync, HFSpaceSync}; +pub use buckets::HFBucket; pub use client::{HFClient, HFClientBuilder}; #[cfg(feature = "cli")] #[doc(hidden)] diff --git a/huggingface_hub/src/pagination.rs b/huggingface_hub/src/pagination.rs index d90a3f3..a528811 100644 --- a/huggingface_hub/src/pagination.rs +++ b/huggingface_hub/src/pagination.rs @@ -107,7 +107,7 @@ impl HFClient { /// Parse the `Link` header for a `rel="next"` URL. /// Format: `; rel="next"` -fn parse_link_header_next(headers: &HeaderMap) -> Option { +pub(crate) fn parse_link_header_next(headers: &HeaderMap) -> Option { let link_header = headers.get("link")?.to_str().ok()?; for part in link_header.split(',') { diff --git a/huggingface_hub/src/types/buckets.rs b/huggingface_hub/src/types/buckets.rs new file mode 100644 index 0000000..6072c01 --- /dev/null +++ b/huggingface_hub/src/types/buckets.rs @@ -0,0 +1,280 @@ +use serde::{Deserialize, Serialize}; +use typed_builder::TypedBuilder; + +// --- Parameter types --- + +/// Parameters for [`HFClient::create_bucket`]. +#[derive(Debug, Clone, TypedBuilder, Serialize)] +pub struct CreateBucketParams { + /// Whether the bucket should be private. Defaults to public when omitted. + #[builder(default, setter(strip_option))] + #[serde(skip_serializing_if = "Option::is_none")] + pub private: Option, + /// Resource group to assign the bucket to. + #[builder(default, setter(strip_option, into))] + #[serde(rename = "resourceGroupId", skip_serializing_if = "Option::is_none")] + pub resource_group_id: Option, + /// CDN regions to enable for this bucket at creation time. + #[builder(default)] + #[serde(skip_serializing_if = "Vec::is_empty")] + pub cdn: Vec, +} + +/// Parameters for [`HFBucket::update_settings`]. +#[derive(Debug, Clone, TypedBuilder, Serialize)] +pub struct UpdateBucketParams { + /// Change the bucket's visibility. Pass `true` to make it private, `false` for public. + #[builder(default, setter(strip_option))] + #[serde(skip_serializing_if = "Option::is_none")] + pub private: Option, + /// Replace the full set of CDN regions. Pass an empty vec to remove all CDN regions. + #[builder(default, setter(strip_option))] + #[serde(rename = "cdnRegions", skip_serializing_if = "Option::is_none")] + pub cdn_regions: Option>, +} + +/// Parameters for [`HFBucket::list_tree`]. +#[derive(Debug, Clone, TypedBuilder)] +pub struct ListTreeParams { + /// Maximum number of entries to return per page. The server default is 1000; maximum is 10 000. + #[builder(default, setter(strip_option))] + pub limit: Option, + /// When `true`, return all entries under the prefix recursively. + /// When `false` (the default), only top-level entries are returned and sub-directories + /// are collapsed into a single [`EntryType::Directory`] entry. + #[builder(default)] + pub recursive: bool, +} + +// --- Response types --- + +/// Returned by [`HFClient::create_bucket`] on success. +#[derive(Debug, Clone, Deserialize)] +pub struct BucketCreated { + /// Full URL of the newly created bucket (e.g. `https://huggingface.co/buckets/my-org/my-bucket`). + pub url: String, + /// Bucket identifier in `namespace/name` format. + pub name: String, + /// Opaque server-side ID for the bucket. + pub id: String, +} + +/// A CDN region specifying a cloud provider and geographic region. +/// +/// Used in [`CreateBucketParams`], [`UpdateBucketParams`], and [`BucketOverview`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CdnRegion { + /// Cloud provider (e.g. `"gcp"` or `"aws"`). + pub provider: String, + /// Geographic region identifier (e.g. `"us"` or `"eu"`). + pub region: String, +} + +/// Metadata about a Storage Bucket, as returned by [`HFBucket::info`] and [`HFClient::list_buckets`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BucketOverview { + /// Internal MongoDB document ID. + #[serde(rename = "_id")] + pub mongo_id: String, + /// Bucket identifier in `namespace/name` format. + pub id: String, + /// Namespace (user or organization) that owns the bucket. + pub author: String, + /// Whether the bucket is private. `None` means the server did not specify. + pub private: Option, + /// Repository type tag — always `"bucket"`. + #[serde(rename = "repoType")] + pub repo_type: String, + /// ISO 8601 creation timestamp. + #[serde(rename = "createdAt")] + pub created_at: String, + /// ISO 8601 last-updated timestamp. + #[serde(rename = "updatedAt")] + pub updated_at: String, + /// Total storage used by the bucket, in bytes. + pub size: u64, + /// Number of files currently stored in the bucket. + #[serde(rename = "totalFiles")] + pub total_files: u64, + /// CDN regions configured for this bucket. + #[serde(rename = "cdnRegions")] + pub cdn_regions: Vec, + /// Resource group this bucket belongs to, if any. + #[serde(rename = "resourceGroup")] + pub resource_group: Option, +} + +/// A resource group that a bucket can be associated with. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResourceGroup { + /// Opaque resource group ID. + pub id: String, + /// Human-readable resource group name. + pub name: String, + /// Number of members in the resource group, if returned by the server. + #[serde(rename = "numUsers")] + pub num_users: Option, +} + +/// A short-lived token for authenticating directly against the Xet CAS (content-addressable storage). +/// +/// Returned by [`HFBucket::get_xet_write_token`] and [`HFBucket::get_xet_read_token`]. +#[derive(Debug, Clone, Deserialize)] +pub struct XetToken { + /// Bearer token to include in requests to the Xet CAS. + #[serde(rename = "accessToken")] + pub access_token: String, + /// Base URL of the Xet CAS server. + #[serde(rename = "casUrl")] + pub cas_url: String, + /// Token expiry as a Unix epoch timestamp (seconds), following the standard JWT `exp` convention. + #[serde(rename = "exp")] + pub expires_at: u64, +} + +/// A single entry returned by [`HFBucket::list_tree`] and [`HFBucket::get_paths_info`]. +/// +/// Can represent either a file or a directory. File-only fields (`size`, `xet_hash`, +/// `content_type`, `mtime`) are `None` for directory entries. +#[derive(Debug, Clone, Deserialize)] +pub struct TreeEntry { + /// Whether this entry is a file or a directory. + #[serde(rename = "type")] + pub entry_type: EntryType, + /// Path of the entry relative to the bucket root. + pub path: String, + /// ISO 8601 timestamp of when this entry was added to the bucket. + #[serde(rename = "uploadedAt")] + pub uploaded_at: String, + /// Original file modification time as an ISO 8601 timestamp, if preserved at upload. + pub mtime: Option, + /// File size in bytes. `None` for directory entries. + pub size: Option, + /// Content-addressable Xet hash of the file. `None` for directory entries. + #[serde(rename = "xetHash")] + pub xet_hash: Option, + /// MIME content type of the file. `None` for directory entries. + #[serde(rename = "contentType")] + pub content_type: Option, +} + +/// Whether a [`TreeEntry`] is a file or a directory. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum EntryType { + /// A regular file stored in the bucket. + File, + /// A virtual directory prefix (only appears when `recursive` is `false`). + Directory, +} + +// --- Batch types --- + +/// A single operation in a [`HFBucket::batch_files`] call. +/// +/// All [`BatchOp::AddFile`] operations must precede all [`BatchOp::DeleteFile`] operations — +/// the client enforces this automatically. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type")] +pub enum BatchOp { + /// Add or overwrite a file at the given path. + /// + /// The file contents must already be present in the Xet CAS; obtain a write token + /// via [`HFBucket::get_xet_write_token`] before uploading. + #[serde(rename = "addFile")] + AddFile(AddFileOp), + /// Remove a file from the bucket by path. + #[serde(rename = "deleteFile")] + DeleteFile(DeleteFileOp), +} + +/// Payload for a [`BatchOp::AddFile`] operation. +#[derive(Debug, Clone, Serialize)] +pub struct AddFileOp { + /// Destination path within the bucket. + pub path: String, + /// Content-addressable Xet hash of the file, obtained after uploading to the CAS. + #[serde(rename = "xetHash")] + pub xet_hash: String, + /// MIME content type of the file (e.g. `"application/octet-stream"`). + #[serde(rename = "contentType")] + pub content_type: String, + /// Original file modification time as a Unix timestamp in milliseconds, if known. + #[serde(skip_serializing_if = "Option::is_none")] + pub mtime: Option, +} + +/// Payload for a [`BatchOp::DeleteFile`] operation. +#[derive(Debug, Clone, Serialize)] +pub struct DeleteFileOp { + /// Path of the file to remove from the bucket. + pub path: String, +} + +/// Result of a [`HFBucket::batch_files`] call. +#[derive(Debug, Clone, Deserialize)] +pub struct BatchResult { + /// `true` if every operation in the batch succeeded. + pub success: bool, + /// Total number of operations attempted. + pub processed: u32, + /// Number of operations that completed successfully. + pub succeeded: u32, + /// Details of any operations that failed. + pub failed: Vec, +} + +/// A single failed operation within a [`BatchResult`]. +#[derive(Debug, Clone, Deserialize)] +pub struct BatchFailure { + /// Path of the file whose operation failed. + pub path: String, + /// Server-provided error message. + pub error: String, +} + +// --- resolve_file types --- + +/// A resolved direct download URL for a bucket file, returned by [`HFBucket::resolve_file`]. +#[derive(Debug, Clone)] +pub struct ResolvedFile { + /// Direct download URL (the `Location` from the server's 302 redirect). + pub url: String, + /// File size in bytes, if provided by the server. + pub size: Option, + /// Content-addressable Xet hash of the file, if provided. + pub xet_hash: Option, + /// ETag of the file, if provided. + pub etag: Option, + /// `Last-Modified` header value, if provided. + pub last_modified: Option, + /// URL to obtain a fresh Xet read token for this file, if provided. + pub xet_auth_url: Option, + /// URL pointing to the Xet CAS reconstruction manifest for this file, if provided. + pub xet_reconstruction_url: Option, +} + +// --- xet_resolve_file type (feature = "xet") --- + +/// Xet reconstruction metadata for a bucket file, returned by [`HFBucket::xet_resolve_file`]. +/// +/// Only available with the `xet` feature enabled. +#[cfg(feature = "xet")] +#[derive(Debug, Clone, Deserialize)] +pub struct XetFileInfo { + /// Content-addressable Xet hash of the file. + pub hash: String, + /// URL to obtain a fresh Xet read token. + #[serde(rename = "refreshUrl")] + pub refresh_url: String, + /// URL pointing to the Xet CAS reconstruction manifest. + #[serde(rename = "reconstructionUrl")] + pub reconstruction_url: String, + /// ETag of the file. + pub etag: String, + /// File size in bytes. + pub size: u64, + /// MIME content type of the file. + #[serde(rename = "contentType")] + pub content_type: String, +} diff --git a/huggingface_hub/src/types/mod.rs b/huggingface_hub/src/types/mod.rs index 86adfdb..3c02f3b 100644 --- a/huggingface_hub/src/types/mod.rs +++ b/huggingface_hub/src/types/mod.rs @@ -1,3 +1,4 @@ +pub mod buckets; pub mod cache; pub mod commit; pub mod params; @@ -7,6 +8,7 @@ pub mod user; #[cfg(feature = "spaces")] pub mod spaces; +pub use buckets::*; pub use commit::*; pub use params::*; pub use repo::*; diff --git a/huggingface_hub/tests/integration_test.rs b/huggingface_hub/tests/integration_test.rs index d3daa6e..1df0afa 100644 --- a/huggingface_hub/tests/integration_test.rs +++ b/huggingface_hub/tests/integration_test.rs @@ -9,7 +9,10 @@ //! Feature-gated tests: enable with --features, e.g.: //! HF_TOKEN=hf_xxx cargo test -p huggingface-hub --all-features --test integration_test -use futures::StreamExt; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +use futures::{StreamExt, TryStreamExt}; use huggingface_hub::repository::{ HFRepository, RepoCreateBranchParams, RepoCreateCommitParams, RepoCreateTagParams, RepoDeleteBranchParams, RepoDeleteFileParams, RepoDeleteFolderParams, RepoDeleteTagParams, RepoDownloadFileParams, RepoFileExistsParams, @@ -21,6 +24,8 @@ use huggingface_hub::types::*; use huggingface_hub::{HFClient, HFClientBuilder}; #[cfg(feature = "spaces")] use huggingface_hub::{SpaceSecretDeleteParams, SpaceSecretParams, SpaceVariableDeleteParams, SpaceVariableParams}; +#[cfg(feature = "xet")] +use xet::xet_session::Sha256Policy; fn api() -> Option { if std::env::var("HF_TOKEN").is_err() { @@ -891,3 +896,113 @@ async fn test_space_secrets_and_variables() { .build(); let _ = api.delete_repo(&delete_params).await; } + +// ---- HFBucket integration tests ---- + +fn test_bucket_name() -> String { + format!("test-bucket-{}", uuid_v4_short()) +} + +#[tokio::test] +async fn test_list_buckets() { + let Some(api) = api() else { return }; + let username = cached_username().await; + let stream = api.list_buckets(username).expect("list_buckets failed"); + let buckets: Vec<_> = { + use futures::StreamExt; + futures::pin_mut!(stream); + let mut items = Vec::new(); + while let Some(item) = stream.next().await { + items.push(item.expect("list_buckets item failed")); + } + items + }; + let _ = buckets; +} + +#[tokio::test] +async fn test_create_and_delete_bucket() { + let Some(api) = api() else { return }; + if !write_enabled() { + return; + } + let username = cached_username().await; + let name = test_bucket_name(); + + let created = api + .create_bucket(username, &name, &CreateBucketParams::builder().private(true).build()) + .await + .expect("create_bucket failed"); + assert!(created.name.contains(&name)); + + let bucket = api.bucket(username, &name); + let info = bucket.info().await.expect("info failed"); + assert_eq!(info.id, format!("{username}/{name}")); + assert!(info.private.unwrap()); + + bucket + .update_settings(&UpdateBucketParams::builder().private(false).build()) + .await + .expect("update_settings failed"); + + let info = bucket.info().await.unwrap(); + assert!(!info.private.unwrap()); + + api.delete_bucket(username, &name).await.expect("delete_bucket failed"); + + assert!(matches!(bucket.info().await, Err(huggingface_hub::HFError::RepoNotFound { .. }))); +} + +#[tokio::test] +async fn test_bucket_list_tree_empty() { + let Some(api) = api() else { return }; + if !write_enabled() { + return; + } + let username = cached_username().await; + let name = test_bucket_name(); + + api.create_bucket(username, &name, &CreateBucketParams::builder().build()) + .await + .expect("create_bucket failed"); + + let bucket = api.bucket(username, &name); + + let entries: Vec<_> = bucket + .list_tree("", &ListTreeParams::builder().build()) + .unwrap() + .collect::>() + .await + .into_iter() + .collect::>>() + .expect("list_tree failed"); + + assert!(entries.is_empty(), "new bucket should have no files"); + + api.delete_bucket(username, &name).await.unwrap(); +} + +#[tokio::test] +async fn test_get_xet_write_and_read_token() { + let Some(api) = api() else { return }; + if !write_enabled() { + return; + } + let username = cached_username().await; + let name = test_bucket_name(); + + api.create_bucket(username, &name, &CreateBucketParams::builder().build()) + .await + .unwrap(); + + let bucket = api.bucket(username, &name); + + let write_tok = bucket.get_xet_write_token().await.expect("xet write token failed"); + assert!(!write_tok.access_token.is_empty()); + assert!(!write_tok.cas_url.is_empty()); + + let read_tok = bucket.get_xet_read_token().await.expect("xet read token failed"); + assert!(!read_tok.access_token.is_empty()); + + api.delete_bucket(username, &name).await.unwrap(); +}