Skip to content

Commit e6749d5

Browse files
apollo_infra: add tcp keepalive to remote server accepted sockets
Mirror the client-side SO_KEEPALIVE behaviour on the server: set TCP keepalive probes on each accepted socket via a configurable idle time (tcp_keepalive_idle_time_ms). This lets the OS detect dead clients that disappear without sending FIN/RST. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 317605e commit e6749d5

6 files changed

Lines changed: 213 additions & 41 deletions

File tree

crates/apollo_infra/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ metrics-exporter-prometheus.workspace = true
2828
rand.workspace = true
2929
rstest.workspace = true
3030
serde = { workspace = true, features = ["derive"] }
31+
socket2.workspace = true
3132
serde_json.workspace = true
3233
starknet_api.workspace = true
3334
static_assertions.workspace = true

crates/apollo_infra/src/component_client/remote_component_client.rs

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ use hyper_util::client::legacy::Client;
1818
use hyper_util::rt::TokioExecutor;
1919
use serde::de::DeserializeOwned;
2020
use serde::{Deserialize, Serialize};
21-
use static_assertions::const_assert;
2221
use tokio::sync::Mutex;
2322
use tokio::time::Instant;
2423
use tracing::field::{display, Empty};
@@ -32,6 +31,7 @@ use crate::component_definitions::{
3231
ServerError,
3332
APPLICATION_OCTET_STREAM,
3433
REQUEST_ID_HEADER,
34+
TCP_KEEPALIVE_FACTOR,
3535
};
3636
use crate::metrics::RemoteClientMetrics;
3737
use crate::requests::LabeledRequest;
@@ -45,10 +45,6 @@ pub const DEFAULT_RETRIES: usize = 15;
4545
pub const REQUEST_TIMEOUT_ERROR_MESSAGE: &str = "request timed out";
4646

4747
const DEFAULT_IDLE_CONNECTIONS: usize = 10;
48-
pub(crate) const TCP_IDLE_TIMEOUT_FACTOR: f64 = 1.5;
49-
// Ensure tcp connection timeout is greater than http2 connection timeout by requiring a factor
50-
// greater than 1.
51-
const_assert!(TCP_IDLE_TIMEOUT_FACTOR > 1.0);
5248

5349
// 8 MiB — bounds memory materialized from a single response as defense in depth.
5450
const DEFAULT_MAX_RESPONSE_BODY_BYTES: usize = 8 * 1024 * 1024;
@@ -66,8 +62,8 @@ pub struct RemoteClientConfig {
6662
pub retries: usize,
6763
pub idle_connections: usize,
6864
// Determines client connection timeouts. Used plainly for HTTP/2 connections, and with a
69-
// `TCP_IDLE_TIMEOUT_FACTOR` for TCP connections.
70-
#[validate(custom(function = "validate_tcp_exceeds_http_keepalive"))]
65+
// `TCP_KEEPALIVE_FACTOR` for TCP connections.
66+
#[validate(custom(function = "validate_keepalive_timeout_ms"))]
7167
pub keepalive_timeout_ms: u64,
7268
pub attempts_per_log: usize,
7369
pub initial_retry_delay_ms: u64,
@@ -96,24 +92,39 @@ impl Default for RemoteClientConfig {
9692
}
9793

9894
/// Validates that the TCP keepalive duration (at second granularity, as the OS stores
99-
/// `TCP_KEEPIDLE` in whole seconds) is greater than or equal to the HTTP keepalive duration
100-
/// (millisecond granularity). If the configured `keepalive_timeout_ms * TCP_IDLE_TIMEOUT_FACTOR` is
101-
/// less than 1 second, truncation to whole seconds yields 0 s, making the TCP keepalive shorter
102-
/// than the HTTP keepalive.
103-
fn validate_tcp_exceeds_http_keepalive(keepalive_timeout_ms: u64) -> Result<(), ValidationError> {
95+
/// `TCP_KEEPIDLE` in whole seconds) is positive and greater than or equal to the HTTP keepalive
96+
/// duration (millisecond granularity). If the configured
97+
/// `keepalive_timeout_ms * TCP_KEEPALIVE_FACTOR` rounds to 0 s, the kernel rejects the socket
98+
/// option with `EINVAL`. If it is less than `keepalive_timeout_ms`, the TCP keepalive fires before
99+
/// the HTTP-level timeout.
100+
pub(crate) fn validate_keepalive_timeout_ms(
101+
keepalive_timeout_ms: u64,
102+
) -> Result<(), ValidationError> {
104103
let http_keepalive = Duration::from_millis(keepalive_timeout_ms);
105-
let tcp_keepalive_raw = http_keepalive.mul_f64(TCP_IDLE_TIMEOUT_FACTOR);
104+
let tcp_keepalive_raw = http_keepalive.mul_f64(TCP_KEEPALIVE_FACTOR);
106105
// TCP_KEEPIDLE is stored in whole seconds; fractional seconds are truncated by the OS.
107-
let tcp_keepalive = Duration::from_secs(tcp_keepalive_raw.as_secs());
106+
let tcp_keepalive_secs = tcp_keepalive_raw.as_secs();
107+
if tcp_keepalive_secs == 0 {
108+
return Err(create_validation_error(
109+
format!(
110+
"TCP keepalive rounds to 0 s (keepalive_timeout_ms={keepalive_timeout_ms}, \
111+
factor={TCP_KEEPALIVE_FACTOR}): increase keepalive_timeout_ms so that \
112+
keepalive_timeout_ms * {TCP_KEEPALIVE_FACTOR} is at least 1 s",
113+
),
114+
"tcp_keepalive_zero",
115+
"TCP keepalive (second granularity) must be > 0 s.",
116+
));
117+
}
118+
let tcp_keepalive = Duration::from_secs(tcp_keepalive_secs);
108119
if tcp_keepalive >= http_keepalive {
109120
Ok(())
110121
} else {
111122
Err(create_validation_error(
112123
format!(
113-
"TCP keepalive ({} s) is shorter than HTTP keepalive ({keepalive_timeout_ms} ms): \
114-
increase keepalive_timeout_ms so that keepalive_timeout_ms * \
115-
{TCP_IDLE_TIMEOUT_FACTOR} rounds to at least {keepalive_timeout_ms} ms",
116-
tcp_keepalive.as_secs(),
124+
"TCP keepalive ({tcp_keepalive_secs} s) is shorter than HTTP keepalive \
125+
({keepalive_timeout_ms} ms): increase keepalive_timeout_ms so that \
126+
keepalive_timeout_ms * {TCP_KEEPALIVE_FACTOR} rounds to at least \
127+
{keepalive_timeout_ms} ms",
117128
),
118129
"tcp_keepalive_shorter_than_http_keepalive",
119130
"TCP keepalive (second granularity) must be >= HTTP keepalive (ms granularity).",
@@ -230,7 +241,7 @@ where
230241
let mut connector = HttpConnector::new();
231242
connector.set_nodelay(config.set_tcp_nodelay);
232243
connector.set_connect_timeout(Some(Duration::from_millis(config.connection_timeout_ms)));
233-
connector.set_keepalive(Some(idle_timeout.mul_f64(TCP_IDLE_TIMEOUT_FACTOR)));
244+
connector.set_keepalive(Some(idle_timeout.mul_f64(TCP_KEEPALIVE_FACTOR)));
234245

235246
// Create the HTTP/2 client.
236247
let client = Client::builder(TokioExecutor::new())

crates/apollo_infra/src/component_client/remote_component_client_test.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ fn tcp_keepalive_validation_passes_when_tcp_equals_http_keepalive() {
1818
assert!(config.validate().is_ok());
1919
}
2020

21-
// keepalive_timeout_ms = 100 ms: tcp_raw = 150 ms, tcp_whole_secs = 0 s = 0 ms, which is less than
22-
// http_keepalive = 100 ms.
21+
// keepalive_timeout_ms = 100 ms: tcp_raw = 150 ms, tcp_whole_secs = 0 s = 0 ms, which is less
22+
// than http_keepalive = 100 ms.
2323
#[test]
2424
fn tcp_keepalive_validation_fails_when_tcp_truncated_below_http_keepalive() {
2525
let config = RemoteClientConfig { keepalive_timeout_ms: 100, ..Default::default() };

crates/apollo_infra/src/component_definitions.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use derive_more::{Display, FromStr};
66
use rand::random;
77
use serde::de::DeserializeOwned;
88
use serde::{Deserialize, Serialize};
9+
use static_assertions::const_assert;
910
use thiserror::Error;
1011
use tokio::sync::mpsc::{Receiver, Sender};
1112
use tokio::time::Instant;
@@ -17,6 +18,11 @@ use crate::requests::LabeledRequest;
1718
pub(crate) const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
1819
pub const BUSY_PREVIOUS_REQUESTS_MSG: &str = "Server is busy addressing previous requests";
1920

21+
pub(crate) const TCP_KEEPALIVE_FACTOR: f64 = 1.5;
22+
// Ensure tcp connection timeout is greater than http2 connection timeout by requiring a factor
23+
// greater than 1.
24+
const_assert!(TCP_KEEPALIVE_FACTOR > 1.0);
25+
2026
#[async_trait]
2127
pub trait ComponentRequestHandler<Request, Response> {
2228
async fn handle_request(&mut self, request: Request) -> Response;

crates/apollo_infra/src/component_server/remote_component_server.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
2020
use hyper_util::server::conn::auto::Builder as ServerBuilder;
2121
use serde::de::DeserializeOwned;
2222
use serde::{Deserialize, Serialize};
23+
use socket2::{SockRef, TcpKeepalive};
2324
use tokio::net::TcpListener;
2425
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
2526
use tracing::{debug, error, instrument, trace, warn};
2627
use validator::Validate;
2728

29+
use crate::component_client::remote_component_client::validate_keepalive_timeout_ms;
2830
use crate::component_client::{ClientError, LocalComponentClient};
2931
use crate::component_definitions::{
3032
ComponentClient,
@@ -33,6 +35,7 @@ use crate::component_definitions::{
3335
APPLICATION_OCTET_STREAM,
3436
BUSY_PREVIOUS_REQUESTS_MSG,
3537
REQUEST_ID_HEADER,
38+
TCP_KEEPALIVE_FACTOR,
3639
};
3740
use crate::component_server::ComponentServerStarter;
3841
use crate::metrics::RemoteServerMetrics;
@@ -80,6 +83,7 @@ pub struct RemoteServerConfig {
8083
pub max_concurrency: usize,
8184
pub max_request_body_bytes: usize,
8285
pub keepalive_interval_ms: u64,
86+
#[validate(custom(function = "validate_keepalive_timeout_ms"))]
8387
pub keepalive_timeout_ms: u64,
8488
}
8589

@@ -393,6 +397,10 @@ where
393397
panic!("Failed to bind remote component server socket {:#?}: {e}", bind_socket)
394398
});
395399

400+
let max_streams = self.config.max_streams_per_connection;
401+
let keepalive_interval = Duration::from_millis(self.config.keepalive_interval_ms);
402+
let keepalive_timeout = Duration::from_millis(self.config.keepalive_timeout_ms);
403+
396404
loop {
397405
let (stream, peer_addr) = match listener.accept().await {
398406
Ok(conn) => conn,
@@ -407,10 +415,13 @@ where
407415
warn!("Failed to set TCP_NODELAY: {e}");
408416
}
409417

418+
let tcp_keepalive =
419+
TcpKeepalive::new().with_time(keepalive_timeout.mul_f64(TCP_KEEPALIVE_FACTOR));
420+
if let Err(e) = SockRef::from(&stream).set_tcp_keepalive(&tcp_keepalive) {
421+
error!("Failed to set TCP keepalive: {e}");
422+
}
423+
410424
let io = TokioIo::new(stream);
411-
let max_streams = self.config.max_streams_per_connection;
412-
let keepalive_interval = Duration::from_millis(self.config.keepalive_interval_ms);
413-
let keepalive_timeout = Duration::from_millis(self.config.keepalive_timeout_ms);
414425

415426
tokio::spawn(per_connection_service(
416427
io,

0 commit comments

Comments
 (0)