Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 121 additions & 18 deletions src/services/aws_kms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ use aws_sdk_kms::{
};
use once_cell::sync::Lazy;
use serde::Serialize;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;

use crate::{
models::{Address, AwsKmsSignerConfig},
services::signer::evm::utils::recover_evm_signature_from_der,
services::{
client_cache::AsyncClientCache, signer::evm::utils::recover_evm_signature_from_der,
},
utils::{
self, derive_ethereum_address_from_der, derive_solana_address_from_der,
derive_stellar_address_from_der,
Expand Down Expand Up @@ -191,9 +193,71 @@ static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
static KMS_ED25519_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));

#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct AwsKmsClientKey {
region: String,
}

static KMS_CLIENT_CACHE: Lazy<AsyncClientCache<AwsKmsClientKey, Client>> =
Lazy::new(AsyncClientCache::new);

/// Get or create a shared AWS KMS SDK client for the given signer config.
/// Keyed by resolved region — one client serves all KMS keys in that region.
async fn get_or_create_kms_client(config: &AwsKmsSignerConfig) -> AwsKmsResult<Arc<Client>> {
let resolved_region = resolve_aws_region(config).await?;
let key = AwsKmsClientKey {
region: resolved_region.clone(),
};

KMS_CLIENT_CACHE
.get_or_try_init(key, || async {
debug!(
region = %resolved_region,
"Creating new AWS KMS client"
);
let auth_config = aws_config::defaults(BehaviorVersion::latest())
.region(Region::new(resolved_region))
.load()
.await;

// Client::new() can panic in environments without TLS root certificates
// (e.g., stripped containers). Catch the panic and return a typed error.
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Client::new(&auth_config)))
.map_err(|panic| {
let msg = panic
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| panic.downcast_ref::<&str>().copied())
.unwrap_or("unknown panic");
AwsKmsError::ConfigError(format!(
"Failed to initialize AWS KMS client (check TLS root certificates): {msg}"
))
})
})
.await
}

/// Resolve the AWS region from config or the default provider chain.
async fn resolve_aws_region(config: &AwsKmsSignerConfig) -> AwsKmsResult<String> {
if let Some(region) = &config.region {
return Ok(region.clone());
}

let provider = RegionProviderChain::default_provider();
provider
.region()
.await
.map(|r| r.to_string())
.ok_or_else(|| {
AwsKmsError::ConfigError(
"AWS region not specified and could not be resolved from environment".to_string(),
)
})
}

#[derive(Debug, Clone)]
pub struct AwsKmsClient {
inner: Client,
inner: Arc<Client>,
}

#[async_trait]
Expand Down Expand Up @@ -228,10 +292,8 @@ impl AwsKmsK256 for AwsKmsClient {
))?
.into_inner();

// Cache the result
let mut cache_write = KMS_DER_PK_CACHE.write().await;
cache_write.insert(key_id.to_string(), der_pk_blob.clone());
drop(cache_write);

Ok(der_pk_blob)
}
Expand Down Expand Up @@ -297,10 +359,8 @@ impl AwsKmsEd25519 for AwsKmsClient {
))?
.into_inner();

// Cache the result
let mut cache_write = KMS_ED25519_PK_CACHE.write().await;
cache_write.insert(key_id.to_string(), der_pk_blob.clone());
drop(cache_write);

Ok(der_pk_blob)
}
Expand Down Expand Up @@ -353,20 +413,13 @@ pub struct AwsKmsService<T: AwsKmsK256 + AwsKmsEd25519 + Clone = AwsKmsClient> {

impl AwsKmsService<AwsKmsClient> {
pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
let region_provider =
RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider();

let auth_config = aws_config::defaults(BehaviorVersion::latest())
.region(region_provider)
.load()
.await;
let client = AwsKmsClient {
inner: Client::new(&auth_config),
};
let shared_client = get_or_create_kms_client(&config).await?;

Ok(Self {
kms_key_id: config.key_id,
client,
client: AwsKmsClient {
inner: shared_client,
},
})
}
}
Expand Down Expand Up @@ -842,4 +895,54 @@ pub mod tests {
}

// Note: Ed25519 DER parsing tests are in utils/ed25519.rs

#[tokio::test]
async fn test_kms_client_cache_same_region_shares_client() {
let config1 = AwsKmsSignerConfig {
region: Some("us-west-2".to_string()),
key_id: "key-aaa".to_string(),
};
let config2 = AwsKmsSignerConfig {
region: Some("us-west-2".to_string()),
key_id: "key-bbb".to_string(),
};

let result1 = get_or_create_kms_client(&config1).await;
let result2 = get_or_create_kms_client(&config2).await;

match (result1, result2) {
(Ok(client1), Ok(client2)) => {
assert!(Arc::ptr_eq(&client1, &client2));
}
(Err(AwsKmsError::ConfigError(msg)), _) | (_, Err(AwsKmsError::ConfigError(msg))) => {
// In environments without TLS roots, the panic is caught as ConfigError
assert!(
msg.contains("TLS root certificates"),
"Expected TLS-related config error, got: {msg}"
);
}
(Err(e), _) | (_, Err(e)) => {
panic!("Unexpected error: {e:?}");
}
}
}

#[tokio::test]
async fn test_kms_client_returns_config_error_when_region_missing() {
let config = AwsKmsSignerConfig {
region: None,
key_id: "test-key".to_string(),
};

// Covers the missing-region branch in resolve_aws_region().
// Does not exercise Client::new() panic handling (that requires TLS root absence).
let result = get_or_create_kms_client(&config).await;
match result {
Err(AwsKmsError::ConfigError(_)) => {}
Ok(_) => panic!(
"Expected missing-region error; AWS_REGION/AWS_DEFAULT_REGION may be set in env"
),
Err(e) => panic!("Expected ConfigError, got: {e:?}"),
}
}
}
Loading
Loading