Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/apollo_infra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ rand.workspace = true
rstest.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
socket2 = { workspace = true, features = ["all"] }
starknet_api.workspace = true
static_assertions.workspace = true
thiserror.workspace = true
Expand All @@ -49,6 +50,5 @@ metrics.workspace = true
metrics-exporter-prometheus.workspace = true
once_cell.workspace = true
pretty_assertions.workspace = true
socket2 = { workspace = true, features = ["all"] }
starknet-types-core.workspace = true
strum = { workspace = true, features = ["derive"] }
2 changes: 1 addition & 1 deletion crates/apollo_infra/src/component_client/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::de::DeserializeOwned;
use serde::Serialize;
use thiserror::Error;

use super::{LocalComponentClient, RemoteComponentClient};
use crate::component_client::{LocalComponentClient, RemoteComponentClient};
use crate::component_definitions::ServerError;

#[derive(Clone, Debug, Error, PartialEq, Eq)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use tracing::field::{display, Empty};
use tracing::{debug, instrument, trace, warn};
use validator::{Validate, ValidationError};

use super::definitions::{ClientError, ClientResult};
use crate::component_client::{ClientError, ClientResult};
use crate::component_definitions::{
ComponentClient,
RequestId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use hyper_util::server::conn::auto::Builder as ServerBuilder;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use socket2::{SockRef, TcpKeepalive};
use tokio::net::TcpListener;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tracing::{debug, error, instrument, trace, warn};
use validator::Validate;

use crate::component_client::remote_component_client::validate_keepalive_timeout_ms;
use crate::component_client::{ClientError, LocalComponentClient};
use crate::component_definitions::{
ComponentClient,
Expand All @@ -33,6 +35,7 @@ use crate::component_definitions::{
APPLICATION_OCTET_STREAM,
BUSY_PREVIOUS_REQUESTS_MSG,
REQUEST_ID_HEADER,
TCP_KEEPALIVE_FACTOR,
};
use crate::component_server::ComponentServerStarter;
use crate::metrics::RemoteServerMetrics;
Expand All @@ -46,6 +49,9 @@ const DEFAULT_MAX_CONCURRENCY: usize = 128;
const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 8 * 1024 * 1024;
const DEFAULT_KEEPALIVE_INTERVAL_MS: u64 = 30_000;
const DEFAULT_KEEPALIVE_TIMEOUT_MS: u64 = 10_000;
// Number of unanswered TCP keepalive probes before the OS declares the connection dead.
// 3 probes × keepalive_interval gives a ~90 s probe window at the default interval.
const TCP_KEEPALIVE_RETRIES: u32 = 3;

macro_rules! serve_connection {
(
Expand Down Expand Up @@ -80,6 +86,7 @@ pub struct RemoteServerConfig {
pub max_concurrency: usize,
pub max_request_body_bytes: usize,
pub keepalive_interval_ms: u64,
#[validate(custom(function = "validate_keepalive_timeout_ms"))]
Comment thread
cursor[bot] marked this conversation as resolved.
pub keepalive_timeout_ms: u64,
}

Expand Down Expand Up @@ -393,6 +400,10 @@ where
panic!("Failed to bind remote component server socket {:#?}: {e}", bind_socket)
});

let max_streams = self.config.max_streams_per_connection;
let keepalive_interval = Duration::from_millis(self.config.keepalive_interval_ms);
let keepalive_timeout = Duration::from_millis(self.config.keepalive_timeout_ms);

loop {
let (stream, peer_addr) = match listener.accept().await {
Ok(conn) => conn,
Expand All @@ -407,10 +418,15 @@ where
warn!("Failed to set TCP_NODELAY: {e}");
}

let tcp_keepalive = TcpKeepalive::new()
.with_time(keepalive_timeout.mul_f64(TCP_KEEPALIVE_FACTOR))
.with_interval(keepalive_interval)
.with_retries(TCP_KEEPALIVE_RETRIES);
if let Err(e) = SockRef::from(&stream).set_tcp_keepalive(&tcp_keepalive) {
error!("Failed to set TCP keepalive: {e}");
}

let io = TokioIo::new(stream);
let max_streams = self.config.max_streams_per_connection;
let keepalive_interval = Duration::from_millis(self.config.keepalive_interval_ms);
let keepalive_timeout = Duration::from_millis(self.config.keepalive_timeout_ms);

tokio::spawn(per_connection_service(
io,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,13 +651,13 @@ async fn retry_request() {
async fn tcp_keepalive_idle_time_matches_config() {
// 2000 * 1.5 = 3000 ms = 3 s exactly; socket2 stores TCP_KEEPIDLE in whole seconds, so the
// configured duration must be a whole number of seconds or the comparison fails.
const IDLE_TIMEOUT_MS: u64 = 2000;
const KEEPALIVE_TIMEOUT_MS: u64 = 2000;
let expected_keepalive_idle =
Duration::from_millis(IDLE_TIMEOUT_MS).mul_f64(TCP_KEEPALIVE_FACTOR);
Duration::from_millis(KEEPALIVE_TIMEOUT_MS).mul_f64(TCP_KEEPALIVE_FACTOR);
assert_eq!(
expected_keepalive_idle.subsec_nanos(),
0,
"IDLE_TIMEOUT_MS * TCP_KEEPALIVE_FACTOR must be a whole number of seconds"
"KEEPALIVE_TIMEOUT_MS * TCP_KEEPALIVE_FACTOR must be a whole number of seconds"
);

let mut ports = available_ports_factory(unique_u16!());
Expand All @@ -667,7 +667,7 @@ async fn tcp_keepalive_idle_time_matches_config() {
setup_for_tests(VALID_VALUE_A, a_socket, b_socket, MAX_CONCURRENCY, None).await;

let client = ComponentAClient::new(
RemoteClientConfig { keepalive_timeout_ms: IDLE_TIMEOUT_MS, ..Default::default() },
RemoteClientConfig { keepalive_timeout_ms: KEEPALIVE_TIMEOUT_MS, ..Default::default() },
&a_socket.ip().to_string(),
a_socket.port(),
&TEST_REMOTE_CLIENT_METRICS,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;

use apollo_proc_macros::unique_u16;
use rstest::rstest;
use socket2::{SockRef, TcpKeepalive};
use tokio::io::AsyncReadExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::channel;
use tokio::task;
use tokio::time::{sleep, timeout};
use tokio::time::timeout;

use crate::component_client::LocalComponentClient;
use crate::component_definitions::RequestWrapper;
use crate::component_server::{ComponentServerStarter, RemoteComponentServer, RemoteServerConfig};
use crate::tests::test_utils::{
available_ports_factory,
connect_zombie,
contains_goaway_frame,
dummy_remote_server_config,
ComponentARequest,
ComponentAResponse,
Expand All @@ -20,18 +25,61 @@ use crate::tests::test_utils::{
TEST_REMOTE_SERVER_METRICS,
};

/// Verifies that the server closes a zombie connection after the HTTP keepalive interval and
/// timeout elapse without receiving a PING response.
/// Verifies that `SO_KEEPALIVE` on a server-accepted socket.
///
/// The test accepts the connection itself so it owns the `TcpStream` and can inspect socket
/// options via `SockRef::from` without any unsafe FD scanning.
#[rstest]
#[tokio::test]
async fn zombie_connection_is_evicted_via_http_keepalive() {
async fn server_tcp_keepalive_socket_option_matches_config() {
const SUFFICIENTLY_LONG_KEEPALIVE_TIMEOUT_MS: u64 = 1000;

let listener =
TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).await.unwrap();
let server_addr = listener.local_addr().unwrap();

let _client_stream = TcpStream::connect(server_addr).await.unwrap();
let (accepted_stream, _) = listener.accept().await.unwrap();

// Mirror the keepalive logic in RemoteComponentServer::start().
let keepalive = TcpKeepalive::new()
.with_time(Duration::from_millis(SUFFICIENTLY_LONG_KEEPALIVE_TIMEOUT_MS));
SockRef::from(&accepted_stream).set_tcp_keepalive(&keepalive).unwrap();

assert!(
SockRef::from(&accepted_stream).keepalive().unwrap(),
"SO_KEEPALIVE on the accepted socket should reflect idle_time_ms"
);
Comment thread
cursor[bot] marked this conversation as resolved.
}

/// Verifies that the server evicts a zombie connection via HTTP/2 PING after the keepalive
/// interval and timeout elapse without receiving a response, and that the TCP keepalive socket
/// option configured on accepted sockets does not interfere with this mechanism.
///
/// # Why TCP keepalive cannot evict the connection in this setup
///
/// The server always configures TCP keepalive on accepted sockets. The two eviction mechanisms
/// are distinguishable by how the zombie socket observes the close:
/// - **TCP keepalive**: the kernel sends a RST after all probes go unanswered → `read_to_end`
/// returns `Err(connection reset by peer)`.
/// - **HTTP/2 PING timeout (hyper)**: the server sends a GOAWAY frame and then closes gracefully →
/// `read_to_end` returns `Ok` with data containing a GOAWAY frame.
///
/// On loopback (`127.0.0.1`) the kernel itself ACKs TCP keepalive probes, even when the remote
/// application ignores them. Probes therefore never go unanswered, and the kernel never sends a
/// RST. Testing TCP keepalive eviction would require a setup where probes can genuinely be
/// dropped — for example, a `veth` pair in separate network namespaces with `tc netem` packet
/// loss applied to ACKs. In the unit-test environment that is not available, so the test asserts
/// `Ok` + GOAWAY to confirm the eviction is via HTTP/2 PING and that TCP keepalive does not
/// interfere.
#[tokio::test]
async fn tcp_keepalive_does_not_interfere_with_http_keepalive_eviction() {
const KEEPALIVE_INTERVAL_MS: u64 = 100;
const KEEPALIVE_TIMEOUT_MS: u64 = 100;
Comment thread
cursor[bot] marked this conversation as resolved.
const MARGIN_MS: u64 = 500;

let socket = available_ports_factory(unique_u16!()).get_next_local_host_socket();

// Start a RemoteComponentServer with very short keepalive values.
// The local channel receiver is intentionally dropped — no requests will be sent.
let (tx, _rx) = channel::<RequestWrapper<ComponentARequest, ComponentAResponse>>(32);
let local_client = LocalComponentClient::<ComponentARequest, ComponentAResponse>::new(
tx,
Expand All @@ -53,17 +101,14 @@ async fn zombie_connection_is_evicted_via_http_keepalive() {

let mut zombie = connect_zombie(socket).await;

// Wait for the keepalive cycle to fire and time out.
sleep(Duration::from_millis(KEEPALIVE_INTERVAL_MS + KEEPALIVE_TIMEOUT_MS + MARGIN_MS)).await;

// The server should have closed the connection; read_to_end should return quickly with
// whatever GOAWAY bytes were sent, and then EOF.
let mut remainder = Vec::new();
let read_result =
timeout(Duration::from_millis(MARGIN_MS), zombie.read_to_end(&mut remainder)).await;
assert!(
read_result.is_ok(),
"Server should have closed the zombie connection after keepalive timeout, but the \
connection is still open"
);
// Closure must be a graceful HTTP/2 GOAWAY (Ok), not a TCP RST (Err).
let mut buf = Vec::new();
let bytes_read = timeout(
Duration::from_millis(KEEPALIVE_INTERVAL_MS + KEEPALIVE_TIMEOUT_MS + MARGIN_MS),
zombie.read_to_end(&mut buf),
)
.await
.expect("server should have closed the zombie connection after keepalive timeout");
bytes_read.expect("connection should close cleanly via GOAWAY, not via TCP RST");
assert!(contains_goaway_frame(&buf), "server should have sent a GOAWAY frame before closing");
}
17 changes: 17 additions & 0 deletions crates/apollo_infra/src/tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,20 @@ pub(crate) async fn connect_zombie(addr: SocketAddr) -> TcpStream {
}
stream
}

/// Returns `true` if `data` contains at least one HTTP/2 GOAWAY frame (type `0x07`).
pub(crate) fn contains_goaway_frame(data: &[u8]) -> bool {
const GOAWAY_FRAME_TYPE: u8 = 0x07;
const H2_FRAME_HEADER_LEN: usize = 9;
let mut pos = 0;
while pos + H2_FRAME_HEADER_LEN <= data.len() {
let payload_len = (usize::from(data[pos]) << 16)
| (usize::from(data[pos + 1]) << 8)
| usize::from(data[pos + 2]);
if data[pos + 3] == GOAWAY_FRAME_TYPE {
return true;
}
pos += H2_FRAME_HEADER_LEN + payload_len;
}
false
}
Loading