diff --git a/src/rpc/codec.rs b/src/rpc/codec.rs index ff96050..1a82563 100644 --- a/src/rpc/codec.rs +++ b/src/rpc/codec.rs @@ -2601,4 +2601,177 @@ mod tests { err ); } + + /// Trailing bytes after a complete request payload must remain in the source buffer. + /// `upgrade_inbound` then drains the stream and rejects them with `InvalidRequest` + /// per `phase0/p2p-interface.md`. + #[test] + fn test_inbound_codec_leaves_trailing_bytes_in_buffer() { + let config = Arc::new(Config::mainnet()); + let protocol = ProtocolId::new(SupportedProtocol::PingV1, Encoding::SSZSnappy); + let fork_context = Arc::new(ForkContext::dummy::(&config, Phase::Phase0)); + + let mut wire = BytesMut::new(); + let mut outbound = SSZSnappyOutboundCodec::::new( + protocol.clone(), + config.max_payload_size, + fork_context.clone(), + ); + outbound + .encode(RequestType::Ping(Ping { data: 1 }), &mut wire) + .expect("Ping encodes"); + + let trailing = [0xDEu8; 16]; + wire.extend_from_slice(&trailing); + + let mut inbound = SSZSnappyInboundCodec::::new( + config.clone_arc(), + protocol, + config.max_payload_size, + fork_context, + ); + let request = inbound + .decode(&mut wire) + .expect("decode succeeds") + .expect("returns a request"); + assert!(matches!(request, RequestType::Ping(_))); + assert_eq!( + wire.len(), + trailing.len(), + "codec must leave trailing bytes in source buffer for drain check", + ); + } + + /// FramedRead-level: trailing bytes after a clean request are detected on the next poll, + /// matching what `upgrade_inbound`'s post-decode drain observes. + #[tokio::test] + async fn test_framed_read_yields_extra_chunk_when_trailing_bytes() { + use futures::StreamExt as _; + use tokio_util::codec::FramedRead; + + let config = Arc::new(Config::mainnet()); + let protocol = ProtocolId::new(SupportedProtocol::PingV1, Encoding::SSZSnappy); + let fork_context = Arc::new(ForkContext::dummy::(&config, Phase::Phase0)); + + let mut wire = BytesMut::new(); + let mut outbound = SSZSnappyOutboundCodec::::new( + protocol.clone(), + config.max_payload_size, + fork_context.clone(), + ); + outbound + .encode(RequestType::Ping(Ping { data: 1 }), &mut wire) + .expect("Ping encodes"); + wire.extend_from_slice(&[0xDEu8; 16]); + + let cursor = std::io::Cursor::new(wire.freeze()); + let codec = SSZSnappyInboundCodec::::new( + config.clone_arc(), + protocol, + config.max_payload_size, + fork_context, + ); + let mut framed = FramedRead::new(cursor, codec); + + let first = framed.next().await; + assert!( + matches!(first, Some(Ok(RequestType::Ping(_)))), + "first item must decode as Ping, got {first:?}", + ); + + let second = framed.next().await; + assert!( + second.is_some(), + "trailing bytes after request must yield a second item, got None", + ); + } + + /// FramedRead-level: a clean request stream reaches EOF on the next poll, so + /// `upgrade_inbound`'s drain returns `Ok(None)` and accepts the request. + #[tokio::test] + async fn test_framed_read_reaches_eof_on_clean_request() { + use futures::StreamExt as _; + use tokio_util::codec::FramedRead; + + let config = Arc::new(Config::mainnet()); + let protocol = ProtocolId::new(SupportedProtocol::PingV1, Encoding::SSZSnappy); + let fork_context = Arc::new(ForkContext::dummy::(&config, Phase::Phase0)); + + let mut wire = BytesMut::new(); + let mut outbound = SSZSnappyOutboundCodec::::new( + protocol.clone(), + config.max_payload_size, + fork_context.clone(), + ); + outbound + .encode(RequestType::Ping(Ping { data: 1 }), &mut wire) + .expect("Ping encodes"); + + let cursor = std::io::Cursor::new(wire.freeze()); + let codec = SSZSnappyInboundCodec::::new( + config.clone_arc(), + protocol, + config.max_payload_size, + fork_context, + ); + let mut framed = FramedRead::new(cursor, codec); + + let first = framed.next().await; + assert!(matches!(first, Some(Ok(RequestType::Ping(_))))); + + let second = framed.next().await; + assert!( + second.is_none(), + "clean request must reach EOF, got {second:?}", + ); + } + + /// Same as the Ping trailing-bytes test, but for a variable-size request + /// (`BlocksByRootV2`). Variable-size SSZ goes through a different decode branch + /// than fixed-size Ping; the drain check must work for both. + #[tokio::test] + async fn test_framed_read_yields_extra_chunk_when_trailing_bytes_blocks_by_root() { + use futures::StreamExt as _; + use tokio_util::codec::FramedRead; + + let config = Arc::new(Config::mainnet()); + let phase = Phase::Phase0; + let protocol = ProtocolId::new(SupportedProtocol::BlocksByRootV2, Encoding::SSZSnappy); + let fork_context = Arc::new(ForkContext::dummy::(&config, phase)); + + let mut wire = BytesMut::new(); + let mut outbound = SSZSnappyOutboundCodec::::new( + protocol.clone(), + config.max_payload_size, + fork_context.clone(), + ); + outbound + .encode( + RequestType::BlocksByRoot(bbroot_request_v2(&config, phase)), + &mut wire, + ) + .expect("BlocksByRoot encodes"); + wire.extend_from_slice(&[0xDEu8; 16]); + + let cursor = std::io::Cursor::new(wire.freeze()); + let codec = SSZSnappyInboundCodec::::new( + config.clone_arc(), + protocol, + config.max_payload_size, + fork_context, + ); + let mut framed = FramedRead::new(cursor, codec); + + let first = framed.next().await; + assert!( + matches!(first, Some(Ok(RequestType::BlocksByRoot(_)))), + "first item must decode as BlocksByRoot, got {first:?}", + ); + + let second = framed.next().await; + assert!( + second.is_some(), + "trailing bytes after request must yield a second item, got None", + ); + } } diff --git a/src/rpc/protocol.rs b/src/rpc/protocol.rs index b8d97b0..63abffd 100644 --- a/src/rpc/protocol.rs +++ b/src/rpc/protocol.rs @@ -84,6 +84,10 @@ const PROTOCOL_PREFIX: &str = "/eth2/beacon_chain/req"; /// The number of seconds to wait for the first bytes of a request once a protocol has been /// established before the stream is terminated. const REQUEST_TIMEOUT: u64 = 15; +/// After decoding a request, the peer is expected to have closed the write side of the stream. +/// Wait up to this duration for EOF so we can reject trailing bytes per +/// `phase0/p2p-interface.md` (`InvalidRequest` on any remaining bytes after the SSZ payload). +const REQUEST_TRAILING_BYTES_TIMEOUT: Duration = Duration::from_millis(500); /// Returns the rpc limits for beacon_block_by_range and beacon_block_by_root responses. /// @@ -633,7 +637,28 @@ where .await { Err(e) => Err((versioned_protocol.protocol(), RPCError::from(e))), - Ok((Some(Ok(request)), stream)) => Ok((request, stream)), + Ok((Some(Ok(request)), mut stream)) => { + match tokio::time::timeout( + REQUEST_TRAILING_BYTES_TIMEOUT, + stream.next(), + ) + .await + { + Ok(None) => Ok((request, stream)), + Ok(Some(_)) => Err(( + versioned_protocol.protocol(), + RPCError::InvalidData( + "trailing bytes after request payload".into(), + ), + )), + Err(_) => Err(( + versioned_protocol.protocol(), + RPCError::InvalidData( + "request stream not closed after request payload".into(), + ), + )), + } + } Ok((Some(Err(e)), _)) => Err((versioned_protocol.protocol(), e)), Ok((None, _)) => { Err((versioned_protocol.protocol(), RPCError::IncompleteStream))