Skip to content

Commit aede8aa

Browse files
authored
chore: Introduce client cache support and add to signers (#729)
* chore: Introduce client cache support and add to signers * chore: PR suggestions * chore: PR suggestions * chore: Add unit tests
1 parent 0795504 commit aede8aa

10 files changed

Lines changed: 805 additions & 190 deletions

File tree

src/services/aws_kms/mod.rs

Lines changed: 121 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ use aws_sdk_kms::{
4040
};
4141
use once_cell::sync::Lazy;
4242
use serde::Serialize;
43-
use std::collections::HashMap;
43+
use std::{collections::HashMap, sync::Arc};
4444
use tokio::sync::RwLock;
4545

4646
use crate::{
4747
models::{Address, AwsKmsSignerConfig},
48-
services::signer::evm::utils::recover_evm_signature_from_der,
48+
services::{
49+
client_cache::AsyncClientCache, signer::evm::utils::recover_evm_signature_from_der,
50+
},
4951
utils::{
5052
self, derive_ethereum_address_from_der, derive_solana_address_from_der,
5153
derive_stellar_address_from_der,
@@ -191,9 +193,71 @@ static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
191193
static KMS_ED25519_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
192194
Lazy::new(|| RwLock::new(HashMap::new()));
193195

196+
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
197+
struct AwsKmsClientKey {
198+
region: String,
199+
}
200+
201+
static KMS_CLIENT_CACHE: Lazy<AsyncClientCache<AwsKmsClientKey, Client>> =
202+
Lazy::new(AsyncClientCache::new);
203+
204+
/// Get or create a shared AWS KMS SDK client for the given signer config.
205+
/// Keyed by resolved region — one client serves all KMS keys in that region.
206+
async fn get_or_create_kms_client(config: &AwsKmsSignerConfig) -> AwsKmsResult<Arc<Client>> {
207+
let resolved_region = resolve_aws_region(config).await?;
208+
let key = AwsKmsClientKey {
209+
region: resolved_region.clone(),
210+
};
211+
212+
KMS_CLIENT_CACHE
213+
.get_or_try_init(key, || async {
214+
debug!(
215+
region = %resolved_region,
216+
"Creating new AWS KMS client"
217+
);
218+
let auth_config = aws_config::defaults(BehaviorVersion::latest())
219+
.region(Region::new(resolved_region))
220+
.load()
221+
.await;
222+
223+
// Client::new() can panic in environments without TLS root certificates
224+
// (e.g., stripped containers). Catch the panic and return a typed error.
225+
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Client::new(&auth_config)))
226+
.map_err(|panic| {
227+
let msg = panic
228+
.downcast_ref::<String>()
229+
.map(|s| s.as_str())
230+
.or_else(|| panic.downcast_ref::<&str>().copied())
231+
.unwrap_or("unknown panic");
232+
AwsKmsError::ConfigError(format!(
233+
"Failed to initialize AWS KMS client (check TLS root certificates): {msg}"
234+
))
235+
})
236+
})
237+
.await
238+
}
239+
240+
/// Resolve the AWS region from config or the default provider chain.
241+
async fn resolve_aws_region(config: &AwsKmsSignerConfig) -> AwsKmsResult<String> {
242+
if let Some(region) = &config.region {
243+
return Ok(region.clone());
244+
}
245+
246+
let provider = RegionProviderChain::default_provider();
247+
provider
248+
.region()
249+
.await
250+
.map(|r| r.to_string())
251+
.ok_or_else(|| {
252+
AwsKmsError::ConfigError(
253+
"AWS region not specified and could not be resolved from environment".to_string(),
254+
)
255+
})
256+
}
257+
194258
#[derive(Debug, Clone)]
195259
pub struct AwsKmsClient {
196-
inner: Client,
260+
inner: Arc<Client>,
197261
}
198262

199263
#[async_trait]
@@ -228,10 +292,8 @@ impl AwsKmsK256 for AwsKmsClient {
228292
))?
229293
.into_inner();
230294

231-
// Cache the result
232295
let mut cache_write = KMS_DER_PK_CACHE.write().await;
233296
cache_write.insert(key_id.to_string(), der_pk_blob.clone());
234-
drop(cache_write);
235297

236298
Ok(der_pk_blob)
237299
}
@@ -297,10 +359,8 @@ impl AwsKmsEd25519 for AwsKmsClient {
297359
))?
298360
.into_inner();
299361

300-
// Cache the result
301362
let mut cache_write = KMS_ED25519_PK_CACHE.write().await;
302363
cache_write.insert(key_id.to_string(), der_pk_blob.clone());
303-
drop(cache_write);
304364

305365
Ok(der_pk_blob)
306366
}
@@ -353,20 +413,13 @@ pub struct AwsKmsService<T: AwsKmsK256 + AwsKmsEd25519 + Clone = AwsKmsClient> {
353413

354414
impl AwsKmsService<AwsKmsClient> {
355415
pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
356-
let region_provider =
357-
RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider();
358-
359-
let auth_config = aws_config::defaults(BehaviorVersion::latest())
360-
.region(region_provider)
361-
.load()
362-
.await;
363-
let client = AwsKmsClient {
364-
inner: Client::new(&auth_config),
365-
};
416+
let shared_client = get_or_create_kms_client(&config).await?;
366417

367418
Ok(Self {
368419
kms_key_id: config.key_id,
369-
client,
420+
client: AwsKmsClient {
421+
inner: shared_client,
422+
},
370423
})
371424
}
372425
}
@@ -842,4 +895,54 @@ pub mod tests {
842895
}
843896

844897
// Note: Ed25519 DER parsing tests are in utils/ed25519.rs
898+
899+
#[tokio::test]
900+
async fn test_kms_client_cache_same_region_shares_client() {
901+
let config1 = AwsKmsSignerConfig {
902+
region: Some("us-west-2".to_string()),
903+
key_id: "key-aaa".to_string(),
904+
};
905+
let config2 = AwsKmsSignerConfig {
906+
region: Some("us-west-2".to_string()),
907+
key_id: "key-bbb".to_string(),
908+
};
909+
910+
let result1 = get_or_create_kms_client(&config1).await;
911+
let result2 = get_or_create_kms_client(&config2).await;
912+
913+
match (result1, result2) {
914+
(Ok(client1), Ok(client2)) => {
915+
assert!(Arc::ptr_eq(&client1, &client2));
916+
}
917+
(Err(AwsKmsError::ConfigError(msg)), _) | (_, Err(AwsKmsError::ConfigError(msg))) => {
918+
// In environments without TLS roots, the panic is caught as ConfigError
919+
assert!(
920+
msg.contains("TLS root certificates"),
921+
"Expected TLS-related config error, got: {msg}"
922+
);
923+
}
924+
(Err(e), _) | (_, Err(e)) => {
925+
panic!("Unexpected error: {e:?}");
926+
}
927+
}
928+
}
929+
930+
#[tokio::test]
931+
async fn test_kms_client_returns_config_error_when_region_missing() {
932+
let config = AwsKmsSignerConfig {
933+
region: None,
934+
key_id: "test-key".to_string(),
935+
};
936+
937+
// Covers the missing-region branch in resolve_aws_region().
938+
// Does not exercise Client::new() panic handling (that requires TLS root absence).
939+
let result = get_or_create_kms_client(&config).await;
940+
match result {
941+
Err(AwsKmsError::ConfigError(_)) => {}
942+
Ok(_) => panic!(
943+
"Expected missing-region error; AWS_REGION/AWS_DEFAULT_REGION may be set in env"
944+
),
945+
Err(e) => panic!("Expected ConfigError, got: {e:?}"),
946+
}
947+
}
845948
}

0 commit comments

Comments
 (0)