Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 248 additions & 1 deletion ballista/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
//! Client API for sending requests to executors.

use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};

use std::{
convert::{TryFrom, TryInto},
task::{Context, Poll},
};

use parking_lot::RwLock;

use crate::error::{BallistaError, Result as BResult};
use crate::serde::scheduler::{Action, PartitionId};

Expand Down Expand Up @@ -287,6 +290,181 @@ impl BallistaClient {
}
}

/// Default time-to-live for cached connections (5 minutes).
/// Connections older than this will be replaced with fresh ones.
const DEFAULT_CONNECTION_TTL: Duration = Duration::from_secs(5 * 60);

/// A cached connection with its creation timestamp.
struct CachedConnection {
client: BallistaClient,
created_at: Instant,
}

/// A connection pool for reusing `BallistaClient` connections to executors.
///
/// This pool caches connections by (host, port) to avoid the overhead of
/// establishing new gRPC connections for each partition fetch during shuffle reads.
/// Connections have a configurable time-to-live (TTL) after which they are
/// considered stale and will be replaced with fresh connections.
///
/// This TTL mechanism prevents connection leaks when executors are removed or
/// replaced, as stale connections will eventually be cleaned up even if they
/// never fail with an error.
///
/// # Thread Safety
///
/// The pool uses a `RwLock` to allow concurrent reads while ensuring exclusive
/// access during connection creation. The `BallistaClient` itself is `Clone`
/// (wrapping an `Arc`), so cloned clients share the underlying connection.
pub struct BallistaClientPool {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wonder would it make sense to use some off-the-shelf like deadpool or r2d2 instead of implementing our own ?

If i'm not mistaken, current implementation can, in theory, leak connections; connection will be removed from the pool only when it fails, but, in theory, we can have a executor removed/replaced and connection never get chance to fail

Copy link
Copy Markdown
Member Author

@andygrove andygrove Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a TTL improvement, but moved this to draft for now

/// Map from (host, port) to cached client connection with timestamp
connections: RwLock<HashMap<(String, u16), CachedConnection>>,
/// Time-to-live for cached connections
ttl: Duration,
}

impl Default for BallistaClientPool {
fn default() -> Self {
Self::new()
}
}

impl BallistaClientPool {
/// Creates a new empty connection pool with the default TTL.
pub fn new() -> Self {
Self::with_ttl(DEFAULT_CONNECTION_TTL)
}

/// Creates a new empty connection pool with a custom TTL.
pub fn with_ttl(ttl: Duration) -> Self {
Self {
connections: RwLock::new(HashMap::new()),
ttl,
}
}

/// Checks if a cached connection is still valid (not expired).
fn is_connection_valid(&self, cached: &CachedConnection) -> bool {
cached.created_at.elapsed() < self.ttl
}

/// Gets an existing connection or creates a new one for the given host and port.
///
/// If a valid (non-expired) connection already exists in the pool, it is cloned
/// and returned. Otherwise, a new connection is established, cached, and returned.
/// Expired connections are automatically replaced.
///
/// # Arguments
///
/// * `host` - The hostname or IP address of the executor
/// * `port` - The port number of the executor's Flight service
/// * `max_message_size` - Maximum gRPC message size for new connections
///
/// # Errors
///
/// Returns an error if connection establishment fails for a new connection.
pub async fn get_or_connect(
&self,
host: &str,
port: u16,
max_message_size: usize,
) -> BResult<BallistaClient> {
let key = (host.to_string(), port);

// Fast path: check if a valid connection exists with read lock
{
let connections = self.connections.read();
if let Some(cached) = connections.get(&key) {
if self.is_connection_valid(cached) {
debug!("Reusing cached connection to {host}:{port}");
return Ok(cached.client.clone());
}
debug!("Cached connection to {host}:{port} has expired, will create new one");
}
}

// Slow path: create new connection without holding lock
// Multiple tasks might race to create connections to the same host,
// but only one will be cached (the others will be dropped)
debug!("Creating new connection to {host}:{port}");
let client = BallistaClient::try_new(host, port, max_message_size).await?;

// Now acquire write lock to cache the connection
let mut connections = self.connections.write();

// Check if another task created a valid connection while we were connecting
if let Some(cached) = connections.get(&key)
&& self.is_connection_valid(cached)
{
debug!("Using connection to {host}:{port} created by another task");
return Ok(cached.client.clone());
}

// Cache our new connection
let cached = CachedConnection {
client: client.clone(),
created_at: Instant::now(),
};
connections.insert(key, cached);
Ok(client)
}

/// Removes a connection from the pool.
///
/// This can be used to force reconnection on the next request,
/// for example after a connection error.
pub fn remove(&self, host: &str, port: u16) {
let key = (host.to_string(), port);
let mut connections = self.connections.write();
if connections.remove(&key).is_some() {
debug!("Removed cached connection to {host}:{port}");
}
}

/// Returns the number of cached connections.
pub fn len(&self) -> usize {
self.connections.read().len()
}

/// Returns true if the pool has no cached connections.
pub fn is_empty(&self) -> bool {
self.connections.read().is_empty()
}

/// Clears all cached connections.
pub fn clear(&self) {
let mut connections = self.connections.write();
let count = connections.len();
connections.clear();
debug!("Cleared {count} cached connections from pool");
}

/// Removes all expired connections from the pool.
///
/// This method can be called periodically to proactively clean up
/// stale connections rather than waiting for them to be accessed.
/// Returns the number of connections that were removed.
pub fn remove_expired(&self) -> usize {
let mut connections = self.connections.write();
let initial_count = connections.len();
connections.retain(|_, cached| cached.created_at.elapsed() < self.ttl);
let removed = initial_count - connections.len();
if removed > 0 {
debug!("Removed {removed} expired connections from pool");
}
removed
}
}

/// Returns the global connection pool instance.
///
/// This pool is shared across all shuffle read operations within the executor
/// process, enabling connection reuse across different tasks and queries.
pub fn global_client_pool() -> &'static BallistaClientPool {
static POOL: OnceLock<BallistaClientPool> = OnceLock::new();
POOL.get_or_init(BallistaClientPool::new)
}

/// [FlightDataStream] facilitates the transfer of shuffle data using the Arrow Flight protocol.
/// Internally, it invokes the `do_get` method on the Arrow Flight server, which returns a stream
/// of messages, each representing a record batch.
Expand Down Expand Up @@ -623,4 +801,73 @@ mod tests {

assert_eq!(batches, result.unwrap())
}

mod connection_pool_tests {
use super::super::BallistaClientPool;
use std::time::{Duration, Instant};

#[test]
fn test_pool_new_with_default_ttl() {
let pool = BallistaClientPool::new();
assert!(pool.is_empty());
assert_eq!(pool.len(), 0);
}

#[test]
fn test_pool_with_custom_ttl() {
let ttl = Duration::from_secs(60);
let pool = BallistaClientPool::with_ttl(ttl);
assert_eq!(pool.ttl, ttl);
}

#[test]
fn test_is_connection_valid_not_expired() {
let pool = BallistaClientPool::with_ttl(Duration::from_secs(60));

// Create a mock CachedConnection that was just created
// We can't actually create a BallistaClient without a server,
// but we can test the TTL logic by checking is_connection_valid
// through the internal mechanism via remove_expired

// Since we can't insert directly, we test through the public API
// by checking that remove_expired doesn't remove anything when TTL hasn't passed
assert_eq!(pool.remove_expired(), 0);
}

#[test]
fn test_remove_expired_with_zero_ttl() {
// With a zero TTL, any connection should be considered expired immediately
let pool = BallistaClientPool::with_ttl(Duration::ZERO);
// Can't insert without a real connection, but we can verify the pool behavior
assert_eq!(pool.remove_expired(), 0);
assert!(pool.is_empty());
}

#[test]
fn test_pool_clear() {
let pool = BallistaClientPool::new();
pool.clear();
assert!(pool.is_empty());
}

#[test]
fn test_pool_remove_nonexistent() {
let pool = BallistaClientPool::new();
// Should not panic when removing a non-existent connection
pool.remove("nonexistent", 12345);
assert!(pool.is_empty());
}

#[test]
fn test_cached_connection_created_at() {
// Test that CachedConnection stores creation time correctly
let now = Instant::now();
// We can't create a real BallistaClient, but we can verify the struct works
// This is more of a compile-time check that the struct is correctly defined
let _duration = Duration::from_secs(300);
let _instant = Instant::now();
// Verify that elapsed time calculation works
assert!(now.elapsed() < Duration::from_secs(1));
}
}
}
27 changes: 21 additions & 6 deletions ballista/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::result;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::client::BallistaClient;
use crate::client::global_client_pool;
use crate::extension::SessionConfigExt;
use crate::serde::scheduler::{PartitionLocation, PartitionStats};

Expand Down Expand Up @@ -490,11 +490,13 @@ async fn fetch_partition_remote(
) -> result::Result<SendableRecordBatchStream, BallistaError> {
let metadata = &location.executor_meta;
let partition_id = &location.partition_id;
// TODO for shuffle client connections, we should avoid creating new connections again and again.
// And we should also avoid to keep alive too many connections for long time.
let host = metadata.host.as_str();
let port = metadata.port;
let mut ballista_client = BallistaClient::try_new(host, port, max_message_size)

// Use the global connection pool to reuse connections across partition fetches
let pool = global_client_pool();
let mut ballista_client = pool
.get_or_connect(host, port, max_message_size)
.await
.map_err(|error| match error {
// map grpc connection error to partition fetch error.
Expand All @@ -507,7 +509,7 @@ async fn fetch_partition_remote(
other => other,
})?;

ballista_client
let result = ballista_client
.fetch_partition(
&metadata.id,
partition_id,
Expand All @@ -516,7 +518,20 @@ async fn fetch_partition_remote(
port,
flight_transport,
)
.await
.await;

// On connection-related errors, remove the connection from the pool
// so the next request will create a fresh connection
if let Err(BallistaError::FetchFailed(..)) | Err(BallistaError::GrpcActionError(_)) =
&result
{
debug!(
"Removing potentially stale connection to {host}:{port} from pool after error"
);
pool.remove(host, port);
}

result
}

async fn fetch_partition_local(
Expand Down
Loading