@@ -312,13 +312,25 @@ macro_rules! composite_custom_message_handler {
312312 }
313313
314314 fn peer_connected( & self , their_node_id: $crate:: bitcoin:: secp256k1:: PublicKey , msg: & $crate:: lightning:: ln:: msgs:: Init , inbound: bool ) -> Result <( ) , ( ) > {
315- let mut result = Ok ( ( ) ) ;
315+ // Per the `CustomMessageHandler::peer_connected` contract, `peer_disconnected`
316+ // will not be called by `PeerManager` if we return `Err`. To avoid leaking
317+ // per-peer state in sub-handlers that already returned `Ok` when a later one
318+ // errors, record each sub-handler's result and roll back the successful ones
319+ // ourselves before propagating the failure.
316320 $(
317- if let Err ( e) = self . $field. peer_connected( their_node_id, msg, inbound) {
318- result = Err ( e) ;
319- }
321+ let $field = self . $field. peer_connected( their_node_id, msg, inbound) ;
320322 ) *
321- result
323+ let any_err = false $( || $field. is_err( ) ) * ;
324+ if any_err {
325+ $(
326+ if $field. is_ok( ) {
327+ self . $field. peer_disconnected( their_node_id) ;
328+ }
329+ ) *
330+ Err ( ( ) )
331+ } else {
332+ Ok ( ( ) )
333+ }
322334 }
323335
324336 fn provided_node_features( & self ) -> $crate:: lightning:: types:: features:: NodeFeatures {
@@ -376,3 +388,143 @@ macro_rules! composite_custom_message_handler {
376388 }
377389 }
378390}
391+
392+ #[ cfg( test) ]
393+ mod tests {
394+ use bitcoin:: secp256k1:: PublicKey ;
395+ use core:: sync:: atomic:: { AtomicUsize , Ordering } ;
396+ use lightning:: io;
397+ use lightning:: ln:: msgs:: { DecodeError , Init , LightningError } ;
398+ use lightning:: ln:: peer_handler:: CustomMessageHandler ;
399+ use lightning:: ln:: wire:: { CustomMessageReader , Type } ;
400+ use lightning:: types:: features:: { InitFeatures , NodeFeatures } ;
401+ use lightning:: util:: ser:: { LengthLimitedRead , Writeable , Writer } ;
402+
403+ #[ derive( Debug ) ]
404+ pub struct Foo ;
405+ impl Type for Foo {
406+ fn type_id ( & self ) -> u16 {
407+ 32768
408+ }
409+ }
410+ impl Writeable for Foo {
411+ fn write < W : Writer > ( & self , _: & mut W ) -> Result < ( ) , io:: Error > {
412+ Ok ( ( ) )
413+ }
414+ }
415+
416+ pub struct CountingHandler {
417+ pub connect_count : AtomicUsize ,
418+ }
419+ impl CustomMessageReader for CountingHandler {
420+ type CustomMessage = Foo ;
421+ fn read < R : LengthLimitedRead > (
422+ & self , _t : u16 , _b : & mut R ,
423+ ) -> Result < Option < Foo > , DecodeError > {
424+ Ok ( None )
425+ }
426+ }
427+ impl CustomMessageHandler for CountingHandler {
428+ fn handle_custom_message ( & self , _msg : Foo , _: PublicKey ) -> Result < ( ) , LightningError > {
429+ Ok ( ( ) )
430+ }
431+ fn get_and_clear_pending_msg ( & self ) -> Vec < ( PublicKey , Foo ) > {
432+ vec ! [ ]
433+ }
434+ fn peer_disconnected ( & self , _: PublicKey ) {
435+ self . connect_count . fetch_sub ( 1 , Ordering :: SeqCst ) ;
436+ }
437+ fn peer_connected ( & self , _: PublicKey , _: & Init , _: bool ) -> Result < ( ) , ( ) > {
438+ self . connect_count . fetch_add ( 1 , Ordering :: SeqCst ) ;
439+ Ok ( ( ) )
440+ }
441+ fn provided_node_features ( & self ) -> NodeFeatures {
442+ NodeFeatures :: empty ( )
443+ }
444+ fn provided_init_features ( & self , _: PublicKey ) -> InitFeatures {
445+ InitFeatures :: empty ( )
446+ }
447+ }
448+
449+ #[ derive( Debug ) ]
450+ pub struct Bar ;
451+ impl Type for Bar {
452+ fn type_id ( & self ) -> u16 {
453+ 32769
454+ }
455+ }
456+ impl Writeable for Bar {
457+ fn write < W : Writer > ( & self , _: & mut W ) -> Result < ( ) , io:: Error > {
458+ Ok ( ( ) )
459+ }
460+ }
461+
462+ pub struct ErroringHandler ;
463+ impl CustomMessageReader for ErroringHandler {
464+ type CustomMessage = Bar ;
465+ fn read < R : LengthLimitedRead > (
466+ & self , _t : u16 , _b : & mut R ,
467+ ) -> Result < Option < Bar > , DecodeError > {
468+ Ok ( None )
469+ }
470+ }
471+ impl CustomMessageHandler for ErroringHandler {
472+ fn handle_custom_message ( & self , _msg : Bar , _: PublicKey ) -> Result < ( ) , LightningError > {
473+ Ok ( ( ) )
474+ }
475+ fn get_and_clear_pending_msg ( & self ) -> Vec < ( PublicKey , Bar ) > {
476+ vec ! [ ]
477+ }
478+ fn peer_disconnected ( & self , _: PublicKey ) {
479+ debug_assert ! ( false ) ;
480+ }
481+ fn peer_connected ( & self , _: PublicKey , _: & Init , _: bool ) -> Result < ( ) , ( ) > {
482+ Err ( ( ) )
483+ }
484+ fn provided_node_features ( & self ) -> NodeFeatures {
485+ NodeFeatures :: empty ( )
486+ }
487+ fn provided_init_features ( & self , _: PublicKey ) -> InitFeatures {
488+ InitFeatures :: empty ( )
489+ }
490+ }
491+
492+ composite_custom_message_handler ! (
493+ pub struct CompositeHandler {
494+ counting: CountingHandler ,
495+ erroring: ErroringHandler ,
496+ }
497+
498+ pub enum CompositeMessage {
499+ Foo ( 32768 ) ,
500+ Bar ( 32769 ) ,
501+ }
502+ ) ;
503+
504+ #[ test]
505+ fn peer_connected_failure_does_not_leak_subhandler_state ( ) {
506+ let composite = CompositeHandler {
507+ counting : CountingHandler { connect_count : AtomicUsize :: new ( 0 ) } ,
508+ erroring : ErroringHandler ,
509+ } ;
510+ let pk_bytes = [
511+ 0x02 , 0x79 , 0xBE , 0x66 , 0x7E , 0xF9 , 0xDC , 0xBB , 0xAC , 0x55 , 0xA0 , 0x62 , 0x95 , 0xCE ,
512+ 0x87 , 0x0B , 0x07 , 0x02 , 0x9B , 0xFC , 0xDB , 0x2D , 0xCE , 0x28 , 0xD9 , 0x59 , 0xF2 , 0x81 ,
513+ 0x5B , 0x16 , 0xF8 , 0x17 , 0x98 ,
514+ ] ;
515+ let pk = PublicKey :: from_slice ( & pk_bytes) . unwrap ( ) ;
516+ let init =
517+ Init { features : InitFeatures :: empty ( ) , networks : None , remote_network_address : None } ;
518+
519+ let result = composite. peer_connected ( pk, & init, true ) ;
520+ assert ! ( result. is_err( ) , "Composite must propagate the inner Err" ) ;
521+
522+ let leaked = composite. counting . connect_count . load ( Ordering :: SeqCst ) ;
523+ assert_eq ! (
524+ leaked, 0 ,
525+ "CountingHandler tracked {leaked} connected peer(s) after the composite \
526+ returned Err; this state will never be cleaned up because per the trait \
527+ contract peer_disconnected won't be called when peer_connected returns Err.",
528+ ) ;
529+ }
530+ }
0 commit comments