@@ -17,14 +17,11 @@ use std::{
1717} ;
1818
1919use chrono:: { DateTime , Utc } ;
20- use futures:: prelude:: * ;
2120use libp2p:: { PeerId , swarm:: Stream } ;
2221use pluto_core:: version:: { self , SemVer , SemVerError } ;
23- use prost:: Message ;
2422use regex:: Regex ;
2523use tokio:: sync:: Mutex ;
2624use tracing:: { info, warn} ;
27- use unsigned_varint:: aio:: read_usize;
2825
2926use crate :: {
3027 LocalPeerInfo ,
@@ -51,57 +48,6 @@ pub struct ProtocolState {
5148 local_info : LocalPeerInfo ,
5249}
5350
54- /// Writes a protobuf message with unsigned varint length prefix to the stream.
55- ///
56- /// Wire format: `[uvarint length][protobuf bytes]`
57- async fn write_protobuf < M : Message , S : AsyncWrite + Unpin > (
58- stream : & mut S ,
59- msg : & M ,
60- ) -> io:: Result < ( ) > {
61- // Encode message to protobuf bytes
62- let mut buf = Vec :: with_capacity ( msg. encoded_len ( ) ) ;
63- msg. encode ( & mut buf)
64- . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ?;
65-
66- // Write unsigned varint length prefix
67- let mut len_buf = unsigned_varint:: encode:: usize_buffer ( ) ;
68- let encoded_len = unsigned_varint:: encode:: usize ( buf. len ( ) , & mut len_buf) ;
69- stream. write_all ( encoded_len) . await ?;
70-
71- // Write protobuf bytes
72- stream. write_all ( & buf) . await ?;
73- stream. flush ( ) . await
74- }
75-
76- /// Reads a protobuf message with unsigned varint length prefix from the stream.
77- ///
78- /// Wire format: `[uvarint length][protobuf bytes]`
79- ///
80- /// Returns an error if the message exceeds `MAX_MESSAGE_SIZE`.
81- async fn read_protobuf < M : Message + Default , S : AsyncRead + Unpin > (
82- stream : & mut S ,
83- ) -> io:: Result < M > {
84- // Read unsigned varint length prefix
85- let msg_len = read_usize ( & mut * stream) . await . map_err ( |e| match e {
86- unsigned_varint:: io:: ReadError :: Io ( io_err) => io_err,
87- other => io:: Error :: new ( io:: ErrorKind :: InvalidData , other) ,
88- } ) ?;
89-
90- if msg_len > MAX_MESSAGE_SIZE {
91- return Err ( io:: Error :: new (
92- io:: ErrorKind :: InvalidData ,
93- format ! ( "message too large: {msg_len} bytes (max: {MAX_MESSAGE_SIZE})" ) ,
94- ) ) ;
95- }
96-
97- // Read exactly `msg_len` protobuf bytes
98- let mut buf = vec ! [ 0u8 ; msg_len] ;
99- stream. read_exact ( & mut buf) . await ?;
100-
101- // Unmarshal protobuf
102- M :: decode ( & buf[ ..] ) . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidData , e) )
103- }
104-
10551/// Errors that can occur during the protocol.
10652#[ derive( Debug , thiserror:: Error ) ]
10753pub enum ProtocolError {
@@ -317,8 +263,9 @@ impl ProtocolState {
317263 request : & PeerInfo ,
318264 ) -> io:: Result < ( Stream , PeerInfo ) > {
319265 let start = Instant :: now ( ) ;
320- write_protobuf ( & mut stream, request) . await ?;
321- let response = read_protobuf ( & mut stream) . await ?;
266+ pluto_p2p:: proto:: write_protobuf ( & mut stream, request) . await ?;
267+ let response =
268+ pluto_p2p:: proto:: read_protobuf_with_max_size ( & mut stream, MAX_MESSAGE_SIZE ) . await ?;
322269 let rtt = start. elapsed ( ) ;
323270
324271 self . validate_peer_info ( & response, rtt) . await ;
@@ -334,8 +281,9 @@ impl ProtocolState {
334281 mut stream : Stream ,
335282 local_info : & PeerInfo ,
336283 ) -> io:: Result < ( Stream , PeerInfo ) > {
337- let request = read_protobuf ( & mut stream) . await ?;
338- write_protobuf ( & mut stream, local_info) . await ?;
284+ let request =
285+ pluto_p2p:: proto:: read_protobuf_with_max_size ( & mut stream, MAX_MESSAGE_SIZE ) . await ?;
286+ pluto_p2p:: proto:: write_protobuf ( & mut stream, local_info) . await ?;
339287 Ok ( ( stream, request) )
340288 }
341289}
@@ -344,6 +292,7 @@ impl ProtocolState {
344292mod tests {
345293 use super :: * ;
346294 use hex_literal:: hex;
295+ use prost:: Message ;
347296
348297 // Test case: minimal
349298 // CharonVersion: "v1.0.0"
@@ -571,7 +520,9 @@ mod tests {
571520
572521 // Write to a cursor
573522 let mut buf = Vec :: new ( ) ;
574- write_protobuf ( & mut buf, & original) . await . unwrap ( ) ;
523+ pluto_p2p:: proto:: write_protobuf ( & mut buf, & original)
524+ . await
525+ . unwrap ( ) ;
575526
576527 // The wire format should be: [varint length][protobuf bytes]
577528 // Minimal message is 14 bytes, so length prefix is just 1 byte (14 < 128)
@@ -580,7 +531,7 @@ mod tests {
580531
581532 // Read it back
582533 let mut cursor = futures:: io:: Cursor :: new ( & buf[ ..] ) ;
583- let decoded: PeerInfo = read_protobuf ( & mut cursor) . await . unwrap ( ) ;
534+ let decoded: PeerInfo = pluto_p2p :: proto :: read_protobuf ( & mut cursor) . await . unwrap ( ) ;
584535 assert_eq ! ( original, decoded) ;
585536 }
586537
@@ -589,11 +540,13 @@ mod tests {
589540 let original = make_full_peerinfo ( ) ;
590541
591542 let mut buf = Vec :: new ( ) ;
592- write_protobuf ( & mut buf, & original) . await . unwrap ( ) ;
543+ pluto_p2p:: proto:: write_protobuf ( & mut buf, & original)
544+ . await
545+ . unwrap ( ) ;
593546
594547 // Read it back
595548 let mut cursor = futures:: io:: Cursor :: new ( & buf[ ..] ) ;
596- let decoded: PeerInfo = read_protobuf ( & mut cursor) . await . unwrap ( ) ;
549+ let decoded: PeerInfo = pluto_p2p :: proto :: read_protobuf ( & mut cursor) . await . unwrap ( ) ;
597550 assert_eq ! ( original, decoded) ;
598551 }
599552
@@ -609,10 +562,12 @@ mod tests {
609562
610563 for original in variants {
611564 let mut buf = Vec :: new ( ) ;
612- write_protobuf ( & mut buf, & original) . await . unwrap ( ) ;
565+ pluto_p2p:: proto:: write_protobuf ( & mut buf, & original)
566+ . await
567+ . unwrap ( ) ;
613568
614569 let mut cursor = futures:: io:: Cursor :: new ( & buf[ ..] ) ;
615- let decoded: PeerInfo = read_protobuf ( & mut cursor) . await . unwrap ( ) ;
570+ let decoded: PeerInfo = pluto_p2p :: proto :: read_protobuf ( & mut cursor) . await . unwrap ( ) ;
616571 assert_eq ! ( original, decoded) ;
617572 }
618573 }
@@ -627,7 +582,8 @@ mod tests {
627582 buf. extend_from_slice ( encoded_len) ;
628583
629584 let mut cursor = futures:: io:: Cursor :: new ( & buf[ ..] ) ;
630- let result: io:: Result < PeerInfo > = read_protobuf ( & mut cursor) . await ;
585+ let result: io:: Result < PeerInfo > =
586+ pluto_p2p:: proto:: read_protobuf_with_max_size ( & mut cursor, MAX_MESSAGE_SIZE ) . await ;
631587
632588 assert ! ( result. is_err( ) ) ;
633589 let err = result. unwrap_err ( ) ;
@@ -641,7 +597,7 @@ mod tests {
641597 let invalid_data = [ 0x05 , 0xff , 0xff , 0xff , 0xff , 0xff ] ; // length 5, then garbage
642598
643599 let mut cursor = futures:: io:: Cursor :: new ( & invalid_data[ ..] ) ;
644- let result: io:: Result < PeerInfo > = read_protobuf ( & mut cursor) . await ;
600+ let result: io:: Result < PeerInfo > = pluto_p2p :: proto :: read_protobuf ( & mut cursor) . await ;
645601
646602 assert ! ( result. is_err( ) ) ;
647603 assert_eq ! ( result. unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidData ) ;
@@ -653,7 +609,7 @@ mod tests {
653609 let truncated = [ 0x10 ] ; // claims 16 bytes but has none
654610
655611 let mut cursor = futures:: io:: Cursor :: new ( & truncated[ ..] ) ;
656- let result: io:: Result < PeerInfo > = read_protobuf ( & mut cursor) . await ;
612+ let result: io:: Result < PeerInfo > = pluto_p2p :: proto :: read_protobuf ( & mut cursor) . await ;
657613
658614 assert ! ( result. is_err( ) ) ;
659615 assert_eq ! ( result. unwrap_err( ) . kind( ) , io:: ErrorKind :: UnexpectedEof ) ;
@@ -667,15 +623,21 @@ mod tests {
667623
668624 // Write multiple messages to the same buffer
669625 let mut buf = Vec :: new ( ) ;
670- write_protobuf ( & mut buf, & msg1) . await . unwrap ( ) ;
671- write_protobuf ( & mut buf, & msg2) . await . unwrap ( ) ;
672- write_protobuf ( & mut buf, & msg3) . await . unwrap ( ) ;
626+ pluto_p2p:: proto:: write_protobuf ( & mut buf, & msg1)
627+ . await
628+ . unwrap ( ) ;
629+ pluto_p2p:: proto:: write_protobuf ( & mut buf, & msg2)
630+ . await
631+ . unwrap ( ) ;
632+ pluto_p2p:: proto:: write_protobuf ( & mut buf, & msg3)
633+ . await
634+ . unwrap ( ) ;
673635
674636 // Read them back in order
675637 let mut cursor = futures:: io:: Cursor :: new ( & buf[ ..] ) ;
676- let decoded1: PeerInfo = read_protobuf ( & mut cursor) . await . unwrap ( ) ;
677- let decoded2: PeerInfo = read_protobuf ( & mut cursor) . await . unwrap ( ) ;
678- let decoded3: PeerInfo = read_protobuf ( & mut cursor) . await . unwrap ( ) ;
638+ let decoded1: PeerInfo = pluto_p2p :: proto :: read_protobuf ( & mut cursor) . await . unwrap ( ) ;
639+ let decoded2: PeerInfo = pluto_p2p :: proto :: read_protobuf ( & mut cursor) . await . unwrap ( ) ;
640+ let decoded3: PeerInfo = pluto_p2p :: proto :: read_protobuf ( & mut cursor) . await . unwrap ( ) ;
679641
680642 assert_eq ! ( msg1, decoded1) ;
681643 assert_eq ! ( msg2, decoded2) ;
0 commit comments