diff --git a/ballista/core/src/client.rs b/ballista/core/src/client.rs index 53a9aedfa8..86b5a0e529 100644 --- a/ballista/core/src/client.rs +++ b/ballista/core/src/client.rs @@ -18,13 +18,16 @@ //! Client API for sending requests to executors. use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use std::time::{Duration, Instant}; use std::{ convert::{TryFrom, TryInto}, task::{Context, Poll}, }; +use parking_lot::RwLock; + use crate::error::{BallistaError, Result as BResult}; use crate::serde::scheduler::{Action, PartitionId}; @@ -287,6 +290,181 @@ impl BallistaClient { } } +/// Default time-to-live for cached connections (5 minutes). +/// Connections older than this will be replaced with fresh ones. +const DEFAULT_CONNECTION_TTL: Duration = Duration::from_secs(5 * 60); + +/// A cached connection with its creation timestamp. +struct CachedConnection { + client: BallistaClient, + created_at: Instant, +} + +/// A connection pool for reusing `BallistaClient` connections to executors. +/// +/// This pool caches connections by (host, port) to avoid the overhead of +/// establishing new gRPC connections for each partition fetch during shuffle reads. +/// Connections have a configurable time-to-live (TTL) after which they are +/// considered stale and will be replaced with fresh connections. +/// +/// This TTL mechanism prevents connection leaks when executors are removed or +/// replaced, as stale connections will eventually be cleaned up even if they +/// never fail with an error. +/// +/// # Thread Safety +/// +/// The pool uses a `RwLock` to allow concurrent reads while ensuring exclusive +/// access during connection creation. The `BallistaClient` itself is `Clone` +/// (wrapping an `Arc`), so cloned clients share the underlying connection. +pub struct BallistaClientPool { + /// Map from (host, port) to cached client connection with timestamp + connections: RwLock>, + /// Time-to-live for cached connections + ttl: Duration, +} + +impl Default for BallistaClientPool { + fn default() -> Self { + Self::new() + } +} + +impl BallistaClientPool { + /// Creates a new empty connection pool with the default TTL. + pub fn new() -> Self { + Self::with_ttl(DEFAULT_CONNECTION_TTL) + } + + /// Creates a new empty connection pool with a custom TTL. + pub fn with_ttl(ttl: Duration) -> Self { + Self { + connections: RwLock::new(HashMap::new()), + ttl, + } + } + + /// Checks if a cached connection is still valid (not expired). + fn is_connection_valid(&self, cached: &CachedConnection) -> bool { + cached.created_at.elapsed() < self.ttl + } + + /// Gets an existing connection or creates a new one for the given host and port. + /// + /// If a valid (non-expired) connection already exists in the pool, it is cloned + /// and returned. Otherwise, a new connection is established, cached, and returned. + /// Expired connections are automatically replaced. + /// + /// # Arguments + /// + /// * `host` - The hostname or IP address of the executor + /// * `port` - The port number of the executor's Flight service + /// * `max_message_size` - Maximum gRPC message size for new connections + /// + /// # Errors + /// + /// Returns an error if connection establishment fails for a new connection. + pub async fn get_or_connect( + &self, + host: &str, + port: u16, + max_message_size: usize, + ) -> BResult { + let key = (host.to_string(), port); + + // Fast path: check if a valid connection exists with read lock + { + let connections = self.connections.read(); + if let Some(cached) = connections.get(&key) { + if self.is_connection_valid(cached) { + debug!("Reusing cached connection to {host}:{port}"); + return Ok(cached.client.clone()); + } + debug!("Cached connection to {host}:{port} has expired, will create new one"); + } + } + + // Slow path: create new connection without holding lock + // Multiple tasks might race to create connections to the same host, + // but only one will be cached (the others will be dropped) + debug!("Creating new connection to {host}:{port}"); + let client = BallistaClient::try_new(host, port, max_message_size).await?; + + // Now acquire write lock to cache the connection + let mut connections = self.connections.write(); + + // Check if another task created a valid connection while we were connecting + if let Some(cached) = connections.get(&key) + && self.is_connection_valid(cached) + { + debug!("Using connection to {host}:{port} created by another task"); + return Ok(cached.client.clone()); + } + + // Cache our new connection + let cached = CachedConnection { + client: client.clone(), + created_at: Instant::now(), + }; + connections.insert(key, cached); + Ok(client) + } + + /// Removes a connection from the pool. + /// + /// This can be used to force reconnection on the next request, + /// for example after a connection error. + pub fn remove(&self, host: &str, port: u16) { + let key = (host.to_string(), port); + let mut connections = self.connections.write(); + if connections.remove(&key).is_some() { + debug!("Removed cached connection to {host}:{port}"); + } + } + + /// Returns the number of cached connections. + pub fn len(&self) -> usize { + self.connections.read().len() + } + + /// Returns true if the pool has no cached connections. + pub fn is_empty(&self) -> bool { + self.connections.read().is_empty() + } + + /// Clears all cached connections. + pub fn clear(&self) { + let mut connections = self.connections.write(); + let count = connections.len(); + connections.clear(); + debug!("Cleared {count} cached connections from pool"); + } + + /// Removes all expired connections from the pool. + /// + /// This method can be called periodically to proactively clean up + /// stale connections rather than waiting for them to be accessed. + /// Returns the number of connections that were removed. + pub fn remove_expired(&self) -> usize { + let mut connections = self.connections.write(); + let initial_count = connections.len(); + connections.retain(|_, cached| cached.created_at.elapsed() < self.ttl); + let removed = initial_count - connections.len(); + if removed > 0 { + debug!("Removed {removed} expired connections from pool"); + } + removed + } +} + +/// Returns the global connection pool instance. +/// +/// This pool is shared across all shuffle read operations within the executor +/// process, enabling connection reuse across different tasks and queries. +pub fn global_client_pool() -> &'static BallistaClientPool { + static POOL: OnceLock = OnceLock::new(); + POOL.get_or_init(BallistaClientPool::new) +} + /// [FlightDataStream] facilitates the transfer of shuffle data using the Arrow Flight protocol. /// Internally, it invokes the `do_get` method on the Arrow Flight server, which returns a stream /// of messages, each representing a record batch. @@ -623,4 +801,73 @@ mod tests { assert_eq!(batches, result.unwrap()) } + + mod connection_pool_tests { + use super::super::BallistaClientPool; + use std::time::{Duration, Instant}; + + #[test] + fn test_pool_new_with_default_ttl() { + let pool = BallistaClientPool::new(); + assert!(pool.is_empty()); + assert_eq!(pool.len(), 0); + } + + #[test] + fn test_pool_with_custom_ttl() { + let ttl = Duration::from_secs(60); + let pool = BallistaClientPool::with_ttl(ttl); + assert_eq!(pool.ttl, ttl); + } + + #[test] + fn test_is_connection_valid_not_expired() { + let pool = BallistaClientPool::with_ttl(Duration::from_secs(60)); + + // Create a mock CachedConnection that was just created + // We can't actually create a BallistaClient without a server, + // but we can test the TTL logic by checking is_connection_valid + // through the internal mechanism via remove_expired + + // Since we can't insert directly, we test through the public API + // by checking that remove_expired doesn't remove anything when TTL hasn't passed + assert_eq!(pool.remove_expired(), 0); + } + + #[test] + fn test_remove_expired_with_zero_ttl() { + // With a zero TTL, any connection should be considered expired immediately + let pool = BallistaClientPool::with_ttl(Duration::ZERO); + // Can't insert without a real connection, but we can verify the pool behavior + assert_eq!(pool.remove_expired(), 0); + assert!(pool.is_empty()); + } + + #[test] + fn test_pool_clear() { + let pool = BallistaClientPool::new(); + pool.clear(); + assert!(pool.is_empty()); + } + + #[test] + fn test_pool_remove_nonexistent() { + let pool = BallistaClientPool::new(); + // Should not panic when removing a non-existent connection + pool.remove("nonexistent", 12345); + assert!(pool.is_empty()); + } + + #[test] + fn test_cached_connection_created_at() { + // Test that CachedConnection stores creation time correctly + let now = Instant::now(); + // We can't create a real BallistaClient, but we can verify the struct works + // This is more of a compile-time check that the struct is correctly defined + let _duration = Duration::from_secs(300); + let _instant = Instant::now(); + // Verify that elapsed time calculation works + assert!(now.elapsed() < Duration::from_secs(1)); + } + } } diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 2ebc2fd3c6..9116268c89 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -28,7 +28,7 @@ use std::result; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::client::BallistaClient; +use crate::client::global_client_pool; use crate::extension::SessionConfigExt; use crate::serde::scheduler::{PartitionLocation, PartitionStats}; @@ -490,11 +490,13 @@ async fn fetch_partition_remote( ) -> result::Result { let metadata = &location.executor_meta; let partition_id = &location.partition_id; - // TODO for shuffle client connections, we should avoid creating new connections again and again. - // And we should also avoid to keep alive too many connections for long time. let host = metadata.host.as_str(); let port = metadata.port; - let mut ballista_client = BallistaClient::try_new(host, port, max_message_size) + + // Use the global connection pool to reuse connections across partition fetches + let pool = global_client_pool(); + let mut ballista_client = pool + .get_or_connect(host, port, max_message_size) .await .map_err(|error| match error { // map grpc connection error to partition fetch error. @@ -507,7 +509,7 @@ async fn fetch_partition_remote( other => other, })?; - ballista_client + let result = ballista_client .fetch_partition( &metadata.id, partition_id, @@ -516,7 +518,20 @@ async fn fetch_partition_remote( port, flight_transport, ) - .await + .await; + + // On connection-related errors, remove the connection from the pool + // so the next request will create a fresh connection + if let Err(BallistaError::FetchFailed(..)) | Err(BallistaError::GrpcActionError(_)) = + &result + { + debug!( + "Removing potentially stale connection to {host}:{port} from pool after error" + ); + pool.remove(host, port); + } + + result } async fn fetch_partition_local(