Skip to content

Commit 771372e

Browse files
singhsaabirSaabir Singh
andauthored
fix(aws): Include default headers in signature calculation (#484) (#636)
* fix(aws): Include default headers in signature calculation (#484) When using `ClientOptions::with_default_headers()` to set S3 metadata headers (like `x-amz-meta-*` or `x-amz-tagging`), these headers were not included in the AWS SigV4 signature calculation, causing S3 to reject requests with "headers present in the request which were not signed". This fix adds default headers to the request before signing in all three S3 request paths: `S3Client::request()`, `bulk_delete_request()`, and `get_request()`. * test(aws): Extend default headers signing test to cover all request paths Add test coverage for bulk_delete_request and get_request to verify default headers are included in signature calculation. Extract assert_default_headers_signed and default_headers_config helpers to reduce duplication. * style(aws): Apply rustfmt formatting --------- Co-authored-by: Saabir Singh <saabirs@amazon.com>
1 parent 50e1229 commit 771372e

2 files changed

Lines changed: 134 additions & 2 deletions

File tree

src/aws/client.rs

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,9 +467,13 @@ impl S3Client {
467467

468468
pub(crate) fn request<'a>(&'a self, method: Method, path: &'a Path) -> Request<'a> {
469469
let url = self.config.path_url(path);
470+
let mut builder = self.client.request(method, url);
471+
if let Some(headers) = self.config.client_options.get_default_headers() {
472+
builder = builder.headers(headers.clone());
473+
}
470474
Request {
471475
path,
472-
builder: self.client.request(method, url),
476+
builder,
473477
payload: None,
474478
payload_sha256: None,
475479
config: &self.config,
@@ -535,6 +539,9 @@ impl S3Client {
535539
let body = Bytes::from(buffer);
536540

537541
let mut builder = self.client.request(Method::POST, url);
542+
if let Some(headers) = self.config.client_options.get_default_headers() {
543+
builder = builder.headers(headers.clone());
544+
}
538545

539546
let digest = digest::digest(&digest::SHA256, &body);
540547
builder = builder.header(SHA256_CHECKSUM, BASE64_STANDARD.encode(digest));
@@ -863,6 +870,9 @@ impl GetClient for S3Client {
863870
};
864871

865872
let mut builder = self.client.request(method, url);
873+
if let Some(headers) = self.config.client_options.get_default_headers() {
874+
builder = builder.headers(headers.clone());
875+
}
866876
if self
867877
.config
868878
.encryption_headers
@@ -958,10 +968,15 @@ fn encode_path(path: &Path) -> PercentEncode<'_> {
958968
#[cfg(test)]
959969
mod tests {
960970
use super::*;
971+
use crate::GetOptions;
961972
use crate::client::HttpClient;
973+
use crate::client::get::GetClient;
962974
use crate::client::mock_server::MockServer;
975+
use crate::client::retry::RetryContext;
963976
use http::Response;
964-
use http::header::CONTENT_LENGTH;
977+
use http::header::{AUTHORIZATION, CONTENT_LENGTH};
978+
use hyper::Request;
979+
use hyper::body::Incoming;
965980

966981
#[tokio::test]
967982
async fn test_create_multipart_has_content_length() {
@@ -1010,4 +1025,116 @@ mod tests {
10101025
assert_eq!(result.unwrap(), "test-upload-id");
10111026
mock.shutdown().await;
10121027
}
1028+
1029+
fn assert_default_headers_signed(req: &Request<Incoming>) {
1030+
assert_eq!(req.headers().get("x-amz-meta-test").unwrap(), "test-value");
1031+
assert_eq!(req.headers().get("x-amz-tagging").unwrap(), "key=value");
1032+
1033+
let auth = req.headers().get(AUTHORIZATION).unwrap().to_str().unwrap();
1034+
assert!(
1035+
auth.contains("x-amz-meta-test"),
1036+
"x-amz-meta-test not in SignedHeaders: {auth}"
1037+
);
1038+
assert!(
1039+
auth.contains("x-amz-tagging"),
1040+
"x-amz-tagging not in SignedHeaders: {auth}"
1041+
);
1042+
}
1043+
1044+
fn default_headers_config(mock: &MockServer) -> S3Config {
1045+
let credential = AwsCredential {
1046+
key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1047+
secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
1048+
token: None,
1049+
};
1050+
1051+
let mut default_headers = HeaderMap::new();
1052+
default_headers.insert("x-amz-meta-test", "test-value".parse().unwrap());
1053+
default_headers.insert("x-amz-tagging", "key=value".parse().unwrap());
1054+
1055+
S3Config {
1056+
bucket_endpoint: mock.url().to_string(),
1057+
bucket: "test-bucket".to_string(),
1058+
region: "us-east-1".to_string(),
1059+
credentials: Arc::new(crate::StaticCredentialProvider::new(credential)),
1060+
client_options: ClientOptions::new()
1061+
.with_allow_http(true)
1062+
.with_default_headers(default_headers),
1063+
skip_signature: false,
1064+
session_provider: None,
1065+
retry_config: Default::default(),
1066+
sign_payload: false,
1067+
disable_tagging: false,
1068+
checksum: None,
1069+
copy_if_not_exists: None,
1070+
conditional_put: Default::default(),
1071+
encryption_headers: Default::default(),
1072+
request_payer: false,
1073+
}
1074+
}
1075+
1076+
#[tokio::test]
1077+
async fn test_default_headers_signed_request() {
1078+
let mock = MockServer::new().await;
1079+
mock.push_fn(|req| {
1080+
assert_default_headers_signed(&req);
1081+
Response::builder()
1082+
.status(200)
1083+
.header("etag", "\"test-etag\"")
1084+
.body(String::new())
1085+
.unwrap()
1086+
});
1087+
1088+
let config = default_headers_config(&mock);
1089+
let client = S3Client::new(config, HttpClient::new(reqwest::Client::new()));
1090+
let result = client
1091+
.request(Method::PUT, &Path::from("test"))
1092+
.with_payload(PutPayload::default())
1093+
.do_put()
1094+
.await;
1095+
1096+
assert!(result.is_ok());
1097+
mock.shutdown().await;
1098+
}
1099+
1100+
#[tokio::test]
1101+
async fn test_default_headers_signed_bulk_delete() {
1102+
let mock = MockServer::new().await;
1103+
mock.push_fn(|req| {
1104+
assert_default_headers_signed(&req);
1105+
Response::builder()
1106+
.status(200)
1107+
.body("<DeleteResult><Deleted><Key>test</Key></Deleted></DeleteResult>".to_string())
1108+
.unwrap()
1109+
});
1110+
1111+
let config = default_headers_config(&mock);
1112+
let client = S3Client::new(config, HttpClient::new(reqwest::Client::new()));
1113+
let result = client.bulk_delete_request(vec![Path::from("test")]).await;
1114+
1115+
assert!(result.is_ok());
1116+
mock.shutdown().await;
1117+
}
1118+
1119+
#[tokio::test]
1120+
async fn test_default_headers_signed_get_request() {
1121+
let mock = MockServer::new().await;
1122+
mock.push_fn(|req| {
1123+
assert_default_headers_signed(&req);
1124+
Response::builder()
1125+
.status(200)
1126+
.body("test-body".to_string())
1127+
.unwrap()
1128+
});
1129+
1130+
let config = default_headers_config(&mock);
1131+
let client = S3Client::new(config, HttpClient::new(reqwest::Client::new()));
1132+
let mut ctx = RetryContext::new(&client.config.retry_config);
1133+
let result = client
1134+
.get_request(&mut ctx, &Path::from("test"), GetOptions::default())
1135+
.await;
1136+
1137+
assert!(result.is_ok());
1138+
mock.shutdown().await;
1139+
}
10131140
}

src/client/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,11 @@ impl ClientOptions {
700700
self
701701
}
702702

703+
/// Get the default headers defined through `ClientOptions::with_default_headers`
704+
pub fn get_default_headers(&self) -> Option<&HeaderMap> {
705+
self.default_headers.as_ref()
706+
}
707+
703708
/// Get the mime type for the file in `path` to be uploaded
704709
///
705710
/// Gets the file extension from `path`, and returns the

0 commit comments

Comments
 (0)