diff --git a/src/services/aws_kms/mod.rs b/src/services/aws_kms/mod.rs index 1790048f2..968da9336 100644 --- a/src/services/aws_kms/mod.rs +++ b/src/services/aws_kms/mod.rs @@ -40,12 +40,14 @@ use aws_sdk_kms::{ }; use once_cell::sync::Lazy; use serde::Serialize; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use tokio::sync::RwLock; use crate::{ models::{Address, AwsKmsSignerConfig}, - services::signer::evm::utils::recover_evm_signature_from_der, + services::{ + client_cache::AsyncClientCache, signer::evm::utils::recover_evm_signature_from_der, + }, utils::{ self, derive_ethereum_address_from_der, derive_solana_address_from_der, derive_stellar_address_from_der, @@ -191,9 +193,71 @@ static KMS_DER_PK_CACHE: Lazy>>> = static KMS_ED25519_PK_CACHE: Lazy>>> = Lazy::new(|| RwLock::new(HashMap::new())); +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct AwsKmsClientKey { + region: String, +} + +static KMS_CLIENT_CACHE: Lazy> = + Lazy::new(AsyncClientCache::new); + +/// Get or create a shared AWS KMS SDK client for the given signer config. +/// Keyed by resolved region — one client serves all KMS keys in that region. +async fn get_or_create_kms_client(config: &AwsKmsSignerConfig) -> AwsKmsResult> { + let resolved_region = resolve_aws_region(config).await?; + let key = AwsKmsClientKey { + region: resolved_region.clone(), + }; + + KMS_CLIENT_CACHE + .get_or_try_init(key, || async { + debug!( + region = %resolved_region, + "Creating new AWS KMS client" + ); + let auth_config = aws_config::defaults(BehaviorVersion::latest()) + .region(Region::new(resolved_region)) + .load() + .await; + + // Client::new() can panic in environments without TLS root certificates + // (e.g., stripped containers). Catch the panic and return a typed error. + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Client::new(&auth_config))) + .map_err(|panic| { + let msg = panic + .downcast_ref::() + .map(|s| s.as_str()) + .or_else(|| panic.downcast_ref::<&str>().copied()) + .unwrap_or("unknown panic"); + AwsKmsError::ConfigError(format!( + "Failed to initialize AWS KMS client (check TLS root certificates): {msg}" + )) + }) + }) + .await +} + +/// Resolve the AWS region from config or the default provider chain. +async fn resolve_aws_region(config: &AwsKmsSignerConfig) -> AwsKmsResult { + if let Some(region) = &config.region { + return Ok(region.clone()); + } + + let provider = RegionProviderChain::default_provider(); + provider + .region() + .await + .map(|r| r.to_string()) + .ok_or_else(|| { + AwsKmsError::ConfigError( + "AWS region not specified and could not be resolved from environment".to_string(), + ) + }) +} + #[derive(Debug, Clone)] pub struct AwsKmsClient { - inner: Client, + inner: Arc, } #[async_trait] @@ -228,10 +292,8 @@ impl AwsKmsK256 for AwsKmsClient { ))? .into_inner(); - // Cache the result let mut cache_write = KMS_DER_PK_CACHE.write().await; cache_write.insert(key_id.to_string(), der_pk_blob.clone()); - drop(cache_write); Ok(der_pk_blob) } @@ -297,10 +359,8 @@ impl AwsKmsEd25519 for AwsKmsClient { ))? .into_inner(); - // Cache the result let mut cache_write = KMS_ED25519_PK_CACHE.write().await; cache_write.insert(key_id.to_string(), der_pk_blob.clone()); - drop(cache_write); Ok(der_pk_blob) } @@ -353,20 +413,13 @@ pub struct AwsKmsService { impl AwsKmsService { pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult { - let region_provider = - RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider(); - - let auth_config = aws_config::defaults(BehaviorVersion::latest()) - .region(region_provider) - .load() - .await; - let client = AwsKmsClient { - inner: Client::new(&auth_config), - }; + let shared_client = get_or_create_kms_client(&config).await?; Ok(Self { kms_key_id: config.key_id, - client, + client: AwsKmsClient { + inner: shared_client, + }, }) } } @@ -842,4 +895,54 @@ pub mod tests { } // Note: Ed25519 DER parsing tests are in utils/ed25519.rs + + #[tokio::test] + async fn test_kms_client_cache_same_region_shares_client() { + let config1 = AwsKmsSignerConfig { + region: Some("us-west-2".to_string()), + key_id: "key-aaa".to_string(), + }; + let config2 = AwsKmsSignerConfig { + region: Some("us-west-2".to_string()), + key_id: "key-bbb".to_string(), + }; + + let result1 = get_or_create_kms_client(&config1).await; + let result2 = get_or_create_kms_client(&config2).await; + + match (result1, result2) { + (Ok(client1), Ok(client2)) => { + assert!(Arc::ptr_eq(&client1, &client2)); + } + (Err(AwsKmsError::ConfigError(msg)), _) | (_, Err(AwsKmsError::ConfigError(msg))) => { + // In environments without TLS roots, the panic is caught as ConfigError + assert!( + msg.contains("TLS root certificates"), + "Expected TLS-related config error, got: {msg}" + ); + } + (Err(e), _) | (_, Err(e)) => { + panic!("Unexpected error: {e:?}"); + } + } + } + + #[tokio::test] + async fn test_kms_client_returns_config_error_when_region_missing() { + let config = AwsKmsSignerConfig { + region: None, + key_id: "test-key".to_string(), + }; + + // Covers the missing-region branch in resolve_aws_region(). + // Does not exercise Client::new() panic handling (that requires TLS root absence). + let result = get_or_create_kms_client(&config).await; + match result { + Err(AwsKmsError::ConfigError(_)) => {} + Ok(_) => panic!( + "Expected missing-region error; AWS_REGION/AWS_DEFAULT_REGION may be set in env" + ), + Err(e) => panic!("Expected ConfigError, got: {e:?}"), + } + } } diff --git a/src/services/client_cache.rs b/src/services/client_cache.rs new file mode 100644 index 000000000..5ef456c2c --- /dev/null +++ b/src/services/client_cache.rs @@ -0,0 +1,363 @@ +//! # Client Cache +//! +//! Typed caching primitives for long-lived SDK and HTTP clients. +//! +//! - [`AsyncClientCache`] — for client construction that requires `.await` +//! (e.g., AWS KMS via `aws_config::load().await`) +//! - [`SyncClientCache`] — for synchronous client constructors +//! (e.g., `soroban_rs::Client::new(url)`, Solana `RpcClient::new(...)`) +//! +//! Both guarantee at most one in-flight initializer per key. If `init` returns +//! `Err`, the entry is not cached and subsequent calls will retry initialization. + +use std::{ + hash::Hash, + sync::{Arc, Mutex}, +}; + +use dashmap::DashMap; +use tokio::sync::OnceCell; + +// --------------------------------------------------------------------------- +// AsyncClientCache +// --------------------------------------------------------------------------- + +/// A thread-safe cache for async client construction. At most one caller +/// runs the init closure per key; others `.await` the same result. +/// If init returns `Err`, the entry remains uninitialized and the next caller retries. +#[derive(Debug)] +pub struct AsyncClientCache { + entries: DashMap>>>, +} + +impl AsyncClientCache { + pub fn new() -> Self { + Self { + entries: DashMap::new(), + } + } + + pub async fn get_or_try_init(&self, key: K, init: F) -> Result, E> + where + F: FnOnce() -> Fut, + Fut: std::future::Future>, + { + let cell = self + .entries + .entry(key) + .or_insert_with(|| Arc::new(OnceCell::new())) + .clone(); + + let value = cell + .get_or_try_init(|| async { Ok(Arc::new(init().await?)) }) + .await?; + + Ok(Arc::clone(value)) + } + + #[cfg(test)] + pub fn remove(&self, key: &K) { + self.entries.remove(key); + } + + #[cfg(test)] + pub fn len(&self) -> usize { + self.entries.len() + } +} + +impl Default for AsyncClientCache { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// SyncClientCache +// --------------------------------------------------------------------------- + +/// A thread-safe cache for synchronous client construction. At most one caller +/// runs the init closure per key; others block on a mutex and receive the same result. +/// If init returns `Err`, the entry remains uninitialized and the next caller retries. +#[derive(Debug)] +pub struct SyncClientCache { + entries: DashMap>>>>, +} + +impl SyncClientCache { + pub fn new() -> Self { + Self { + entries: DashMap::new(), + } + } + + pub fn get_or_try_init(&self, key: K, init: F) -> Result, E> + where + F: FnOnce() -> Result, + { + let cell = self + .entries + .entry(key) + .or_insert_with(|| Arc::new(Mutex::new(None))) + .clone(); + + let mut guard = cell.lock().unwrap_or_else(|e| e.into_inner()); + if let Some(value) = &*guard { + return Ok(Arc::clone(value)); + } + + let value = Arc::new(init()?); + *guard = Some(Arc::clone(&value)); + Ok(value) + } + + #[cfg(test)] + pub fn remove(&self, key: &K) { + self.entries.remove(key); + } + + #[cfg(test)] + pub fn len(&self) -> usize { + self.entries.len() + } +} + +impl Default for SyncClientCache { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + + // ── AsyncClientCache ───────────────────────────────────────────── + + #[tokio::test] + async fn async_same_key_returns_same_arc() { + let cache = AsyncClientCache::::new(); + + let v1 = cache + .get_or_try_init("key1".to_string(), || async { + Ok::<_, String>("value1".to_string()) + }) + .await + .unwrap(); + + let v2 = cache + .get_or_try_init("key1".to_string(), || async { + Ok::<_, String>("should_not_be_used".to_string()) + }) + .await + .unwrap(); + + assert!(Arc::ptr_eq(&v1, &v2)); + assert_eq!(*v1, "value1"); + } + + #[tokio::test] + async fn async_different_keys_return_different_values() { + let cache = AsyncClientCache::::new(); + + let v1 = cache + .get_or_try_init("key1".to_string(), || async { + Ok::<_, String>("value1".to_string()) + }) + .await + .unwrap(); + + let v2 = cache + .get_or_try_init("key2".to_string(), || async { + Ok::<_, String>("value2".to_string()) + }) + .await + .unwrap(); + + assert!(!Arc::ptr_eq(&v1, &v2)); + assert_eq!(cache.len(), 2); + } + + #[tokio::test] + async fn async_concurrent_access_creates_once() { + let cache = Arc::new(AsyncClientCache::::new()); + let init_count = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::new(); + for _ in 0..50 { + let cache = Arc::clone(&cache); + let count = Arc::clone(&init_count); + handles.push(tokio::spawn(async move { + cache + .get_or_try_init("shared_key".to_string(), || { + let count = Arc::clone(&count); + async move { + count.fetch_add(1, Ordering::SeqCst); + tokio::task::yield_now().await; + Ok::<_, String>("shared_value".to_string()) + } + }) + .await + .unwrap() + })); + } + + let results: Vec> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + for result in &results { + assert!(Arc::ptr_eq(result, &results[0])); + } + assert_eq!(init_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn async_error_does_not_cache() { + let cache = AsyncClientCache::::new(); + + let result = cache + .get_or_try_init("key1".to_string(), || async { + Err::("fail".to_string()) + }) + .await; + assert!(result.is_err()); + + let v = cache + .get_or_try_init("key1".to_string(), || async { + Ok::<_, String>("recovered".to_string()) + }) + .await + .unwrap(); + assert_eq!(*v, "recovered"); + } + + #[tokio::test] + async fn async_remove_allows_reinit() { + let cache = AsyncClientCache::::new(); + + let v1 = cache + .get_or_try_init("key1".to_string(), || async { + Ok::<_, String>("first".to_string()) + }) + .await + .unwrap(); + + cache.remove(&"key1".to_string()); + + let v2 = cache + .get_or_try_init("key1".to_string(), || async { + Ok::<_, String>("second".to_string()) + }) + .await + .unwrap(); + + assert!(!Arc::ptr_eq(&v1, &v2)); + assert_eq!(*v2, "second"); + } + + // ── SyncClientCache ────────────────────────────────────────────── + + #[test] + fn sync_same_key_returns_same_arc() { + let cache = SyncClientCache::::new(); + + let v1 = cache + .get_or_try_init("key1".to_string(), || Ok::<_, String>("value1".to_string())) + .unwrap(); + + let v2 = cache + .get_or_try_init("key1".to_string(), || { + Ok::<_, String>("should_not_be_used".to_string()) + }) + .unwrap(); + + assert!(Arc::ptr_eq(&v1, &v2)); + assert_eq!(*v1, "value1"); + } + + #[test] + fn sync_different_keys_return_different_values() { + let cache = SyncClientCache::::new(); + + let v1 = cache + .get_or_try_init("key1".to_string(), || Ok::<_, String>("value1".to_string())) + .unwrap(); + + let v2 = cache + .get_or_try_init("key2".to_string(), || Ok::<_, String>("value2".to_string())) + .unwrap(); + + assert!(!Arc::ptr_eq(&v1, &v2)); + assert_eq!(cache.len(), 2); + } + + #[test] + fn sync_concurrent_access_creates_once() { + let cache = Arc::new(SyncClientCache::::new()); + let init_count = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::new(); + for _ in 0..50 { + let cache = Arc::clone(&cache); + let count = Arc::clone(&init_count); + handles.push(std::thread::spawn(move || { + cache + .get_or_try_init("shared_key".to_string(), || { + count.fetch_add(1, Ordering::SeqCst); + std::thread::yield_now(); + Ok::<_, String>("shared_value".to_string()) + }) + .unwrap() + })); + } + + let results: Vec> = handles.into_iter().map(|h| h.join().unwrap()).collect(); + + for result in &results { + assert!(Arc::ptr_eq(result, &results[0])); + } + assert_eq!(init_count.load(Ordering::SeqCst), 1); + } + + #[test] + fn sync_error_does_not_cache() { + let cache = SyncClientCache::::new(); + + let result = + cache.get_or_try_init("key1".to_string(), || Err::("fail".to_string())); + assert!(result.is_err()); + + let v = cache + .get_or_try_init("key1".to_string(), || { + Ok::<_, String>("recovered".to_string()) + }) + .unwrap(); + assert_eq!(*v, "recovered"); + } + + #[test] + fn sync_remove_allows_reinit() { + let cache = SyncClientCache::::new(); + + let v1 = cache + .get_or_try_init("key1".to_string(), || Ok::<_, String>("first".to_string())) + .unwrap(); + + cache.remove(&"key1".to_string()); + + let v2 = cache + .get_or_try_init("key1".to_string(), || Ok::<_, String>("second".to_string())) + .unwrap(); + + assert!(!Arc::ptr_eq(&v1, &v2)); + assert_eq!(*v2, "second"); + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs index 7e5c6ac2e..c9b699899 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -42,3 +42,5 @@ pub mod plugins; pub mod health; pub use health::*; + +pub(crate) mod client_cache; diff --git a/src/services/provider/evm/mod.rs b/src/services/provider/evm/mod.rs index 8fbf6c78e..ed7a6d52f 100644 --- a/src/services/provider/evm/mod.rs +++ b/src/services/provider/evm/mod.rs @@ -4,8 +4,6 @@ //! It implements common operations like getting balances, sending transactions, and querying //! blockchain state. -use std::time::Duration; - use alloy::{ network::AnyNetwork, primitives::{Bytes, TxKind, Uint}, @@ -30,20 +28,12 @@ type EvmProviderType = FillProvider< >; use async_trait::async_trait; use eyre::Result; -use reqwest::ClientBuilder as ReqwestClientBuilder; use serde_json; use tracing::debug; use super::rpc_selector::RpcSelector; use super::{retry_rpc_call, ProviderConfig, RetryConfig}; use crate::{ - constants::{ - DEFAULT_HTTP_CLIENT_CONNECT_TIMEOUT_SECONDS, - DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_INTERVAL_SECONDS, - DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_TIMEOUT_SECONDS, - DEFAULT_HTTP_CLIENT_POOL_IDLE_TIMEOUT_SECONDS, DEFAULT_HTTP_CLIENT_POOL_MAX_IDLE_PER_HOST, - DEFAULT_HTTP_CLIENT_TCP_KEEPALIVE_SECONDS, - }, models::{ BlockResponse, EvmTransactionData, RpcConfig, TransactionError, TransactionReceipt, U256, }, @@ -51,7 +41,7 @@ use crate::{ utils::mask_url, }; -use crate::utils::{create_secure_redirect_policy, validate_safe_url}; +use crate::utils::validate_safe_url; #[cfg(test)] use mockall::automock; @@ -210,7 +200,6 @@ impl EvmProvider { /// Initialize a provider for a given URL fn initialize_provider(&self, url: &str) -> Result { - // Re-validate URL security as a safety net let allowed_hosts = crate::config::ServerConfig::get_rpc_allowed_hosts(); let block_private_ips = crate::config::ServerConfig::get_rpc_block_private_ips(); validate_safe_url(url, &allowed_hosts, block_private_ips).map_err(|e| { @@ -222,25 +211,9 @@ impl EvmProvider { .parse() .map_err(|e| ProviderError::NetworkConfiguration(format!("Invalid URL format: {e}")))?; - // Using use_rustls_tls() forces the use of rustls instead of native-tls to support TLS 1.3 - let client = ReqwestClientBuilder::new() - .timeout(Duration::from_secs(self.timeout_seconds)) - .connect_timeout(Duration::from_secs(DEFAULT_HTTP_CLIENT_CONNECT_TIMEOUT_SECONDS)) - .pool_max_idle_per_host(DEFAULT_HTTP_CLIENT_POOL_MAX_IDLE_PER_HOST) - .pool_idle_timeout(Duration::from_secs(DEFAULT_HTTP_CLIENT_POOL_IDLE_TIMEOUT_SECONDS)) - .tcp_keepalive(Duration::from_secs(DEFAULT_HTTP_CLIENT_TCP_KEEPALIVE_SECONDS)) - .http2_keep_alive_interval(Some(Duration::from_secs( - DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_INTERVAL_SECONDS, - ))) - .http2_keep_alive_timeout(Duration::from_secs( - DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_TIMEOUT_SECONDS, - )) - .use_rustls_tls() - // Allow only HTTP→HTTPS redirects on same host to handle legitimate protocol upgrades - // while preventing SSRF via redirect chains to different hosts - .redirect(create_secure_redirect_policy()) - .build() - .map_err(|e| ProviderError::Other(format!("Failed to build HTTP client: {e}")))?; + let client = super::build_rpc_http_client_with_timeout(std::time::Duration::from_secs( + self.timeout_seconds, + ))?; let mut transport = Http::new(rpc_url); transport.set_client(client); @@ -561,6 +534,7 @@ mod tests { use lazy_static::lazy_static; use std::str::FromStr; use std::sync::Mutex; + use std::time::Duration; lazy_static! { static ref EVM_TEST_ENV_MUTEX: Mutex<()> = Mutex::new(()); diff --git a/src/services/provider/mod.rs b/src/services/provider/mod.rs index f84d78f84..35108691f 100644 --- a/src/services/provider/mod.rs +++ b/src/services/provider/mod.rs @@ -1,7 +1,20 @@ use std::num::ParseIntError; +use std::time::Duration; + +use once_cell::sync::Lazy; +use reqwest::Client as ReqwestClient; +use tracing::debug; use crate::config::ServerConfig; +use crate::constants::{ + DEFAULT_HTTP_CLIENT_CONNECT_TIMEOUT_SECONDS, + DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_INTERVAL_SECONDS, + DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_TIMEOUT_SECONDS, + DEFAULT_HTTP_CLIENT_POOL_IDLE_TIMEOUT_SECONDS, DEFAULT_HTTP_CLIENT_POOL_MAX_IDLE_PER_HOST, + DEFAULT_HTTP_CLIENT_TCP_KEEPALIVE_SECONDS, +}; use crate::models::{EvmNetwork, RpcConfig, SolanaNetwork, StellarNetwork}; +use crate::utils::create_secure_redirect_policy; use serde::Serialize; use thiserror::Error; @@ -98,6 +111,60 @@ impl ProviderConfig { } } +/// Pre-configured `reqwest::ClientBuilder` with standard pool, keepalive, TLS, +/// and redirect settings. Callers chain on extras (e.g., `.timeout(...)`) then `.build()`. +fn base_rpc_client_builder() -> reqwest::ClientBuilder { + ReqwestClient::builder() + .connect_timeout(Duration::from_secs( + DEFAULT_HTTP_CLIENT_CONNECT_TIMEOUT_SECONDS, + )) + .pool_max_idle_per_host(DEFAULT_HTTP_CLIENT_POOL_MAX_IDLE_PER_HOST) + .pool_idle_timeout(Duration::from_secs( + DEFAULT_HTTP_CLIENT_POOL_IDLE_TIMEOUT_SECONDS, + )) + .tcp_keepalive(Duration::from_secs( + DEFAULT_HTTP_CLIENT_TCP_KEEPALIVE_SECONDS, + )) + .http2_keep_alive_interval(Some(Duration::from_secs( + DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_INTERVAL_SECONDS, + ))) + .http2_keep_alive_timeout(Duration::from_secs( + DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_TIMEOUT_SECONDS, + )) + .use_rustls_tls() + .redirect(create_secure_redirect_policy()) +} + +/// Shared `reqwest::Client` for RPC providers that set per-request timeouts +/// (e.g., Stellar raw HTTP). No request-level timeout is baked in. +static SHARED_RPC_HTTP_CLIENT: Lazy> = Lazy::new(|| { + debug!("Creating shared RPC HTTP client"); + base_rpc_client_builder() + .build() + .map_err(|e| format!("Failed to create shared RPC HTTP client: {e}")) +}); + +/// Get the shared RPC HTTP client (no per-request timeout). +pub fn get_shared_rpc_http_client() -> Result { + SHARED_RPC_HTTP_CLIENT + .as_ref() + .map(|c| c.clone()) + .map_err(|e| ProviderError::NetworkConfiguration(e.clone())) +} + +/// Build a new RPC HTTP client with standard settings plus a per-request timeout. +/// Use when the provider needs timeouts baked into the client (e.g., EVM via alloy transport). +pub fn build_rpc_http_client_with_timeout( + timeout: Duration, +) -> Result { + base_rpc_client_builder() + .timeout(timeout) + .build() + .map_err(|e| { + ProviderError::NetworkConfiguration(format!("Failed to build RPC HTTP client: {e}")) + }) +} + #[derive(Error, Debug, Serialize)] pub enum ProviderError { #[error("RPC client error: {0}")] diff --git a/src/services/provider/solana/mod.rs b/src/services/provider/solana/mod.rs index 7ca998d0e..e4b94a519 100644 --- a/src/services/provider/solana/mod.rs +++ b/src/services/provider/solana/mod.rs @@ -35,9 +35,14 @@ use spl_token_interface::state::Mint; use std::{str::FromStr, sync::Arc, time::Duration}; use thiserror::Error; +use once_cell::sync::Lazy; + use crate::{ models::{RpcConfig, SolanaTransactionStatus}, - services::provider::{retry_rpc_call, should_mark_provider_failed_by_status_code}, + services::{ + client_cache::SyncClientCache, + provider::{retry_rpc_call, should_mark_provider_failed_by_status_code}, + }, }; use super::ProviderError; @@ -48,6 +53,16 @@ use super::{ use crate::utils::validate_safe_url; +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct SolanaRpcClientKey { + url: String, + timeout_ms: u128, + commitment: CommitmentConfig, +} + +static SOLANA_RPC_CLIENT_CACHE: Lazy> = + Lazy::new(SyncClientCache::new); + /// Utility function to match error patterns by normalizing both strings. /// Removes spaces and converts to lowercase for flexible matching. /// @@ -602,7 +617,6 @@ impl SolanaProvider { /// which requires adding `solana-rpc-client` as a direct dependency. /// The URL security validation provides the primary SSRF defense for Solana. fn initialize_provider(&self, url: &str) -> Result, SolanaProviderError> { - // Layer 2 validation: Re-validate URL security as a safety net let allowed_hosts = crate::config::ServerConfig::get_rpc_allowed_hosts(); let block_private_ips = crate::config::ServerConfig::get_rpc_block_private_ips(); validate_safe_url(url, &allowed_hosts, block_private_ips).map_err(|e| { @@ -611,17 +625,23 @@ impl SolanaProvider { )) })?; - let rpc_url: Url = url.parse().map_err(|e| { - SolanaProviderError::NetworkConfiguration(format!("Invalid URL format: {e}")) - })?; - - let client = RpcClient::new_with_timeout_and_commitment( - rpc_url.to_string(), - self.timeout_seconds, - self.commitment, - ); - - Ok(Arc::new(client)) + let timeout = self.timeout_seconds; + let commitment = self.commitment; + let cache_key = SolanaRpcClientKey { + url: url.to_string(), + timeout_ms: timeout.as_millis(), + commitment, + }; + SOLANA_RPC_CLIENT_CACHE.get_or_try_init(cache_key, || { + let rpc_url: Url = url.parse().map_err(|e| { + SolanaProviderError::NetworkConfiguration(format!("Invalid URL format: {e}")) + })?; + Ok(RpcClient::new_with_timeout_and_commitment( + rpc_url.to_string(), + timeout, + commitment, + )) + }) } /// Retry helper for Solana RPC calls diff --git a/src/services/provider/stellar/mod.rs b/src/services/provider/stellar/mod.rs index 071bc3977..36ebec9a0 100644 --- a/src/services/provider/stellar/mod.rs +++ b/src/services/provider/stellar/mod.rs @@ -26,14 +26,10 @@ use std::sync::atomic::{AtomicU64, Ordering}; #[cfg(test)] use mockall::automock; -use crate::constants::{ - DEFAULT_HTTP_CLIENT_CONNECT_TIMEOUT_SECONDS, - DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_INTERVAL_SECONDS, - DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_TIMEOUT_SECONDS, - DEFAULT_HTTP_CLIENT_POOL_IDLE_TIMEOUT_SECONDS, DEFAULT_HTTP_CLIENT_POOL_MAX_IDLE_PER_HOST, - DEFAULT_HTTP_CLIENT_TCP_KEEPALIVE_SECONDS, -}; +use once_cell::sync::Lazy; + use crate::models::{JsonRpcId, RpcConfig}; +use crate::services::client_cache::SyncClientCache; use crate::services::provider::is_retriable_error; use crate::services::provider::retry::retry_rpc_call; use crate::services::provider::rpc_selector::RpcSelector; @@ -42,7 +38,7 @@ use crate::services::provider::RetryConfig; use crate::services::provider::{ProviderConfig, ProviderError}; // Reqwest client is used for raw JSON-RPC HTTP requests. Alias to avoid name clash with the // soroban `Client` type imported above. -use crate::utils::{create_secure_redirect_policy, validate_safe_url}; +use crate::utils::validate_safe_url; use reqwest::Client as ReqwestClient; use std::sync::Arc; use std::time::Duration; @@ -60,6 +56,11 @@ fn generate_unique_rpc_id() -> u64 { NEXT_ID.fetch_add(1, Ordering::Relaxed) } +/// Cache for soroban_rs Stellar RPC clients, keyed by URL. +/// Avoids recreating jsonrpsee HTTP clients on every retry attempt. +static STELLAR_RPC_CLIENT_CACHE: Lazy> = + Lazy::new(SyncClientCache::new); + /// Categorizes a Stellar client error into an appropriate `ProviderError` variant. /// /// This function analyzes the given error and maps it to a specific `ProviderError` variant: @@ -407,48 +408,35 @@ impl StellarProvider { self.selector.get_configs() } - /// Initialize a Stellar client for a given URL - fn initialize_provider(&self, url: &str) -> Result { - // Layer 2 validation: Re-validate URL security as a safety net + /// Get or create a cached Stellar RPC client for a given URL. + /// Reuses clients across retry attempts and provider instances. + fn initialize_provider(&self, url: &str) -> Result, ProviderError> { let allowed_hosts = crate::config::ServerConfig::get_rpc_allowed_hosts(); let block_private_ips = crate::config::ServerConfig::get_rpc_block_private_ips(); validate_safe_url(url, &allowed_hosts, block_private_ips).map_err(|e| { ProviderError::NetworkConfiguration(format!("RPC URL security validation failed: {e}")) })?; - Client::new(url).map_err(|e| { - ProviderError::NetworkConfiguration(format!( - "Failed to create Stellar RPC client: {e} - URL: '{url}'" - )) + STELLAR_RPC_CLIENT_CACHE.get_or_try_init(url.to_string(), || { + Client::new(url).map_err(|e| { + let safe_url = crate::utils::mask_url(url); + ProviderError::NetworkConfiguration(format!( + "Failed to create Stellar RPC client: {e} - URL: '{safe_url}'" + )) + }) }) } - /// Initialize a reqwest client for raw HTTP JSON-RPC calls. - /// - /// This centralizes client creation so we can configure timeouts and other options in one place. + /// Get the shared reqwest client for raw HTTP JSON-RPC calls, after + /// validating the URL as an SSRF safety net. fn initialize_raw_provider(&self, url: &str) -> Result { - ReqwestClient::builder() - .timeout(self.timeout_seconds) - .connect_timeout(Duration::from_secs(DEFAULT_HTTP_CLIENT_CONNECT_TIMEOUT_SECONDS)) - .pool_max_idle_per_host(DEFAULT_HTTP_CLIENT_POOL_MAX_IDLE_PER_HOST) - .pool_idle_timeout(Duration::from_secs(DEFAULT_HTTP_CLIENT_POOL_IDLE_TIMEOUT_SECONDS)) - .tcp_keepalive(Duration::from_secs(DEFAULT_HTTP_CLIENT_TCP_KEEPALIVE_SECONDS)) - .http2_keep_alive_interval(Some(Duration::from_secs( - DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_INTERVAL_SECONDS, - ))) - .http2_keep_alive_timeout(Duration::from_secs( - DEFAULT_HTTP_CLIENT_HTTP2_KEEP_ALIVE_TIMEOUT_SECONDS, - )) - .use_rustls_tls() - // Allow only HTTP→HTTPS redirects on same host to handle legitimate protocol upgrades - // while preventing SSRF via redirect chains to different hosts - .redirect(create_secure_redirect_policy()) - .build() - .map_err(|e| { - ProviderError::NetworkConfiguration(format!( - "Failed to create HTTP client for raw RPC: {e} - URL: '{url}'" - )) - }) + let allowed_hosts = crate::config::ServerConfig::get_rpc_allowed_hosts(); + let block_private_ips = crate::config::ServerConfig::get_rpc_block_private_ips(); + validate_safe_url(url, &allowed_hosts, block_private_ips).map_err(|e| { + ProviderError::NetworkConfiguration(format!("RPC URL security validation failed: {e}")) + })?; + + super::get_shared_rpc_http_client() } /// Helper method to retry RPC calls with exponential backoff @@ -458,7 +446,7 @@ impl StellarProvider { operation: F, ) -> Result where - F: Fn(Client) -> Fut, + F: Fn(Arc) -> Fut, Fut: std::future::Future>, { let provider_url_raw = match self.selector.get_current_url() { diff --git a/src/services/signer/stellar/aws_kms_signer.rs b/src/services/signer/stellar/aws_kms_signer.rs index 5fc433ed5..bc60d028b 100644 --- a/src/services/signer/stellar/aws_kms_signer.rs +++ b/src/services/signer/stellar/aws_kms_signer.rs @@ -17,11 +17,10 @@ use crate::{ use async_trait::async_trait; use sha2::{Digest, Sha256}; use soroban_rs::xdr::{ - DecoratedSignature, Hash, HashIdPreimage, HashIdPreimageSorobanAuthorization, Limits, ReadXdr, - ScBytes, ScMap, ScMapEntry, ScSymbol, ScVal, ScVec, Signature, SignatureHint, - SorobanAddressCredentials, SorobanAuthorizationEntry, SorobanCredentials, Transaction, + DecoratedSignature, Hash, Limits, ReadXdr, Signature, SignatureHint, Transaction, TransactionEnvelope, WriteXdr, }; +use tokio::sync::OnceCell; use tracing::debug; pub type DefaultAwsKmsService = AwsKmsService; @@ -31,12 +30,17 @@ where T: AwsKmsStellarService, { aws_kms_service: T, + /// Cached signature hint (last 4 bytes of the public key), computed once on first use. + cached_hint: OnceCell, } impl AwsKmsSigner { /// Creates a new AwsKmsSigner with the default AwsKmsService pub fn new(aws_kms_service: DefaultAwsKmsService) -> Self { - Self { aws_kms_service } + Self { + aws_kms_service, + cached_hint: OnceCell::new(), + } } } @@ -44,7 +48,10 @@ impl AwsKmsSigner { impl AwsKmsSigner { /// Creates a new AwsKmsSigner with a custom service implementation for testing pub fn new_for_testing(aws_kms_service: T) -> Self { - Self { aws_kms_service } + Self { + aws_kms_service, + cached_hint: OnceCell::new(), + } } } @@ -194,51 +201,26 @@ impl AwsKmsSigner { }) } - /// Get the signature hint for this signer (last 4 bytes of the public key) + /// Get the signature hint for this signer (last 4 bytes of the public key). + /// + /// The hint is computed once and cached for the lifetime of the signer, + /// since it is deterministic for a given KMS key. async fn get_signature_hint(&self) -> Result { - // Get the public key to derive the signature hint - let stellar_address = self - .aws_kms_service - .get_stellar_address() - .await - .map_err(|e| { - SignerError::SigningError(format!( - "Failed to retrieve Stellar address from AWS KMS: {e}" - )) - })?; - - // Extract hint from the public key (last 4 bytes of public key) - match stellar_address { - Address::Stellar(addr) => { - // Parse the Stellar address to get the public key - use stellar_strkey::ed25519::PublicKey; - let pk = PublicKey::from_string(&addr).map_err(|e| { - SignerError::SigningError(format!( - "Failed to parse Stellar address '{addr}': {e}" - )) - })?; - let pk_bytes = pk.0; - - // Safety check: ensure we have enough bytes for the hint - if pk_bytes.len() < 4 { - return Err(SignerError::SigningError(format!( - "Public key too short for signature hint: {} bytes", - pk_bytes.len() - ))); - } - - let hint_bytes: [u8; 4] = - pk_bytes[pk_bytes.len() - 4..].try_into().map_err(|_| { - SignerError::SigningError( - "Failed to create signature hint from public key".to_string(), - ) + self.cached_hint + .get_or_try_init(|| async { + let address = self + .aws_kms_service + .get_stellar_address() + .await + .map_err(|e| { + SignerError::SigningError(format!( + "Failed to retrieve Stellar address from AWS KMS: {e}" + )) })?; - Ok(SignatureHint(hint_bytes)) - } - _ => Err(SignerError::SigningError(format!( - "Expected Stellar address, got: {stellar_address:?}" - ))), - } + super::derive_signature_hint(&address) + }) + .await + .cloned() } } @@ -565,4 +547,55 @@ mod tests { _ => panic!("Expected SigningError about KMS service"), } } + + #[tokio::test] + async fn test_sign_xdr_hint_retrieval_failure() { + // Tests the get_signature_hint() error path when get_stellar_address fails + // inside the OnceCell init closure + let mut mock_service = MockAwsKmsStellarService::new(); + mock_service + .expect_get_stellar_address() + .times(1) + .returning(|| { + Box::pin(async { Err(AwsKmsError::GetError("key not found".to_string())) }) + }); + // sign_stellar succeeds but hint retrieval will fail + mock_service + .expect_sign_stellar() + .times(1) + .returning(|_| Box::pin(async { Ok(vec![1u8; 64]) })); + + let signer = AwsKmsSigner::new_for_testing(mock_service); + + use stellar_strkey::ed25519::PublicKey as StrKeyPublicKey; + let test_pk = StrKeyPublicKey([0u8; 32]); + let test_address = test_pk.to_string(); + + let tx_data = StellarTransactionData { + source_account: test_address, + fee: Some(100), + sequence_number: Some(1), + transaction_input: TransactionInput::Operations(vec![]), + memo: None, + valid_until: None, + network_passphrase: "Test SDF Network ; September 2015".to_string(), + signatures: Vec::new(), + hash: None, + simulation_transaction_data: None, + signed_envelope_xdr: None, + transaction_result_xdr: None, + }; + + let result = signer + .sign_transaction(NetworkTransactionData::Stellar(tx_data)) + .await; + + assert!(result.is_err()); + match result.unwrap_err() { + SignerError::SigningError(msg) => { + assert!(msg.contains("Failed to retrieve Stellar address from AWS KMS")); + } + e => panic!("Expected SigningError about address retrieval, got: {e:?}"), + } + } } diff --git a/src/services/signer/stellar/google_cloud_kms_signer.rs b/src/services/signer/stellar/google_cloud_kms_signer.rs index 5f38d3bf9..c273ee0a5 100644 --- a/src/services/signer/stellar/google_cloud_kms_signer.rs +++ b/src/services/signer/stellar/google_cloud_kms_signer.rs @@ -23,6 +23,7 @@ use soroban_rs::xdr::{ DecoratedSignature, Hash, Limits, ReadXdr, Signature, SignatureHint, Transaction, TransactionEnvelope, WriteXdr, }; +use tokio::sync::OnceCell; use tracing::debug; pub type DefaultGoogleCloudKmsService = GoogleCloudKmsService; @@ -32,6 +33,8 @@ where T: GoogleCloudKmsStellarService + GoogleCloudKmsServiceTrait, { google_cloud_kms_service: T, + /// Cached signature hint (last 4 bytes of the public key), computed once on first use. + cached_hint: OnceCell, } impl GoogleCloudKmsSigner { @@ -39,6 +42,7 @@ impl GoogleCloudKmsSigner { pub fn new(google_cloud_kms_service: DefaultGoogleCloudKmsService) -> Self { Self { google_cloud_kms_service, + cached_hint: OnceCell::new(), } } } @@ -49,6 +53,7 @@ impl GoogleCloudKm pub fn new_for_testing(google_cloud_kms_service: T) -> Self { Self { google_cloud_kms_service, + cached_hint: OnceCell::new(), } } } @@ -207,53 +212,28 @@ impl GoogleCloudKm }) } - /// Get the signature hint for this signer (last 4 bytes of the public key) - /// TODO: This can be cached on a future iteration + /// Get the signature hint for this signer (last 4 bytes of the public key). + /// + /// The hint is computed once and cached for the lifetime of the signer, + /// since it is deterministic for a given KMS key. async fn get_signature_hint(&self) -> Result { - use crate::services::GoogleCloudKmsStellarService; + self.cached_hint + .get_or_try_init(|| async { + use crate::services::GoogleCloudKmsStellarService; - // Get the public key to derive the signature hint - let stellar_address = - GoogleCloudKmsStellarService::get_stellar_address(&self.google_cloud_kms_service) + let address = GoogleCloudKmsStellarService::get_stellar_address( + &self.google_cloud_kms_service, + ) .await .map_err(|e| { SignerError::SigningError(format!( "Failed to retrieve Stellar address from Google Cloud KMS: {e}" )) })?; - - // Extract hint from the public key (last 4 bytes of public key) - match stellar_address { - Address::Stellar(addr) => { - // Parse the Stellar address to get the public key - use stellar_strkey::ed25519::PublicKey; - let pk = PublicKey::from_string(&addr).map_err(|e| { - SignerError::SigningError(format!( - "Failed to parse Stellar address '{addr}': {e}" - )) - })?; - let pk_bytes = pk.0; - - // Safety check: ensure we have enough bytes for the hint - if pk_bytes.len() < 4 { - return Err(SignerError::SigningError(format!( - "Public key too short for signature hint: {} bytes", - pk_bytes.len() - ))); - } - - let hint_bytes: [u8; 4] = - pk_bytes[pk_bytes.len() - 4..].try_into().map_err(|_| { - SignerError::SigningError( - "Failed to create signature hint from public key".to_string(), - ) - })?; - Ok(SignatureHint(hint_bytes)) - } - _ => Err(SignerError::SigningError(format!( - "Expected Stellar address, got: {stellar_address:?}" - ))), - } + super::derive_signature_hint(&address) + }) + .await + .cloned() } } diff --git a/src/services/signer/stellar/mod.rs b/src/services/signer/stellar/mod.rs index 72c38b068..deeb89255 100644 --- a/src/services/signer/stellar/mod.rs +++ b/src/services/signer/stellar/mod.rs @@ -14,6 +14,8 @@ use local_signer::*; use turnkey_signer::*; use vault_signer::*; +use soroban_rs::xdr::SignatureHint; + use crate::{ domain::{SignDataRequest, SignDataResponse, SignTransactionResponse, SignTypedDataRequest}, models::{ @@ -28,6 +30,27 @@ use crate::{ use super::DataSignerTrait; +/// Derive a `SignatureHint` (last 4 bytes of the Ed25519 public key) from a Stellar address. +fn derive_signature_hint(address: &Address) -> Result { + match address { + Address::Stellar(addr) => { + let pk = stellar_strkey::ed25519::PublicKey::from_string(addr).map_err(|e| { + SignerError::SigningError(format!("Failed to parse Stellar address '{addr}': {e}")) + })?; + // pk.0 is [u8; 32], last 4 bytes are the hint + let hint_bytes: [u8; 4] = pk.0[28..].try_into().map_err(|_| { + SignerError::SigningError( + "Failed to create signature hint from public key".to_string(), + ) + })?; + Ok(SignatureHint(hint_bytes)) + } + _ => Err(SignerError::SigningError(format!( + "Expected Stellar address, got: {address:?}" + ))), + } +} + #[cfg(test)] use mockall::automock; @@ -181,3 +204,65 @@ impl StellarSignerFactory { Ok(signer) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_derive_signature_hint_valid_stellar_address() { + let pk = stellar_strkey::ed25519::PublicKey([0u8; 32]); + let address = Address::Stellar(pk.to_string()); + + let hint = derive_signature_hint(&address).unwrap(); + // Last 4 bytes of all-zero key + assert_eq!(hint.0, [0u8; 4]); + } + + #[test] + fn test_derive_signature_hint_extracts_last_four_bytes() { + let mut key_bytes = [0u8; 32]; + key_bytes[28] = 0xAA; + key_bytes[29] = 0xBB; + key_bytes[30] = 0xCC; + key_bytes[31] = 0xDD; + let pk = stellar_strkey::ed25519::PublicKey(key_bytes); + let address = Address::Stellar(pk.to_string()); + + let hint = derive_signature_hint(&address).unwrap(); + assert_eq!(hint.0, [0xAA, 0xBB, 0xCC, 0xDD]); + } + + #[test] + fn test_derive_signature_hint_invalid_stellar_address() { + let address = Address::Stellar("INVALID_ADDRESS".to_string()); + let result = derive_signature_hint(&address); + assert!(result.is_err()); + match result.unwrap_err() { + SignerError::SigningError(msg) => { + assert!(msg.contains("Failed to parse Stellar address")); + } + e => panic!("Expected SigningError, got: {e:?}"), + } + } + + #[test] + fn test_derive_signature_hint_non_stellar_address() { + let address = Address::Evm([0u8; 20]); + let result = derive_signature_hint(&address); + assert!(result.is_err()); + match result.unwrap_err() { + SignerError::SigningError(msg) => { + assert!(msg.contains("Expected Stellar address")); + } + e => panic!("Expected SigningError, got: {e:?}"), + } + } + + #[test] + fn test_derive_signature_hint_solana_address_rejected() { + let address = Address::Solana("SomeBase58Address".to_string()); + let result = derive_signature_hint(&address); + assert!(result.is_err()); + } +}