Skip to content
65 changes: 62 additions & 3 deletions src/aws/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ use crate::aws::{
};
use crate::client::{HttpConnector, TokenCredentialProvider, http_connector};
use crate::config::ConfigValue;
use crate::{ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider};
use crate::{
Capabilities, ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider,
};
use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use itertools::Itertools;
Expand Down Expand Up @@ -193,6 +195,8 @@ pub struct AmazonS3Builder {
request_payer: ConfigValue<RequesterPayer>,
/// The [`HttpConnector`] to use
http_connector: Option<Arc<dyn HttpConnector>>,
/// Capabilities to advertise for this store instance
capabilities: Option<ConfigValue<Capabilities>>,
}

/// Configuration keys for [`AmazonS3Builder`]
Expand Down Expand Up @@ -463,6 +467,9 @@ pub enum AmazonS3ConfigKey {

/// Encryption options
Encryption(S3EncryptionConfigKey),

/// Override the capabilities advertised by this store.
Capabilities,
}

impl AsRef<str> for AmazonS3ConfigKey {
Expand Down Expand Up @@ -497,6 +504,7 @@ impl AsRef<str> for AmazonS3ConfigKey {
Self::RequestPayer => "aws_request_payer",
Self::Client(opt) => opt.as_ref(),
Self::Encryption(opt) => opt.as_ref(),
Self::Capabilities => "aws_capabilities",
}
}
}
Expand Down Expand Up @@ -557,6 +565,7 @@ impl FromStr for AmazonS3ConfigKey {
"aws_sse_customer_key_base64" | "sse_customer_key_base64" => Ok(Self::Encryption(
S3EncryptionConfigKey::CustomerEncryptionKey,
)),
"aws_capabilities" => Ok(Self::Capabilities),
_ => match s.strip_prefix("aws_").unwrap_or(s).parse() {
Ok(key) => Ok(Self::Client(key)),
Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()),
Expand Down Expand Up @@ -709,6 +718,9 @@ impl AmazonS3Builder {
self.encryption_customer_key_base64 = Some(value.into())
}
},
AmazonS3ConfigKey::Capabilities => {
self.capabilities = Some(ConfigValue::Deferred(value.into()))
}
};
self
}
Expand Down Expand Up @@ -765,6 +777,7 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::DisableTagging => Some(self.disable_tagging.to_string()),
AmazonS3ConfigKey::DisableBulkDelete => Some(self.disable_bulk_delete.to_string()),
AmazonS3ConfigKey::RequestPayer => Some(self.request_payer.to_string()),
AmazonS3ConfigKey::Capabilities => self.capabilities.as_ref().map(ToString::to_string),
AmazonS3ConfigKey::Encryption(key) => match key {
S3EncryptionConfigKey::ServerSideEncryption => {
self.encryption_type.as_ref().map(ToString::to_string)
Expand Down Expand Up @@ -1105,6 +1118,12 @@ impl AmazonS3Builder {
self
}

/// Override the [`Capabilities`] advertised by this store.
pub fn with_capabilities(mut self, capabilities: Capabilities) -> Self {
self.capabilities = Some(ConfigValue::Parsed(capabilities));
self
}

/// Create a [`AmazonS3`] instance from the provided values,
/// consuming `self`.
pub fn build(mut self) -> Result<AmazonS3> {
Expand Down Expand Up @@ -1286,7 +1305,10 @@ impl AmazonS3Builder {
let http_client = http.connect(&config.client_options)?;
let client = Arc::new(S3Client::new(config, http_client));

Ok(AmazonS3 { client })
Ok(AmazonS3 {
client,
capabilities: self.capabilities.map(|x| x.get()).transpose()?,
})
}
}

Expand Down Expand Up @@ -1535,6 +1557,7 @@ impl From<S3EncryptionHeaders> for HeaderMap {
#[cfg(test)]
Comment thread
vitoordaz marked this conversation as resolved.
mod tests {
use super::*;
use crate::Capability;
use std::collections::HashMap;

#[test]
Expand All @@ -1552,6 +1575,7 @@ mod tests {
("aws_session_token", aws_session_token.clone()),
("aws_unsigned_payload", "true".to_string()),
("aws_checksum_algorithm", "sha256".to_string()),
("aws_capabilities", "ordered_listing".to_string()),
]);

let builder = options
Expand All @@ -1571,6 +1595,14 @@ mod tests {
Checksum::SHA256
);
assert!(builder.unsigned_payload.get().unwrap());
assert!(
builder
.capabilities
.unwrap()
.get()
.unwrap()
.has(Capability::OrderedListing)
);
}

#[test]
Expand Down Expand Up @@ -1625,7 +1657,8 @@ mod tests {
.with_config(
"aws_sse_customer_key_base64".parse().unwrap(),
"some_customer_key",
);
)
.with_config(AmazonS3ConfigKey::Capabilities, "ordered_listing");

assert_eq!(
builder
Expand Down Expand Up @@ -1685,6 +1718,12 @@ mod tests {
.unwrap(),
"some_customer_key"
);
assert_eq!(
builder
.get_config_value(&"aws_capabilities".parse().unwrap())
.unwrap(),
"ordered_listing"
);
}

#[test]
Expand Down Expand Up @@ -1908,6 +1947,26 @@ mod tests {
assert!(s3.client.config.request_payer);
}

#[test]
fn test_parse_capabilities() {
// Default: ordered listing disabled
let s3 = AmazonS3Builder::new()
.with_bucket_name("bucket")
.with_region("region")
.build()
.unwrap();
assert!(!s3.capabilities.is_some());

// Explicit override via with_capabilities: no capabilities
let s3 = AmazonS3Builder::new()
.with_capabilities(Capabilities::new([Capability::OrderedListing]))
.with_bucket_name("bucket")
.with_region("region")
.build()
.unwrap();
assert!(s3.capabilities.unwrap().has(Capability::OrderedListing));
}

#[test]
fn test_parse_bucket_az() {
let cases = [
Expand Down
20 changes: 17 additions & 3 deletions src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ use crate::multipart::{MultipartStore, PartId};
use crate::signer::Signer;
use crate::util::STRICT_ENCODE_SET;
use crate::{
CopyMode, CopyOptions, Error, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload,
ObjectMeta, ObjectStore, Path, PutMode, PutMultipartOptions, PutOptions, PutPayload, PutResult,
Result, UploadPart,
Capabilities, CopyMode, CopyOptions, Error, GetOptions, GetResult, ListResult, MultipartId,
MultipartUpload, ObjectMeta, ObjectStore, Path, PutMode, PutMultipartOptions, PutOptions,
PutPayload, PutResult, Result, UploadPart,
};

static TAGS_HEADER: HeaderName = HeaderName::from_static("x-amz-tagging");
Expand Down Expand Up @@ -79,10 +79,17 @@ use crate::client::parts::Parts;
use crate::list::{PaginatedListOptions, PaginatedListResult, PaginatedListStore};
pub use credential::{AwsAuthorizer, AwsCredential};

// OrderedListing capability depends on the bucket type, it's not enabled for directory bucket.
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html
fn get_default_capabilities() -> Capabilities {
return Capabilities::new([]);
}

/// Interface for [Amazon S3](https://aws.amazon.com/s3/).
#[derive(Debug, Clone)]
pub struct AmazonS3 {
client: Arc<S3Client>,
capabilities: Option<Capabilities>,
}

impl std::fmt::Display for AmazonS3 {
Expand Down Expand Up @@ -413,6 +420,12 @@ impl ObjectStore for AmazonS3 {
}
}
}

fn capabilities(&self) -> Capabilities {
self.capabilities
.clone()
.unwrap_or_else(get_default_capabilities)
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -707,6 +720,7 @@ mod tests {
tagging(
Arc::new(AmazonS3 {
client: Arc::clone(&integration.client),
capabilities: None,
}),
!config.disable_tagging,
|p| {
Expand Down
49 changes: 47 additions & 2 deletions src/azure/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use crate::azure::credential::{
use crate::azure::{AzureCredential, AzureCredentialProvider, MicrosoftAzure, STORE};
use crate::client::{HttpConnector, TokenCredentialProvider, http_connector};
use crate::config::ConfigValue;
use crate::{ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider};
use crate::{
Capabilities, ClientConfigKey, ClientOptions, ObjectStoreExt, Result, RetryConfig,
StaticCredentialProvider,
};
use percent_encoding::percent_decode_str;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
Expand Down Expand Up @@ -180,6 +183,8 @@ pub struct MicrosoftAzureBuilder {
fabric_cluster_identifier: Option<String>,
/// The [`HttpConnector`] to use
http_connector: Option<Arc<dyn HttpConnector>>,
/// Capabilities to advertise for this store instance
capabilities: Option<ConfigValue<Capabilities>>,
}

/// Configuration keys for [`MicrosoftAzureBuilder`]
Expand Down Expand Up @@ -382,6 +387,9 @@ pub enum AzureConfigKey {

/// Client options
Client(ClientConfigKey),

/// Override the capabilities advertised by this store.
Capabilities,
}

impl AsRef<str> for AzureConfigKey {
Expand Down Expand Up @@ -411,6 +419,7 @@ impl AsRef<str> for AzureConfigKey {
Self::FabricSessionToken => "azure_fabric_session_token",
Self::FabricClusterIdentifier => "azure_fabric_cluster_identifier",
Self::Client(key) => key.as_ref(),
Self::Capabilities => "azure_capabilities",
}
}
}
Expand Down Expand Up @@ -468,6 +477,7 @@ impl FromStr for AzureConfigKey {
}
// Backwards compatibility
"azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)),
"azure_capabilities" => Ok(Self::Capabilities),
_ => match s.strip_prefix("azure_").unwrap_or(s).parse() {
Ok(key) => Ok(Self::Client(key)),
Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()),
Expand Down Expand Up @@ -594,6 +604,9 @@ impl MicrosoftAzureBuilder {
AzureConfigKey::FabricClusterIdentifier => {
self.fabric_cluster_identifier = Some(value.into())
}
AzureConfigKey::Capabilities => {
self.capabilities = Some(ConfigValue::Deferred(value.into()))
}
};
self
}
Expand Down Expand Up @@ -635,6 +648,7 @@ impl MicrosoftAzureBuilder {
AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host.clone(),
AzureConfigKey::FabricSessionToken => self.fabric_session_token.clone(),
AzureConfigKey::FabricClusterIdentifier => self.fabric_cluster_identifier.clone(),
AzureConfigKey::Capabilities => self.capabilities.as_ref().map(ToString::to_string),
}
}

Expand Down Expand Up @@ -906,6 +920,12 @@ impl MicrosoftAzureBuilder {
self
}

/// Override the [`Capabilities`] advertised by this store.
pub fn with_capabilities(mut self, capabilities: Capabilities) -> Self {
self.capabilities = Some(ConfigValue::Parsed(capabilities));
self
}

/// Configure a connection to container with given name on Microsoft Azure Blob store.
pub fn build(mut self) -> Result<MicrosoftAzure> {
if let Some(url) = self.url.take() {
Expand Down Expand Up @@ -1054,7 +1074,10 @@ impl MicrosoftAzureBuilder {
let http_client = http.connect(&config.client_options)?;
let client = Arc::new(AzureClient::new(config, http_client));

Ok(MicrosoftAzure { client })
Ok(MicrosoftAzure {
client,
capabilities: self.capabilities.map(|x| x.get()).transpose()?,
})
}
}

Expand Down Expand Up @@ -1097,6 +1120,7 @@ pub fn split_sas(sas: &str) -> Result<Vec<(String, String)>> {
#[cfg(test)]
mod tests {
use super::*;
use crate::Capability;
use std::collections::HashMap;

#[test]
Expand Down Expand Up @@ -1244,6 +1268,7 @@ mod tests {
("azure_client_id", azure_client_id),
("azure_storage_account_name", azure_storage_account_name),
("azure_storage_token", azure_storage_token),
("azure_capabilities", "ordered_listing"),
]);

let builder = options
Expand All @@ -1254,6 +1279,26 @@ mod tests {
assert_eq!(builder.client_id.unwrap(), azure_client_id);
assert_eq!(builder.account_name.unwrap(), azure_storage_account_name);
assert_eq!(builder.bearer_token.unwrap(), azure_storage_token);
assert!(
builder
.capabilities
.unwrap()
.get()
.unwrap()
.has(Capability::OrderedListing)
);
}

#[test]
fn azure_test_config_get_value() {
let builder = MicrosoftAzureBuilder::new()
.with_config(AzureConfigKey::Capabilities, "ordered_listing");
assert_eq!(
builder
.get_config_value(&"azure_capabilities".parse().unwrap())
.unwrap(),
"ordered_listing"
);
}

#[test]
Expand Down
Loading