Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions src/rpc/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2601,4 +2601,177 @@ mod tests {
err
);
}

/// Trailing bytes after a complete request payload must remain in the source buffer.
/// `upgrade_inbound` then drains the stream and rejects them with `InvalidRequest`
/// per `phase0/p2p-interface.md`.
#[test]
fn test_inbound_codec_leaves_trailing_bytes_in_buffer() {
let config = Arc::new(Config::mainnet());
let protocol = ProtocolId::new(SupportedProtocol::PingV1, Encoding::SSZSnappy);
let fork_context = Arc::new(ForkContext::dummy::<Mainnet>(&config, Phase::Phase0));

let mut wire = BytesMut::new();
let mut outbound = SSZSnappyOutboundCodec::<Mainnet>::new(
protocol.clone(),
config.max_payload_size,
fork_context.clone(),
);
outbound
.encode(RequestType::Ping(Ping { data: 1 }), &mut wire)
.expect("Ping encodes");

let trailing = [0xDEu8; 16];
wire.extend_from_slice(&trailing);

let mut inbound = SSZSnappyInboundCodec::<Mainnet>::new(
config.clone_arc(),
protocol,
config.max_payload_size,
fork_context,
);
let request = inbound
.decode(&mut wire)
.expect("decode succeeds")
.expect("returns a request");
assert!(matches!(request, RequestType::Ping(_)));
assert_eq!(
wire.len(),
trailing.len(),
"codec must leave trailing bytes in source buffer for drain check",
);
}

/// FramedRead-level: trailing bytes after a clean request are detected on the next poll,
/// matching what `upgrade_inbound`'s post-decode drain observes.
#[tokio::test]
async fn test_framed_read_yields_extra_chunk_when_trailing_bytes() {
use futures::StreamExt as _;
use tokio_util::codec::FramedRead;

let config = Arc::new(Config::mainnet());
let protocol = ProtocolId::new(SupportedProtocol::PingV1, Encoding::SSZSnappy);
let fork_context = Arc::new(ForkContext::dummy::<Mainnet>(&config, Phase::Phase0));

let mut wire = BytesMut::new();
let mut outbound = SSZSnappyOutboundCodec::<Mainnet>::new(
protocol.clone(),
config.max_payload_size,
fork_context.clone(),
);
outbound
.encode(RequestType::Ping(Ping { data: 1 }), &mut wire)
.expect("Ping encodes");
wire.extend_from_slice(&[0xDEu8; 16]);

let cursor = std::io::Cursor::new(wire.freeze());
let codec = SSZSnappyInboundCodec::<Mainnet>::new(
config.clone_arc(),
protocol,
config.max_payload_size,
fork_context,
);
let mut framed = FramedRead::new(cursor, codec);

let first = framed.next().await;
assert!(
matches!(first, Some(Ok(RequestType::Ping(_)))),
"first item must decode as Ping, got {first:?}",
);

let second = framed.next().await;
assert!(
second.is_some(),
"trailing bytes after request must yield a second item, got None",
);
}

/// FramedRead-level: a clean request stream reaches EOF on the next poll, so
/// `upgrade_inbound`'s drain returns `Ok(None)` and accepts the request.
#[tokio::test]
async fn test_framed_read_reaches_eof_on_clean_request() {
use futures::StreamExt as _;
use tokio_util::codec::FramedRead;

let config = Arc::new(Config::mainnet());
let protocol = ProtocolId::new(SupportedProtocol::PingV1, Encoding::SSZSnappy);
let fork_context = Arc::new(ForkContext::dummy::<Mainnet>(&config, Phase::Phase0));

let mut wire = BytesMut::new();
let mut outbound = SSZSnappyOutboundCodec::<Mainnet>::new(
protocol.clone(),
config.max_payload_size,
fork_context.clone(),
);
outbound
.encode(RequestType::Ping(Ping { data: 1 }), &mut wire)
.expect("Ping encodes");

let cursor = std::io::Cursor::new(wire.freeze());
let codec = SSZSnappyInboundCodec::<Mainnet>::new(
config.clone_arc(),
protocol,
config.max_payload_size,
fork_context,
);
let mut framed = FramedRead::new(cursor, codec);

let first = framed.next().await;
assert!(matches!(first, Some(Ok(RequestType::Ping(_)))));

let second = framed.next().await;
assert!(
second.is_none(),
"clean request must reach EOF, got {second:?}",
);
}

/// Same as the Ping trailing-bytes test, but for a variable-size request
/// (`BlocksByRootV2`). Variable-size SSZ goes through a different decode branch
/// than fixed-size Ping; the drain check must work for both.
#[tokio::test]
async fn test_framed_read_yields_extra_chunk_when_trailing_bytes_blocks_by_root() {
use futures::StreamExt as _;
use tokio_util::codec::FramedRead;

let config = Arc::new(Config::mainnet());
let phase = Phase::Phase0;
let protocol = ProtocolId::new(SupportedProtocol::BlocksByRootV2, Encoding::SSZSnappy);
let fork_context = Arc::new(ForkContext::dummy::<Mainnet>(&config, phase));

let mut wire = BytesMut::new();
let mut outbound = SSZSnappyOutboundCodec::<Mainnet>::new(
protocol.clone(),
config.max_payload_size,
fork_context.clone(),
);
outbound
.encode(
RequestType::BlocksByRoot(bbroot_request_v2(&config, phase)),
&mut wire,
)
.expect("BlocksByRoot encodes");
wire.extend_from_slice(&[0xDEu8; 16]);

let cursor = std::io::Cursor::new(wire.freeze());
let codec = SSZSnappyInboundCodec::<Mainnet>::new(
config.clone_arc(),
protocol,
config.max_payload_size,
fork_context,
);
let mut framed = FramedRead::new(cursor, codec);

let first = framed.next().await;
assert!(
matches!(first, Some(Ok(RequestType::BlocksByRoot(_)))),
"first item must decode as BlocksByRoot, got {first:?}",
);

let second = framed.next().await;
assert!(
second.is_some(),
"trailing bytes after request must yield a second item, got None",
);
}
}
27 changes: 26 additions & 1 deletion src/rpc/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ const PROTOCOL_PREFIX: &str = "/eth2/beacon_chain/req";
/// The number of seconds to wait for the first bytes of a request once a protocol has been
/// established before the stream is terminated.
const REQUEST_TIMEOUT: u64 = 15;
/// After decoding a request, the peer is expected to have closed the write side of the stream.
/// Wait up to this duration for EOF so we can reject trailing bytes per
/// `phase0/p2p-interface.md` (`InvalidRequest` on any remaining bytes after the SSZ payload).
const REQUEST_TRAILING_BYTES_TIMEOUT: Duration = Duration::from_millis(500);

/// Returns the rpc limits for beacon_block_by_range and beacon_block_by_root responses.
///
Expand Down Expand Up @@ -633,7 +637,28 @@ where
.await
{
Err(e) => Err((versioned_protocol.protocol(), RPCError::from(e))),
Ok((Some(Ok(request)), stream)) => Ok((request, stream)),
Ok((Some(Ok(request)), mut stream)) => {
match tokio::time::timeout(
REQUEST_TRAILING_BYTES_TIMEOUT,
stream.next(),
)
.await
{
Ok(None) => Ok((request, stream)),
Ok(Some(_)) => Err((
versioned_protocol.protocol(),
RPCError::InvalidData(
"trailing bytes after request payload".into(),
),
)),
Err(_) => Err((
versioned_protocol.protocol(),
RPCError::InvalidData(
"request stream not closed after request payload".into(),
),
)),
}
}
Ok((Some(Err(e)), _)) => Err((versioned_protocol.protocol(), e)),
Ok((None, _)) => {
Err((versioned_protocol.protocol(), RPCError::IncompleteStream))
Expand Down