Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/apollo_infra/src/component_definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ where
pub enum ServerError {
#[error("Could not deserialize client request: {0}")]
RequestDeserializationFailure(String),
#[error("Request body too large: {0}")]
RequestBodyTooLarge(String),
}

#[derive(Debug)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,8 @@ where
match Limited::new(http_request.into_body(), max_request_body_bytes).collect().await {
Ok(collected) => collected.to_bytes(),
Err(err) => {
error!("Failed to collect request body: {err}");
let server_error = ServerError::RequestDeserializationFailure(
"Request body too large".to_string(),
);
warn!("Request body too large: {err}");
let server_error = ServerError::RequestBodyTooLarge(err.to_string());
return Ok(HyperResponse::builder()
.status(StatusCode::PAYLOAD_TOO_LARGE)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
Expand Down
151 changes: 151 additions & 0 deletions crates/apollo_infra/src/tests/max_request_size_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use std::convert::Infallible;

use apollo_proc_macros::unique_u16;
use bytes::Bytes;
use http::header::CONTENT_TYPE;
use http::{StatusCode, Uri};
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::client::legacy::Client;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as Http2ServerBuilder;
use tokio::net::TcpListener;
use tokio::sync::mpsc::channel;
use tokio::task;

use crate::component_client::{ClientError, LocalComponentClient, RemoteClientConfig};
use crate::component_definitions::{
RequestId,
RequestWrapper,
ServerError,
APPLICATION_OCTET_STREAM,
REQUEST_ID_HEADER,
};
use crate::component_server::{
ComponentServerStarter,
LocalComponentServer,
LocalServerConfig,
RemoteComponentServer,
RemoteServerConfig,
};
use crate::serde_utils::SerdeWrapper;
use crate::tests::test_utils::{
available_ports_factory,
ComponentA,
ComponentAClient,
ComponentAClientTrait,
ComponentARequest,
ComponentAResponse,
ComponentBClient,
FAST_FAILING_CLIENT_CONFIG,
TEST_LOCAL_CLIENT_METRICS,
TEST_LOCAL_SERVER_METRICS,
TEST_REMOTE_CLIENT_METRICS,
TEST_REMOTE_SERVER_METRICS,
};

/// Server rejects a request whose body exceeds `max_request_body_bytes` with 413 and
/// `ServerError::RequestBodyTooLarge`.
#[tokio::test]
async fn request_body_too_large() {
let mut available_ports = available_ports_factory(unique_u16!());
let a_socket = available_ports.get_next_local_host_socket();
let dummy_b_socket = available_ports.get_next_local_host_socket();

// B client points at a non-existent server; it will never be called because the oversized
// request is rejected at the HTTP layer before any component logic runs.
let b_remote_client = ComponentBClient::new(
RemoteClientConfig::default(),
&dummy_b_socket.ip().to_string(),
dummy_b_socket.port(),
&TEST_REMOTE_CLIENT_METRICS,
);
let component_a = ComponentA::new(Box::new(b_remote_client));

let (tx_a, rx_a) = channel::<RequestWrapper<ComponentARequest, ComponentAResponse>>(32);
let a_local_client = LocalComponentClient::new(tx_a, &TEST_LOCAL_CLIENT_METRICS);

let mut local_server = LocalComponentServer::new(
component_a,
&LocalServerConfig::default(),
rx_a,
&TEST_LOCAL_SERVER_METRICS,
);
task::spawn(async move {
let _ = local_server.start().await;
});

let server_config = RemoteServerConfig { max_request_body_bytes: 1, ..Default::default() };
let mut remote_server = RemoteComponentServer::new(
a_local_client,
server_config,
a_socket.port(),
&TEST_REMOTE_SERVER_METRICS,
);
task::spawn(async move {
let _ = remote_server.start().await;
});
task::yield_now().await;

let uri: Uri = format!("http://[{}]:{}/", a_socket.ip(), a_socket.port()).parse().unwrap();
let http_request = Request::post(uri)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.header(REQUEST_ID_HEADER, RequestId::generate().to_string())
.body(Full::new(Bytes::from("x".repeat(1024))))
.unwrap();
let http_response =
Client::builder(TokioExecutor::new()).build_http().request(http_request).await.unwrap();

assert_eq!(http_response.status(), StatusCode::PAYLOAD_TOO_LARGE);
let body_bytes = http_response.into_body().collect().await.unwrap().to_bytes();
let server_error = SerdeWrapper::<ServerError>::wrapper_deserialize(&body_bytes).unwrap();
assert!(matches!(server_error, ServerError::RequestBodyTooLarge(_)));
}

/// Client returns `ResponseParsingFailure` when the server's response body exceeds
/// `max_response_body_bytes`.
#[tokio::test]
async fn response_body_too_large() {
let socket = available_ports_factory(unique_u16!()).get_next_local_host_socket();
task::spawn(async move {
async fn handler(
_http_request: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, Infallible> {
Ok(Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Full::new(Bytes::from(vec![0u8; 1024])))
.unwrap())
}

let listener = TcpListener::bind(&socket).await.unwrap();
loop {
let Ok((stream, _)) = listener.accept().await else { continue };
let io = TokioIo::new(stream);
let service = service_fn(|req| async move { handler(req).await });
tokio::spawn(async move {
let _ = Http2ServerBuilder::new(TokioExecutor::new())
.http2()
.serve_connection(io, service)
.await;
});
}
});
task::yield_now().await;

let client_config =
RemoteClientConfig { max_response_body_bytes: 1, retries: 0, ..FAST_FAILING_CLIENT_CONFIG };
let client = ComponentAClient::new(
client_config,
&socket.ip().to_string(),
socket.port(),
&TEST_REMOTE_CLIENT_METRICS,
);

let Err(error) = client.a_get_value().await else {
panic!("Expected an error");
};
assert!(matches!(error, ClientError::ResponseParsingFailure(_)), "unexpected error: {error}");
}
1 change: 1 addition & 0 deletions crates/apollo_infra/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub(crate) mod test_utils;
mod concurrent_servers_test;
mod local_component_client_server_test;
mod local_request_prioritization_test;
mod max_request_size_test;
mod remote_client_connection_eviction_test;
mod remote_component_client_server_test;
mod server_metrics_test;
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ use crate::tests::test_utils::{
ResultA,
ResultB,
ValueB,
FAST_FAILING_CLIENT_CONFIG,
MAX_CONCURRENCY,
TEST_LOCAL_CLIENT_METRICS,
TEST_LOCAL_SERVER_METRICS,
Expand All @@ -89,18 +90,6 @@ const ARBITRARY_DATA: &str = "arbitrary data";
// ServerError::RequestDeserializationFailure error message.
const DESERIALIZE_REQ_ERROR_MESSAGE: &str = "Could not deserialize client request";
const BAD_REQUEST_ERROR_MESSAGE: &str = "Got status code: 400 Bad Request";
const FAST_FAILING_CLIENT_CONFIG: RemoteClientConfig = RemoteClientConfig {
retries: 0,
idle_connections: 0,
keepalive_timeout_ms: 0,
max_retry_interval_ms: 0,
initial_retry_delay_ms: 0,
attempts_per_log: 1,
connection_timeout_ms: 500,
request_timeout_ms: 1000,
set_tcp_nodelay: true,
max_response_body_bytes: usize::MAX,
};

#[async_trait]
impl ComponentAClientTrait for RemoteComponentClient<ComponentARequest, ComponentAResponse> {
Expand Down
15 changes: 14 additions & 1 deletion crates/apollo_infra/src/tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Semaphore;

use crate::component_client::{ClientResult, RemoteComponentClient};
use crate::component_client::{ClientResult, RemoteClientConfig, RemoteComponentClient};
use crate::component_definitions::{ComponentRequestHandler, ComponentStarter, PrioritizedRequest};
use crate::component_server::RemoteServerConfig;
use crate::metrics::{
Expand All @@ -43,6 +43,19 @@ pub(crate) type ComponentBClient = RemoteComponentClient<ComponentBRequest, Comp
pub(crate) const VALID_VALUE_A: ValueA = Felt::ONE;
pub(crate) const MAX_CONCURRENCY: usize = 10;

pub(crate) const FAST_FAILING_CLIENT_CONFIG: RemoteClientConfig = RemoteClientConfig {
retries: 0,
idle_connections: 0,
keepalive_timeout_ms: 0,
max_retry_interval_ms: 0,
initial_retry_delay_ms: 0,
attempts_per_log: 1,
connection_timeout_ms: 500,
request_timeout_ms: 1000,
set_tcp_nodelay: true,
max_response_body_bytes: usize::MAX,
};

#[derive(Serialize, Deserialize, Clone, AsRefStr, EnumDiscriminants)]
#[strum_discriminants(
name(ComponentARequestLabelValue),
Expand Down
Loading